You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

12 KiB

AI Tasks 与 Providers 实现

日期:2026-01-29
类型:Feature
影响范围:AI Service, Celery Tasks, AI Providers

概述

实现 Celery Tasks 异步任务处理和 AI Providers 基础框架,完成积分确认和退还流程的闭环。

变更内容

1. AI Providers 基础框架

1.1 BaseAIProvider(基类)

class BaseAIProvider(ABC):
    """AI Provider 基类"""
    
    @abstractmethod
    async def generate_image(self, prompt: str, ...) -> Dict[str, Any]
    
    @abstractmethod
    async def generate_video(self, video_type: str, ...) -> Dict[str, Any]
    
    @abstractmethod
    async def generate_sound(self, description: str, ...) -> Dict[str, Any]
    
    @abstractmethod
    async def generate_voice(self, text: str, ...) -> Dict[str, Any]
    
    @abstractmethod
    async def generate_subtitle(self, audio_url: str, ...) -> Dict[str, Any]
    
    @abstractmethod
    async def process_text(self, task_type: str, text: str, ...) -> Dict[str, Any]

设计说明

  • 定义统一的接口规范
  • 所有方法返回标准化的 Dict 结构
  • 支持异步调用

1.2 MockAIProvider(Mock 实现)

class MockAIProvider(BaseAIProvider):
    """Mock AI Provider - 返回模拟数据"""
    
    async def generate_image(self, prompt: str, ...) -> Dict[str, Any]:
        # 模拟 AI 处理时间
        await asyncio.sleep(2)
        
        return {
            'image_url': 'https://mock-storage.jointo.ai/images/mock_xxx.png',
            'thumbnail_url': 'https://mock-storage.jointo.ai/images/mock_xxx_thumb.png',
            'metadata': {...}
        }

功能

  • 模拟 AI 处理时间(1-5 秒)
  • 返回符合规范的数据结构
  • 用于开发和测试
  • 无需真实 AI API 即可测试完整流程

1.3 AIProviderFactory(工厂类)

class AIProviderFactory:
    """AI Provider 工厂类"""
    
    @staticmethod
    def create_provider(model_name: str, config: Optional[Dict] = None) -> BaseAIProvider:
        # 目前所有模型都使用 Mock Provider
        # 后续可以根据 model_name 返回不同的 Provider
        return MockAIProvider(model_name, config)

扩展性

# 未来可以这样扩展:
if model_name.startswith('gpt-'):
    return OpenAIProvider(model_name, config)
elif model_name.startswith('stable-diffusion'):
    return StabilityAIProvider(model_name, config)
elif model_name.startswith('runway'):
    return RunwayProvider(model_name, config)

2. Celery Tasks 实现

2.1 任务基类

class AITask(Task):
    """AI 任务基类"""
    
    def on_failure(self, exc, task_id, args, kwargs, einfo):
        """任务失败时的回调"""
        logger.error(f"任务失败: task_id={task_id}, error={str(exc)}")

2.2 已实现的任务

  • generate_image_task - 图片生成
  • generate_video_task - 视频生成
  • generate_sound_task - 音效生成
  • generate_voice_task - 配音生成
  • generate_subtitle_task - 字幕生成
  • process_text_task - 文本处理

2.3 任务执行流程(以图片生成为例)

@celery_app.task(base=AITask, bind=True, max_retries=3)
def generate_image_task(self, job_id: str, user_id: str, prompt: str, model: str, ...):
    async def _execute():
        try:
            # 1. 更新任务状态为处理中
            await _update_job_status(job_id, AIJobStatus.PROCESSING, progress=10)
            
            # 2. 获取任务详情(获取 consumption_log_id)
            async with async_session_maker() as session:
                job_repo = AIJobRepository(session)
                job = await job_repo.get_by_id(job_id)
                consumption_log_id = job.consumption_log_id
            
            # 3. 创建 AI Provider
            provider = AIProviderFactory.create_provider(model, config)
            
            # 4. 调用 AI 生成
            await _update_job_status(job_id, AIJobStatus.PROCESSING, progress=30)
            result = await provider.generate_image(prompt=prompt, width=width, height=height, ...)
            
            # 5. 更新任务状态为完成
            await _update_job_status(job_id, AIJobStatus.COMPLETED, progress=100, output_data=result)
            
            # 6. 确认积分消耗
            await _confirm_or_refund_credits(
                job_id=job_id,
                consumption_log_id=consumption_log_id,
                success=True
            )
            
            return result
            
        except Exception as e:
            # 更新任务状态为失败
            await _update_job_status(job_id, AIJobStatus.FAILED, error_message=str(e))
            
            # 退还积分
            await _confirm_or_refund_credits(
                job_id=job_id,
                consumption_log_id=consumption_log_id,
                success=False,
                error_message=str(e)
            )
            
            # 重试任务
            if self.request.retries < self.max_retries:
                raise self.retry(exc=e, countdown=60 * (self.request.retries + 1))
            
            raise
    
    return asyncio.run(_execute())

2.4 积分确认和退还

async def _confirm_or_refund_credits(
    job_id: str,
    consumption_log_id: Optional[str],
    success: bool,
    resource_id: Optional[str] = None,
    error_message: Optional[str] = None
):
    """确认或退还积分(异步)"""
    async with async_session_maker() as session:
        credit_service = CreditService(session)
        
        if success:
            # 任务成功,确认积分消耗
            await credit_service.confirm_consumption(
                consumption_id=UUID(consumption_log_id),
                resource_id=UUID(resource_id) if resource_id else None
            )
        else:
            # 任务失败,退还积分
            await credit_service.refund_credits(
                consumption_id=UUID(consumption_log_id),
                reason=error_message or "任务执行失败"
            )
        
        await session.commit()

3. 核心特性

3.1 异步任务处理

  • 使用 Celery 异步执行
  • 支持任务重试(最多 3 次)
  • 重试间隔递增(60s, 120s, 180s)
  • 详细的日志记录

3.2 状态管理

  • PENDING → PROCESSING → COMPLETED
  • PENDING → PROCESSING → FAILED
  • 实时更新进度(0% → 10% → 30% → 100%)
  • 记录开始时间和完成时间

3.3 积分流程闭环

  • 任务创建时:预扣积分(AIService)
  • 任务成功时:确认消耗(Celery Task)
  • 任务失败时:退还积分(Celery Task)

3.4 错误处理

  • 捕获所有异常
  • 记录详细错误信息
  • 自动重试机制
  • 失败后退还积分

完整流程示例

1. 用户发起图片生成请求

# API 层
result = await ai_service.generate_image(
    user_id="019d1234-5678-7abc-def0-111111111111",
    prompt="一只可爱的猫咪",
    width=1024,
    height=1024
)

# 返回
{
    "job_id": "019d1234-5678-7abc-def0-222222222222",
    "task_id": "abc-123-def",
    "status": "pending",
    "estimated_credits": 10
}

2. AIService 预扣积分并创建任务

# 在事务中执行
async with self.db.begin():
    # 预扣积分
    consumption_log = await credit_service.consume_credits(
        user_id=UUID(user_id),
        amount=10,
        feature_type=FeatureType.IMAGE_GENERATION
    )
    
    # 创建任务
    job = await job_repository.create({
        'status': AIJobStatus.PENDING,
        'consumption_log_id': consumption_log.consumption_id,
        ...
    })
    
    # 关联
    consumption_log.ai_job_id = job.ai_job_id

3. Celery Worker 执行任务

# 后台异步执行
@celery_app.task
def generate_image_task(job_id, ...):
    # 1. 更新状态:PENDING → PROCESSING
    await _update_job_status(job_id, AIJobStatus.PROCESSING)
    
    # 2. 调用 AI Provider
    provider = AIProviderFactory.create_provider(model)
    result = await provider.generate_image(prompt, ...)
    
    # 3. 更新状态:PROCESSING → COMPLETED
    await _update_job_status(job_id, AIJobStatus.COMPLETED, output_data=result)
    
    # 4. 确认积分消耗
    await credit_service.confirm_consumption(consumption_log_id)

4. 用户查询任务状态

status = await ai_service.get_job_status(job_id)

# 返回
{
    "job_id": "019d1234-5678-7abc-def0-222222222222",
    "status": 3,  # COMPLETED
    "progress": 100,
    "output_data": {
        "image_url": "https://mock-storage.jointo.ai/images/xxx.png",
        "thumbnail_url": "https://mock-storage.jointo.ai/images/xxx_thumb.png"
    },
    "credits_used": 10
}

技术规范遵循

1. 异步编程

  • 使用 asyncio 和 async/await
  • 异步数据库会话
  • 异步 AI Provider 调用

2. 事务管理

  • 使用 async_session_maker
  • 确保数据一致性
  • 异常时自动回滚

3. 日志记录

  • 使用统一的 logging 模块
  • 记录任务开始、进度、完成、失败
  • 记录积分确认和退还

4. 错误处理

  • 捕获所有异常
  • 详细的错误信息
  • 自动重试机制
  • 失败后清理资源

影响范围

新增文件

  • server/app/services/ai_providers/__init__.py
  • server/app/services/ai_providers/base.py
  • server/app/services/ai_providers/mock_provider.py
  • server/app/services/ai_providers/factory.py
  • server/app/tasks/ai_tasks.py
  • docs/server/changelogs/2026-01-29-ai-tasks-implementation.md

修改文件

数据库变更


待完成工作

高优先级

  1. 真实 AI Providers 实现
    • OpenAI Provider(GPT-4, DALL-E, Whisper)
    • Stability AI Provider(Stable Diffusion)
    • Runway Provider(Gen-2)
    • 添加 API 密钥管理
    • 添加速率限制

中优先级

  1. 任务监控和管理

    • 任务队列监控
    • 任务超时处理
    • 任务优先级管理
    • 任务统计和分析
  2. 单元测试和集成测试

    • 测试 Mock Provider
    • 测试 Celery Tasks
    • 测试积分确认和退还
    • 测试重试机制

使用示例

1. 测试 Mock Provider

from app.services.ai_providers import AIProviderFactory

# 创建 Provider
provider = AIProviderFactory.create_provider('mock-model')

# 生成图片
result = await provider.generate_image(
    prompt="一只可爱的猫咪",
    width=1024,
    height=1024
)

print(result)
# {
#     'image_url': 'https://mock-storage.jointo.ai/images/mock_xxx.png',
#     'thumbnail_url': 'https://mock-storage.jointo.ai/images/mock_xxx_thumb.png',
#     'metadata': {...}
# }

2. 手动触发任务

from app.tasks.ai_tasks import generate_image_task

# 触发任务
task = generate_image_task.delay(
    job_id="019d1234-5678-7abc-def0-222222222222",
    user_id="019d1234-5678-7abc-def0-111111111111",
    prompt="一只可爱的猫咪",
    model="mock-model",
    width=1024,
    height=1024
)

print(f"Task ID: {task.id}")

3. 查询任务结果

from celery.result import AsyncResult

# 查询任务状态
result = AsyncResult(task_id)

print(f"Status: {result.status}")
print(f"Result: {result.result}")

注意事项

1. Celery Worker 必须运行

# 启动 Celery Worker
docker-compose up -d jointo-server-celery-ai

2. 异步数据库会话

Tasks 中使用独立的数据库会话:

async with async_session_maker() as session:
    # 使用 session
    ...

3. 任务重试

  • 最多重试 3 次
  • 重试间隔递增:60s, 120s, 180s
  • 重试前会退还积分

4. Mock Provider 限制

  • 仅用于开发和测试
  • 返回模拟数据
  • 不调用真实 AI API

相关文档


作者

  • Kiro AI Assistant
  • 日期:2026-01-29