You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

121 KiB

AI 生成服务

文档版本:v1.0
最后更新:2025-01-27


目录

  1. 服务概述
  2. 核心功能
  3. 与 Credit Service 集成
  4. 数据库设计
  5. 服务实现
  6. API 接口
  7. AI 提供商集成

服务概述

AI 生成服务负责处理各类 AI 内容生成任务,包括图片、视频、音效、配音等,支持多种 AI 模型和服务提供商。

职责

  • extra_data图片生成(文本转图片)
  • 视频生成(文本转视频、图片转视频)
  • 音效生成
  • 配音生成(文本转语音)
  • 字幕生成(语音转文本)
  • 任务状态管理

核心功能

应用层引用完整性保证

⚠️ 重要:本服务遵循 Jointo 技术栈规范,不使用数据库外键约束,改为在应用层保证引用完整性。

三层保证机制

  1. Repository 层:提供 exists() 方法检查记录是否存在
  2. Service 层:创建/更新前验证所有关联 ID,使用事务确保原子性
  3. 后台任务:定期检查孤儿记录并告警

优势

  • 写入性能提升 15-30%
  • 为分库分表、微服务化做准备
  • 复杂业务逻辑在应用层更易实现
  • 表结构变更无需处理外键依赖

详见:数据库设计服务实现


1. 图片生成

支持的模型

  • Stable Diffusion(开源)
  • DALL-E(OpenAI)
  • Midjourney(如果可用)

功能

  • 文本描述生成图片
  • 支持风格控制
  • 支持分辨率设置
  • 异步生成,返回任务 ID

2. 视频生成

支持的类型

  • 文本转视频(text2video)
  • 图片转视频(img2video)
  • 关键帧动画(keyframe)
  • 视频融合(fusion)
  • 视频替换(replace)

支持的模型

  • Runway Gen-2
  • Pika Labs
  • Stable Video Diffusion

3. 音效生成

功能

  • 根据描述生成音效
  • 支持音效类型选择
  • 支持时长控制

4. 配音生成

功能

  • 文本转语音(TTS)
  • 支持多种语音类型
  • 支持语速、音调控制
  • 支持多语言

支持的服务

  • OpenAI TTS
  • Azure Speech Services
  • 百度语音 / 讯飞语音

5. 字幕生成

功能

  • 语音转文本(STT)
  • 自动断句
  • 分镜看板对齐

支持的服务

  • Whisper(OpenAI)
  • 阿里云语音识别

6. 文本处理

功能

  • 剧本拆解:将完整剧本拆分为场景、镜头
  • 内容分析:提取关键信息(人物、地点、时间、动作)
  • 结构化输出:生成 JSON 格式的结构化数据
  • 智能补全:根据上下文补充缺失信息
  • 风格转换:调整文本风格和语气

支持的模型

  • GPT-4 / GPT-3.5(OpenAI)
  • Claude(Anthropic)
  • 文心一言(百度)
  • 通义千问(阿里)

应用场景

  • 剧本自动拆解为分镜
  • 场景描述生成视觉提示词
  • 对话文本生成配音脚本
  • 内容审核和合规检查

7. 剧本解析(Screenplay Parsing)

核心功能

  • 自动提取剧本元素
    • 角色识别(主角、配角、群演)
    • 场景识别(地点、时间、描述)
    • 道具识别(重要性分类)
  • 变体识别
    • 角色变体(年龄段、时代、状态)
    • 场景变体(时代、季节、状态)
    • 道具变体(状态、版本)
  • 分镜拆解
    • 自动拆分为分镜脚本
    • 识别景别和运镜
    • 估算时长
  • 自动关联
    • 分镜与角色/场景/道具自动关联
    • 变体自动匹配
  • 数据持久化
    • 自动存储到数据库
    • 建立关联关系

工作流程

  1. 用户上传/创建剧本
  2. 触发 AI 解析任务
  3. 预扣积分,创建 AI 任务
  4. Celery Worker 调用 AI 模型
  5. AI 返回结构化 JSON 数据
  6. 自动存储角色/场景/道具/变体
  7. 自动创建分镜记录
  8. 自动建立关联关系
  9. 确认积分消耗
  10. 返回解析结果

详细文档:参见 AI 解析剧本工作流

API 接口

POST /api/v1/screenplays/{screenplay_id}/parse

请求体

{
  "auto_create_elements": true,
  "auto_create_tags": true,
  "auto_create_storyboards": true,
  "model": "gpt-4"
}

响应

{
  "code": 200,
  "message": "Success",
  "data": {
    "jobId": "019d1234-5678-7abc-def0-222222222222",
    "taskId": "abc123-def456-ghi789",
    "status": "pending",
    "estimatedCredits": 50
  }
}

AI 输出格式

AI 模型返回包含以下结构的 JSON 数据:

{
  "characters": [
    {
      "name": "张三",
      "description": "男主角,30岁,程序员",
      "role_type": "main",
      "meta_data": {"age": 30, "gender": "male"}
    }
  ],
  "scenes": [
    {
      "scene_number": 1,
      "title": "咖啡厅",
      "location": "市中心星巴克",
      "time_of_day": "afternoon",
      "description": "温馨的咖啡厅"
    }
  ],
  "props": [
    {
      "name": "笔记本电脑",
      "description": "张三的工作电脑",
      "category": "电子设备",
      "importance": "normal"
    }
  ],
  "character_tags": [
    {
      "character_name": "张三",
      "tag_key": "youth",
      "tag_label": "少年",
      "description": "15岁的张三"
    }
  ],
  "scene_tags": [...],
  "prop_tags": [...],
  "storyboards": [
    {
      "shot_number": "001",
      "title": "开场",
      "description": "张三坐在咖啡厅里",
      "dialogue": "又是一个平凡的下午...",
      "shot_size": "medium_shot",
      "camera_movement": "static",
      "estimated_duration": 5.5,
      "characters": ["张三"],
      "character_tags": {"张三": "adult"},
      "scenes": ["咖啡厅"],
      "props": ["笔记本电脑"]
    }
  ]
}

自动存储逻辑

  1. 存储角色:批量插入 screenplay_characters 表,返回角色 ID 映射
  2. 存储场景:批量插入 screenplay_scenes 表,返回场景 ID 映射
  3. 存储道具:批量插入 screenplay_props 表,返回道具 ID 映射
  4. 存储标签:调用 ScreenplayTagService.store_tags() 批量插入 screenplay_element_tags 表,返回标签 ID 映射
  5. 存储分镜:批量插入 storyboards 表,同时建立关联关系

标签存储详细流程

# Celery Worker 调用 Screenplay Service
result = await screenplay_service.store_parsed_elements(
    screenplay_id=screenplay_id,
    parsed_data=parsed_data
)

# Screenplay Service 内部调用 Screenplay Tag Service
from app.services.screenplay_tag_service import ScreenplayTagService
tag_service = ScreenplayTagService(db)

tag_id_maps = await tag_service.store_tags(
    screenplay_id=screenplay_id,
    parsed_data=parsed_data,
    character_id_map=character_id_map,
    scene_id_map=scene_id_map,
    prop_id_map=prop_id_map
)

# 返回的 tag_id_maps 结构
{
    'character_tags': {
        '张三-youth': UUID('019d1234-5678-7abc-def0-444444444444'),
        '张三-adult': UUID('019d1234-5678-7abc-def0-555555555555')
    },
    'scene_tags': {
        '花果山-daytime': UUID('019d1234-5678-7abc-def0-666666666666'),
        '花果山-night': UUID('019d1234-5678-7abc-def0-777777777777')
    },
    'prop_tags': {
        '金箍棒-new': UUID('019d1234-5678-7abc-def0-888888888888')
    }
}

ScreenplayTagService.store_tags() 实现

async def store_tags(
    self,
    screenplay_id: UUID,
    parsed_data: Dict[str, Any],
    character_id_map: Dict[str, UUID],
    scene_id_map: Dict[str, UUID],
    prop_id_map: Dict[str, UUID]
) -> Dict[str, Dict[str, UUID]]:
    """存储 AI 解析的标签(供 AI Service 调用)"""
    logger.info("存储 AI 解析的标签: 剧本=%s", screenplay_id)
    
    tag_id_maps = {
        'character_tags': {},
        'scene_tags': {},
        'prop_tags': {}
    }
    
    # 1. 存储角色标签
    for char_name, tags in parsed_data.get('character_tags', {}).items():
        character_id = character_id_map.get(char_name)
        if not character_id:
            continue
        
        for tag_data in tags:
            tag = await self.repository.create(ScreenplayElementTag(
                screenplay_id=screenplay_id,
                element_type=ElementType.CHARACTER,
                element_id=character_id,
                element_name=char_name,
                tag_key=tag_data['tag_key'],
                tag_label=tag_data['tag_label'],
                description=tag_data.get('description'),
                meta_data=tag_data.get('meta_data', {})
            ))
            
            map_key = f"{char_name}-{tag_data['tag_key']}"
            tag_id_maps['character_tags'][map_key] = tag.tag_id
        
        # 更新角色的 has_tags 标志
        await self._update_element_has_tags(ElementType.CHARACTER, character_id, True)
    
    # 2. 存储场景标签(逻辑类似)
    for scene_name, tags in parsed_data.get('scene_tags', {}).items():
        scene_id = scene_id_map.get(scene_name)
        if not scene_id:
            continue
        
        for tag_data in tags:
            tag = await self.repository.create(ScreenplayElementTag(
                screenplay_id=screenplay_id,
                element_type=ElementType.SCENE,
                element_id=scene_id,
                element_name=scene_name,
                tag_key=tag_data['tag_key'],
                tag_label=tag_data['tag_label'],
                description=tag_data.get('description'),
                meta_data=tag_data.get('meta_data', {})
            ))
            
            map_key = f"{scene_name}-{tag_data['tag_key']}"
            tag_id_maps['scene_tags'][map_key] = tag.tag_id
        
        await self._update_element_has_tags(ElementType.SCENE, scene_id, True)
    
    # 3. 存储道具标签(逻辑类似)
    for prop_name, tags in parsed_data.get('prop_tags', {}).items():
        prop_id = prop_id_map.get(prop_name)
        if not prop_id:
            continue
        
        for tag_data in tags:
            tag = await self.repository.create(ScreenplayElementTag(
                screenplay_id=screenplay_id,
                element_type=ElementType.PROP,
                element_id=prop_id,
                element_name=prop_name,
                tag_key=tag_data['tag_key'],
                tag_label=tag_data['tag_label'],
                description=tag_data.get('description'),
                meta_data=tag_data.get('meta_data', {})
            ))
            
            map_key = f"{prop_name}-{tag_data['tag_key']}"
            tag_id_maps['prop_tags'][map_key] = tag.tag_id
        
        await self._update_element_has_tags(ElementType.PROP, prop_id, True)
    
    logger.info(
        "标签存储完成: 角色标签=%d, 场景标签=%d, 道具标签=%d",
        len(tag_id_maps['character_tags']),
        len(tag_id_maps['scene_tags']),
        len(tag_id_maps['prop_tags'])
    )
    
    return tag_id_maps

自动关联逻辑

分镜创建时,根据 AI 返回的名称数组和标签映射,自动查找对应的数据库 ID:

# 示例:角色关联
for char_name in storyboard_data['characters']:
    character_id = character_id_map.get(char_name)
    if character_id:
        screenplay_character_ids.append(character_id)
        
        # 检查标签
        tag_key = storyboard_data['character_tags'].get(char_name)
        if tag_key:
            tag_id = tag_id_maps['character_tags'].get(f"{char_name}-{tag_key}")
            if tag_id:
                screenplay_character_tag_ids.append(tag_id)

涉及的服务

  • AI Service:调用 AI 模型进行解析
  • Screenplay Service:管理剧本和剧本元素
  • Screenplay Tag Service:管理标签(新增)
  • Storyboard Service:管理分镜
  • Credit Service:管理积分扣除

与 Credit Service 集成

集成架构

AI 服务与积分服务采用同步调用 + 事务保证的集成模式:

AI Service (依赖) → Credit Service (被依赖)

积分扣除流程

1. 任务创建时预扣积分

# AI Service 创建任务时
async def generate_image(self, user_id: UUID, ...):
    # 1. 计算所需积分
    model_config = await self._get_model(model_name, 'image')
    credits_needed = model_config.credits_per_unit
    
    # 2. 调用 Credit Service 预扣积分(同步,事务保证)
    try:
        consumption_log = await self.credit_service.consume_credits(
            user_id=user_id,
            amount=credits_needed,
            feature_type=1,  # FeatureType.IMAGE_GENERATION
            ai_job_id=None,  # 稍后更新
            task_params={...}
        )
    except InsufficientCreditsError:
        raise ValidationError("积分不足")
    
    # 3. 创建 AI 任务,关联 consumption_log
    job = await self.job_repository.create({
        'consumption_log_id': consumption_log.consumption_id,
        ...
    })
    
    # 4. 更新 consumption_log 的 ai_job_id
    consumption_log.ai_job_id = job.ai_job_id
    await self.db.commit()

2. 任务完成时确认扣除

# Celery Worker 任务完成
@celery_app.task
def generate_image_task(job_id: UUID):
    try:
        # 执行 AI 生成
        result = call_ai_provider(...)
        
        # 更新任务状态
        await ai_service.update_job(job_id, {'status': 3, ...})  # AIJobStatus.COMPLETED
        
        # 确认积分消耗
        job = await ai_service.get_job(job_id)
        await credit_service.confirm_consumption(
            consumption_id=job.consumption_log_id,
            resource_id=result.resource_id
        )
    except Exception as e:
        # 任务失败,退还积分
        await ai_service.update_job(job_id, {'status': 4, ...})  # AIJobStatus.FAILED
        await credit_service.refund_credits(
            consumption_id=job.consumption_log_id,
            reason=str(e)
        )

数据关联

-- AI Service 表关联 Credit Service(无外键约束,应用层保证引用完整性)
ALTER TABLE ai_jobs 
ADD COLUMN consumption_log_id UUID;

CREATE INDEX idx_ai_jobs_consumption_log_id ON ai_jobs (consumption_log_id) 
    WHERE consumption_log_id IS NOT NULL;

-- Credit Service 表关联 AI Service(无外键约束,应用层保证引用完整性)
ALTER TABLE credit_consumption_logs 
ADD COLUMN ai_job_id UUID;

CREATE INDEX idx_credit_consumption_logs_ai_job_id ON credit_consumption_logs (ai_job_id) 
    WHERE ai_job_id IS NOT NULL;

COMMENT ON COLUMN ai_jobs.consumption_log_id IS '积分消耗日志 ID - 应用层验证';
COMMENT ON COLUMN credit_consumption_logs.ai_job_id IS 'AI 任务 ID - 应用层验证';

事务保证

使用数据库事务确保积分扣除和任务创建的原子性:

async with self.db.begin():
    # 1. 扣除积分
    consumption_log = await credit_service.consume_credits(...)
    
    # 2. 创建任务
    job = await ai_service.create_job(...)
    
    # 3. 关联
    consumption_log.ai_job_id = job.ai_job_id
    
    # 提交事务(失败自动回滚)

数据库设计

AI 服务需要以下数据表支撑任务管理、成本控制和配额管理。

3.1 ai_jobs(AI 任务表)

核心表,记录所有 AI 生成任务的状态和结果。

-- Python 枚举定义(app/models/ai_job.py)
-- class AIJobType(IntEnum):
--     IMAGE = 1              # 图片生成
--     VIDEO = 2              # 视频生成
--     SOUND = 3              # 音效生成
--     VOICE = 4              # 配音生成
--     SUBTITLE = 5           # 字幕生成
--     TEXT_PROCESSING = 6    # 文本处理(剧本拆解等)
--     RESOURCE = 7           # 资源生成
--     STORYBOARD_SCRIPT = 8  # 分镜脚本生成
--     SCRIPT_GENERATION = 9  # 剧本生成

-- class AIJobStatus(IntEnum):
--     PENDING = 1      # 等待处理
--     PROCESSING = 2   # 处理中
--     COMPLETED = 3    # 已完成
--     FAILED = 4       # 失败
--     CANCELLED = 5    # 已取消

CREATE TABLE ai_jobs (
    ai_job_id UUID PRIMARY KEY, -- AI 任务唯一标识
    
    -- 关联信息(无外键约束,应用层保证引用完整性)
    user_id UUID NOT NULL, -- 用户 ID
    project_id UUID, -- 项目 ID(可选)
    storyboard_id UUID, -- 分镜 ID(可选)
    
    -- 积分关联(无外键约束,应用层保证引用完整性)
    consumption_log_id UUID, -- 积分消耗日志 ID
    
    -- 任务信息
    job_type SMALLINT NOT NULL, -- 任务类型(1=图片 2=视频 3=音效 4=配音 5=字幕 6=文本处理 7=资源 8=分镜脚本 9=剧本生成)
    status SMALLINT NOT NULL DEFAULT 1, -- 任务状态(1=等待处理 2=处理中 3=已完成 4=失败 5=已取消)
    
    -- 输入输出
    input_data JSONB NOT NULL DEFAULT '{}', -- 输入参数(prompt, 配置等)
    output_data JSONB, -- 输出结果(URL, 元数据等)
    
    -- AI 模型信息(无外键约束,应用层保证引用完整性)
    model_id UUID, -- AI 模型 ID
    model_name TEXT, -- 使用的模型名称(冗余字段)
    
    -- 任务状态
    progress INTEGER NOT NULL DEFAULT 0 CHECK (progress >= 0 AND progress <= 100), -- 任务进度(0-100)
    error_message TEXT, -- 错误信息
    task_id TEXT, -- Celery 异步任务 ID
    
    -- 时间信息
    estimated_completion_at TIMESTAMPTZ, -- 预计完成时间
    started_at TIMESTAMPTZ, -- 开始处理时间
    completed_at TIMESTAMPTZ, -- 完成时间
    created_at TIMESTAMPTZ NOT NULL DEFAULT now(), -- 创建时间
    updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), -- 更新时间
    
    -- 成本信息
    cost NUMERIC(10, 4), -- 实际成本(元)
    credits_used INTEGER NOT NULL DEFAULT 0 -- 使用的积分
);

-- 字段注释
COMMENT ON TABLE ai_jobs IS 'AI 任务表 - 应用层保证引用完整性';
COMMENT ON COLUMN ai_jobs.ai_job_id IS 'AI 任务唯一标识';
COMMENT ON COLUMN ai_jobs.user_id IS '用户 ID - 应用层验证';
COMMENT ON COLUMN ai_jobs.project_id IS '项目 ID(可选)- 应用层验证';
COMMENT ON COLUMN ai_jobs.storyboard_id IS '分镜 ID(可选)- 应用层验证';
COMMENT ON COLUMN ai_jobs.consumption_log_id IS '积分消耗日志 ID - 应用层验证';
COMMENT ON COLUMN ai_jobs.job_type IS '任务类型(1=图片 2=视频 3=音效 4=配音 5=字幕 6=文本处理 7=资源 8=分镜脚本 9=剧本生成)';
COMMENT ON COLUMN ai_jobs.status IS '任务状态(1=等待处理 2=处理中 3=已完成 4=失败 5=已取消)';
COMMENT ON COLUMN ai_jobs.input_data IS '输入参数(prompt, 配置等)';
COMMENT ON COLUMN ai_jobs.output_data IS '输出结果(URL, 元数据等)';
COMMENT ON COLUMN ai_jobs.model_id IS 'AI 模型 ID - 应用层验证';
COMMENT ON COLUMN ai_jobs.model_name IS '使用的模型名称(冗余字段)';
COMMENT ON COLUMN ai_jobs.progress IS '任务进度(0-100)';
COMMENT ON COLUMN ai_jobs.error_message IS '错误信息';
COMMENT ON COLUMN ai_jobs.task_id IS 'Celery 异步任务 ID';
COMMENT ON COLUMN ai_jobs.estimated_completion_at IS '预计完成时间';
COMMENT ON COLUMN ai_jobs.started_at IS '开始处理时间';
COMMENT ON COLUMN ai_jobs.completed_at IS '完成时间';
COMMENT ON COLUMN ai_jobs.created_at IS '创建时间';
COMMENT ON COLUMN ai_jobs.updated_at IS '更新时间';
COMMENT ON COLUMN ai_jobs.cost IS '实际成本(元)';
COMMENT ON COLUMN ai_jobs.credits_used IS '使用的积分';

-- 索引
CREATE INDEX idx_ai_jobs_user_id ON ai_jobs (user_id);
CREATE INDEX idx_ai_jobs_project_id ON ai_jobs (project_id) WHERE project_id IS NOT NULL;
CREATE INDEX idx_ai_jobs_storyboard_id ON ai_jobs (storyboard_id) WHERE storyboard_id IS NOT NULL;
CREATE INDEX idx_ai_jobs_type ON ai_jobs (job_type);
CREATE INDEX idx_ai_jobs_status ON ai_jobs (status);
CREATE INDEX idx_ai_jobs_model_id ON ai_jobs (model_id) WHERE model_id IS NOT NULL;
CREATE INDEX idx_ai_jobs_consumption_log_id ON ai_jobs (consumption_log_id) WHERE consumption_log_id IS NOT NULL;
CREATE INDEX idx_ai_jobs_created_at ON ai_jobs (created_at);
CREATE INDEX idx_ai_jobs_status_created_at ON ai_jobs (status, created_at) 
    WHERE status IN (1, 2);  -- PENDING, PROCESSING
CREATE INDEX idx_ai_jobs_input_data_gin ON ai_jobs USING GIN (input_data);
CREATE INDEX idx_ai_jobs_output_data_gin ON ai_jobs USING GIN (output_data) 
    WHERE output_data IS NOT NULL;

-- 触发器
CREATE TRIGGER update_ai_jobs_updated_at
    BEFORE UPDATE ON ai_jobs
    FOR EACH ROW
    EXECUTE FUNCTION update_updated_at_column();

字段说明

  • input_data:存储任务输入参数

    {
      "prompt": "一只可爱的猫咪",
      "width": 1024,
      "height": 1024,
      "style": "realistic",
      "temperature": 0.7
    }
    
  • output_data:存储任务输出结果

    {
      "file_url": "https://storage.jointo.ai/ai-generated/images/abc123def456.png",
      "file_size": 1024000,
      "checksum": "abc123def456...",
      "mime_type": "image/png",
      "original_url": "https://ai-provider.com/temp/xyz789.png",
      "meta_data": {
        "width": 1024,
        "height": 1024,
        "format": "png"
      }
    }
    

    字段说明

    • file_url: 自有 OSS 存储的文件 URL(永久有效)
    • file_size: 文件大小(字节)
    • checksum: 文件 SHA256 校验和(用于去重)
    • mime_type: 文件 MIME 类型
    • original_url: AI 提供商返回的原始 URL(可选,用于调试)
    • meta_data: 文件元数据(宽度、高度、格式等)

3.2 ai_models(AI 模型配置表)

推荐表,管理可用的 AI 模型及其定价配置。

-- Python 枚举定义(app/models/ai_model.py)
-- class AIModelType(IntEnum):
--     TEXT = 1    # 文本模型(GPT, Claude 等)
--     IMAGE = 2   # 图片模型(DALL-E, Stable Diffusion 等)
--     VIDEO = 3   # 视频模型(Runway, Pika 等)
--     AUDIO = 4   # 音频模型(TTS, STT 等)

-- class AIProvider(IntEnum):
--     OPENAI = 1
--     ANTHROPIC = 2
--     GOOGLE = 3       # Google Gemini
--     STABILITY = 4
--     RUNWAY = 5
--     PIKA = 6
--     ELEVENLABS = 7
--     AZURE = 8
--     BAIDU = 9
--     ALIYUN = 10
--     CUSTOM = 99      # 自定义提供商

-- class UnitType(IntEnum):
--     TOKEN = 1    # Token(文本模型)
--     IMAGE = 2    # 图片(图片模型)
--     SECOND = 3   # 秒(视频/音频模型)
--     REQUEST = 4  # 请求(通用计费单位)

CREATE TABLE ai_models (
    model_id UUID PRIMARY KEY, -- AI 模型唯一标识
    
    -- 模型信息
    model_name TEXT NOT NULL UNIQUE, -- 模型名称(如:gpt-4, dall-e-3, runway-gen2)
    display_name TEXT NOT NULL, -- 显示名称
    description TEXT, -- 模型描述
    
    -- 分类
    provider SMALLINT NOT NULL, -- 提供商(1=OpenAI 2=Anthropic 3=Google 4=Stability 5=Runway 6=Pika 7=ElevenLabs 8=Azure 9=Baidu 10=Aliyun 99=自定义)
    model_type SMALLINT NOT NULL, -- 模型类型(1=文本 2=图片 3=视频 4=音频)
    
    -- 定价(按使用量计费)
    cost_per_unit NUMERIC(10, 4) NOT NULL, -- 单位成本(元)
    unit_type SMALLINT NOT NULL, -- 单位类型(1=Token 2=图片 3=秒 4=请求)
    credits_per_unit INTEGER NOT NULL, -- 每单位消耗积分
    
    -- 配置
    config JSONB NOT NULL DEFAULT '{}', -- 模型特定配置
    
    -- 限制
    rate_limit INTEGER, -- 速率限制(每分钟请求数)
    daily_quota INTEGER, -- 每日配额
    
    -- 状态
    is_active BOOLEAN NOT NULL DEFAULT true, -- 是否启用
    is_beta BOOLEAN NOT NULL DEFAULT false, -- 是否为测试版
    
    -- 审计
    created_at TIMESTAMPTZ NOT NULL DEFAULT now(), -- 创建时间
    updated_at TIMESTAMPTZ NOT NULL DEFAULT now() -- 更新时间
);

-- 字段注释
COMMENT ON COLUMN ai_models.model_id IS 'AI 模型唯一标识';
COMMENT ON COLUMN ai_models.model_name IS '模型名称(如:gpt-4, dall-e-3, runway-gen2)';
COMMENT ON COLUMN ai_models.display_name IS '显示名称';
COMMENT ON COLUMN ai_models.description IS '模型描述';
COMMENT ON COLUMN ai_models.provider IS '提供商(1=OpenAI 2=Anthropic 3=Google 4=Stability 5=Runway 6=Pika 7=ElevenLabs 8=Azure 9=Baidu 10=Aliyun 99=自定义)';
COMMENT ON COLUMN ai_models.model_type IS '模型类型(1=文本 2=图片 3=视频 4=音频)';
COMMENT ON COLUMN ai_models.cost_per_unit IS '单位成本(元)';
COMMENT ON COLUMN ai_models.unit_type IS '单位类型(1=Token 2=图片 3=秒 4=请求)';
COMMENT ON COLUMN ai_models.credits_per_unit IS '每单位消耗积分';
COMMENT ON COLUMN ai_models.config IS '模型特定配置';
COMMENT ON COLUMN ai_models.rate_limit IS '速率限制(每分钟请求数)';
COMMENT ON COLUMN ai_models.daily_quota IS '每日配额';
COMMENT ON COLUMN ai_models.is_active IS '是否启用';
COMMENT ON COLUMN ai_models.is_beta IS '是否为测试版';
COMMENT ON COLUMN ai_models.created_at IS '创建时间';
COMMENT ON COLUMN ai_models.updated_at IS '更新时间';

-- 索引
CREATE INDEX idx_ai_models_provider ON ai_models (provider) WHERE is_active = true;
CREATE INDEX idx_ai_models_type ON ai_models (model_type) WHERE is_active = true;
CREATE INDEX idx_ai_models_is_active ON ai_models (is_active);
CREATE INDEX idx_ai_models_config_gin ON ai_models USING GIN (config);

-- 触发器
CREATE TRIGGER update_ai_models_updated_at
    BEFORE UPDATE ON ai_models
    FOR EACH ROW
    EXECUTE FUNCTION update_updated_at_column();

设计说明

  • 动态配置模型,无需修改代码
  • 支持多种计费单位(token、图片、秒、请求)
  • 使用 JSONB 存储模型特定配置
  • 支持速率限制和配额管理
  • 新增提供商
    • google:支持 Gemini Pro 和 Gemini 1.5 Pro
    • custom:支持自定义 AI 提供商(用户自建模型、第三方 API 等)

Custom 提供商使用场景

  • 用户自建的 AI 模型服务
  • 第三方 AI API 集成
  • 企业内部 AI 服务
  • 测试和开发环境

Custom 提供商配置示例

{
  "api_endpoint": "https://custom-ai.example.com/v1/generate",
  "auth_type": "bearer",
  "headers": {
    "Authorization": "Bearer ${API_KEY}"
  },
  "request_format": "json",
  "response_format": "json"
}

3.3 ai_usage_logs(AI 使用日志表)

计费必需表,记录每次 AI 调用的详细使用情况。

CREATE TABLE ai_usage_logs (
    usage_log_id UUID PRIMARY KEY, -- 使用日志唯一标识
    
    -- 关联信息(无外键约束,应用层保证引用完整性)
    user_id UUID NOT NULL, -- 用户 ID
    ai_job_id UUID NOT NULL, -- AI 任务 ID
    model_id UUID, -- AI 模型 ID
    
    -- 使用量
    units_used NUMERIC(10, 2) NOT NULL, -- 使用的单位数(tokens, 秒数等)
    unit_type SMALLINT NOT NULL CHECK (unit_type BETWEEN 1 AND 4), -- 单位类型(1=Token 2=图片 3=秒 4=请求)
    
    -- 成本
    cost NUMERIC(10, 4) NOT NULL, -- 实际成本(元)
    credits_used INTEGER NOT NULL, -- 消耗的积分
    
    -- 元数据
    meta_data JSONB NOT NULL DEFAULT '{}', -- 使用元数据
    
    -- 时间
    created_at TIMESTAMPTZ NOT NULL DEFAULT now() -- 创建时间
);

-- 字段注释
COMMENT ON TABLE ai_usage_logs IS 'AI 使用日志表 - 应用层保证引用完整性';
COMMENT ON COLUMN ai_usage_logs.usage_log_id IS '使用日志唯一标识';
COMMENT ON COLUMN ai_usage_logs.user_id IS '用户 ID - 应用层验证';
COMMENT ON COLUMN ai_usage_logs.ai_job_id IS 'AI 任务 ID - 应用层验证';
COMMENT ON COLUMN ai_usage_logs.model_id IS 'AI 模型 ID - 应用层验证';
COMMENT ON COLUMN ai_usage_logs.units_used IS '使用的单位数(tokens, 秒数等)';
COMMENT ON COLUMN ai_usage_logs.unit_type IS '单位类型(1=Token 2=图片 3=秒 4=请求)';
COMMENT ON COLUMN ai_usage_logs.cost IS '实际成本(元)';
COMMENT ON COLUMN ai_usage_logs.credits_used IS '消耗的积分';
COMMENT ON COLUMN ai_usage_logs.meta_data IS '使用元数据';
COMMENT ON COLUMN ai_usage_logs.created_at IS '创建时间';

-- 索引
CREATE INDEX idx_ai_usage_logs_user_id ON ai_usage_logs (user_id);
CREATE INDEX idx_ai_usage_logs_ai_job_id ON ai_usage_logs (ai_job_id);
CREATE INDEX idx_ai_usage_logs_model_id ON ai_usage_logs (model_id) WHERE model_id IS NOT NULL;
CREATE INDEX idx_ai_usage_logs_created_at ON ai_usage_logs (created_at);
CREATE INDEX idx_ai_usage_logs_user_created ON ai_usage_logs (user_id, created_at);
CREATE INDEX idx_ai_usage_logs_meta_data_gin ON ai_usage_logs USING GIN (meta_data);

-- 分区策略(按月分区,便于归档)
-- CREATE TABLE ai_usage_logs_2026_01 PARTITION OF ai_usage_logs
--     FOR VALUES FROM ('2026-01-01') TO ('2026-02-01');

设计说明

  • 每次 AI 调用都记录详细日志
  • 支持成本分析和用户账单生成
  • 建议按月分区,便于历史数据归档
  • 使用 JSONB 存储详细的使用元数据

3.4 ai_quotas(AI 配额表)

推荐表,管理用户的 AI 使用配额和限流。

-- Python 枚举定义(app/models/ai_quota.py)
-- class QuotaPeriod(IntEnum):
--     DAILY = 1    # 每日配额
--     MONTHLY = 2  # 每月配额
--     TOTAL = 3    # 总配额

CREATE TABLE ai_quotas (
    quota_id UUID PRIMARY KEY, -- 配额唯一标识
    
    -- 用户信息(无外键约束,应用层保证引用完整性)
    user_id UUID NOT NULL, -- 用户 ID
    
    -- 配额类型
    quota_type TEXT NOT NULL, -- 配额类型(如:image_generation, video_generation, text_processing)
    period SMALLINT NOT NULL CHECK (period BETWEEN 1 AND 3), -- 配额周期(1=每日 2=每月 3=总计)
    
    -- 配额限制
    total_quota INTEGER NOT NULL, -- 总配额
    used_quota INTEGER NOT NULL DEFAULT 0, -- 已使用配额
    
    -- 重置时间
    reset_at TIMESTAMPTZ NOT NULL, -- 重置时间
    
    -- 审计
    created_at TIMESTAMPTZ NOT NULL DEFAULT now(), -- 创建时间
    updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), -- 更新时间
    
    -- 唯一约束
    CONSTRAINT ai_quotas_unique UNIQUE (user_id, quota_type, period) NULLS NOT DISTINCT,
    
    -- 检查约束
    CONSTRAINT ai_quotas_used_check CHECK (used_quota >= 0 AND used_quota <= total_quota)
);

-- 字段注释
COMMENT ON TABLE ai_quotas IS 'AI 配额表 - 应用层保证引用完整性';
COMMENT ON COLUMN ai_quotas.quota_id IS '配额唯一标识';
COMMENT ON COLUMN ai_quotas.user_id IS '用户 ID - 应用层验证';
COMMENT ON COLUMN ai_quotas.quota_type IS '配额类型(如:image_generation, video_generation, text_processing)';
COMMENT ON COLUMN ai_quotas.period IS '配额周期(1=每日 2=每月 3=总计)';
COMMENT ON COLUMN ai_quotas.total_quota IS '总配额';
COMMENT ON COLUMN ai_quotas.used_quota IS '已使用配额';
COMMENT ON COLUMN ai_quotas.reset_at IS '重置时间';
COMMENT ON COLUMN ai_quotas.created_at IS '创建时间';
COMMENT ON COLUMN ai_quotas.updated_at IS '更新时间';

-- 索引
CREATE INDEX idx_ai_quotas_user_id ON ai_quotas (user_id);
CREATE INDEX idx_ai_quotas_type ON ai_quotas (quota_type);
CREATE INDEX idx_ai_quotas_reset_at ON ai_quotas (reset_at);
CREATE INDEX idx_ai_quotas_user_type ON ai_quotas (user_id, quota_type);

-- 触发器
CREATE TRIGGER update_ai_quotas_updated_at
    BEFORE UPDATE ON ai_quotas
    FOR EACH ROW
    EXECUTE FUNCTION update_updated_at_column();

-- 自动重置配额的函数
CREATE OR REPLACE FUNCTION reset_expired_quotas()
RETURNS void AS $$
BEGIN
    -- 重置每日配额 (period = 1)
    UPDATE ai_quotas
    SET used_quota = 0,
        reset_at = reset_at + INTERVAL '1 day',
        updated_at = now()
    WHERE period = 1 AND reset_at <= now();
    
    -- 重置每月配额 (period = 2)
    UPDATE ai_quotas
    SET used_quota = 0,
        reset_at = reset_at + INTERVAL '1 month',
        updated_at = now()
    WHERE period = 2 AND reset_at <= now();
END;
$$ LANGUAGE plpgsql;

-- 定时任务(需要配合 pg_cron 或应用层定时任务)
-- SELECT cron.schedule('reset-ai-quotas', '0 0 * * *', 'SELECT reset_expired_quotas()');

设计说明

  • 支持每日、每月、总配额三种类型
  • 自动重置过期配额
  • 防止配额超限(CHECK 约束)
  • 支持不同类型的配额管理

3.5 数据表关系图

┌─────────────┐
│   users     │
└──────┬──────┘
       │
       ├──────────────────────────────────┐
       │                                  │
       ▼                                  ▼
┌─────────────┐                    ┌─────────────┐
│  ai_jobs    │───────────────────>│ ai_models   │
└──────┬──────┘                    └─────────────┘
       │                                  ▲
       │                                  │
       ▼                                  │
┌─────────────┐                          │
│ai_usage_logs│──────────────────────────┘
└─────────────┘

┌─────────────┐
│ ai_quotas   │
└─────────────┘
       ▲
       │
       └──────── users

Pydantic Schema 定义

请求 Schema

# app/schemas/ai_job.py
from pydantic import BaseModel, Field, field_validator
from uuid import UUID
from typing import Optional, Dict, Any
from enum import IntEnum

class AIJobType(IntEnum):
    """AI 任务类型"""
    IMAGE = 1
    VIDEO = 2
    SOUND = 3
    VOICE = 4
    SUBTITLE = 5
    TEXT_PROCESSING = 6
    RESOURCE = 7
    STORYBOARD_SCRIPT = 8
    SCRIPT_GENERATION = 9

class VideoType(str):
    """视频生成类型"""
    TEXT2VIDEO = "text2video"
    IMG2VIDEO = "img2video"
    KEYFRAME = "keyframe"
    FUSION = "fusion"
    REPLACE = "replace"

class TextProcessingTaskType(str):
    """文本处理任务类型"""
    SCREENPLAY_PARSE = "screenplay_parse"
    CONTENT_ANALYSIS = "content_analysis"
    STYLE_TRANSFORM = "style_transform"
    PROMPT_GENERATION = "prompt_generation"

# ==================== 图片生成 ====================

class ImageGenerationRequest(BaseModel):
    """图片生成请求"""
    prompt: str = Field(..., description="提示词", min_length=1, max_length=2000)
    model: Optional[str] = Field(None, description="模型名称")
    width: int = Field(1024, description="宽度", ge=256, le=2048)
    height: int = Field(1024, description="高度", ge=256, le=2048)
    style: Optional[str] = Field(None, description="风格")
    
    class Config:
        json_schema_extra = {
            "example": {
                "prompt": "一只可爱的猫咪在花园里玩耍",
                "model": "stable_diffusion",
                "width": 1024,
                "height": 1024,
                "style": "realistic"
            }
        }

# ==================== 视频生成 ====================

class VideoGenerationRequest(BaseModel):
    """视频生成请求"""
    video_type: str = Field(..., description="视频类型: text2video, img2video, keyframe, fusion, replace")
    prompt: Optional[str] = Field(None, description="提示词(text2video 必需)")
    image_url: Optional[str] = Field(None, description="图片 URL(img2video 必需)")
    duration: int = Field(5, description="时长(秒)", ge=1, le=60)
    fps: int = Field(30, description="帧率", ge=15, le=60)
    model: Optional[str] = Field(None, description="模型名称")
    
    @field_validator('video_type')
    @classmethod
    def validate_video_type(cls, v):
        valid_types = ['text2video', 'img2video', 'keyframe', 'fusion', 'replace']
        if v not in valid_types:
            raise ValueError(f"video_type 必须是以下之一: {', '.join(valid_types)}")
        return v
    
    @field_validator('prompt')
    @classmethod
    def validate_prompt(cls, v, info):
        if info.data.get('video_type') == 'text2video' and not v:
            raise ValueError("text2video 类型必须提供 prompt")
        return v
    
    @field_validator('image_url')
    @classmethod
    def validate_image_url(cls, v, info):
        if info.data.get('video_type') == 'img2video' and not v:
            raise ValueError("img2video 类型必须提供 image_url")
        return v

# ==================== 音效生成 ====================

class SoundGenerationRequest(BaseModel):
    """音效生成请求"""
    description: str = Field(..., description="音效描述", min_length=1, max_length=500)
    duration: int = Field(5, description="时长(秒)", ge=1, le=30)
    sound_type: Optional[str] = Field(None, description="音效类型")
    model: Optional[str] = Field(None, description="模型名称")

# ==================== 配音生成 ====================

class VoiceGenerationRequest(BaseModel):
    """配音生成请求"""
    text: str = Field(..., description="文本内容", min_length=1, max_length=5000)
    voice_type: str = Field("alloy", description="语音类型")
    speed: float = Field(1.0, description="语速", ge=0.5, le=2.0)
    language: str = Field("zh-CN", description="语言")
    model: Optional[str] = Field(None, description="模型名称")

# ==================== 字幕生成 ====================

class SubtitleGenerationRequest(BaseModel):
    """字幕生成请求"""
    audio_url: str = Field(..., description="音频 URL")
    language: str = Field("zh", description="语言")
    model: Optional[str] = Field(None, description="模型名称")

# ==================== 文本处理 ====================

class TextProcessingRequest(BaseModel):
    """文本处理请求"""
    task_type: str = Field(..., description="任务类型: screenplay_parse, content_analysis, style_transform, prompt_generation")
    text: str = Field(..., description="待处理文本", min_length=1, max_length=50000)
    model: Optional[str] = Field(None, description="模型名称")
    output_format: str = Field("json", description="输出格式: json, text")
    temperature: float = Field(0.7, description="生成温度", ge=0.0, le=2.0)
    max_tokens: int = Field(4000, description="最大 token 数", ge=100, le=8000)
    
    @field_validator('task_type')
    @classmethod
    def validate_task_type(cls, v):
        valid_types = ['screenplay_parse', 'content_analysis', 'style_transform', 'prompt_generation']
        if v not in valid_types:
            raise ValueError(f"task_type 必须是以下之一: {', '.join(valid_types)}")
        return v

# ==================== 统一创建请求 ====================

class AIJobCreateRequest(BaseModel):
    """AI 任务创建请求(统一入口)"""
    job_type: str = Field(..., description="任务类型: image, video, sound, voice, subtitle, textProcessing")
    params: Dict[str, Any] = Field(..., description="任务参数")
    
    @field_validator('job_type')
    @classmethod
    def validate_job_type(cls, v):
        valid_types = ['image', 'video', 'sound', 'voice', 'subtitle', 'textProcessing']
        if v not in valid_types:
            raise ValueError(f"job_type 必须是以下之一: {', '.join(valid_types)}")
        return v

响应 Schema

# app/schemas/ai_job.py (续)
from datetime import datetime

class AIJobResponse(BaseModel):
    """AI 任务响应"""
    job_id: UUID = Field(..., description="任务 ID")
    task_id: str = Field(..., description="Celery 任务 ID")
    status: str = Field(..., description="任务状态")
    estimated_credits: int = Field(..., description="预估积分消耗")
    
    class Config:
        json_schema_extra = {
            "example": {
                "job_id": "019d1234-5678-7abc-def0-111111111111",
                "task_id": "abc-123-def",
                "status": "pending",
                "estimated_credits": 10
            }
        }

class AIJobStatusResponse(BaseModel):
    """AI 任务状态响应"""
    job_id: UUID
    job_type: int
    status: int
    progress: int
    input_data: Dict[str, Any]
    output_data: Optional[Dict[str, Any]]
    error_message: Optional[str]
    model_name: Optional[str]
    cost: Optional[float]
    credits_used: int
    created_at: datetime
    updated_at: datetime
    started_at: Optional[datetime]
    completed_at: Optional[datetime]

class AIJobListResponse(BaseModel):
    """AI 任务列表响应"""
    items: list[AIJobStatusResponse]
    total: int
    page: int
    page_size: int
    total_pages: int

class AIModelResponse(BaseModel):
    """AI 模型响应"""
    model_id: UUID
    model_name: str
    display_name: str
    description: Optional[str]
    provider: int
    model_type: int
    cost_per_unit: float
    unit_type: int
    credits_per_unit: int
    is_beta: bool
    config: Dict[str, Any]

class AIUsageStatsResponse(BaseModel):
    """AI 使用统计响应"""
    total_cost: float
    total_credits_used: int
    total_requests: int
    by_model: Dict[str, Dict[str, Any]]
    by_type: Dict[str, Dict[str, Any]]
    quotas: list[Dict[str, Any]]

服务实现

AIService 类

# app/services/ai_service.py
from typing import Dict, Any, Optional
from uuid import UUID
from app.tasks.ai_tasks import (
    generate_image_task,
    generate_video_task,
    generate_sound_task,
    generate_voice_task,
    generate_subtitle_task,
    process_text_task
)
from app.repositories.ai_job_repository import AIJobRepository
from app.repositories.ai_model_repository import AIModelRepository
from app.repositories.ai_usage_log_repository import AIUsageLogRepository
from app.repositories.ai_quota_repository import AIQuotaRepository
from app.services.credit_service import CreditService
from app.core.exceptions import InsufficientCreditsError, ValidationError
from sqlalchemy.orm import Session

class AIService:
    def __init__(self, db: Session):
        self.db = db
        self.job_repository = AIJobRepository(db)
        self.model_repository = AIModelRepository(db)
        self.usage_log_repository = AIUsageLogRepository(db)
        self.quota_repository = AIQuotaRepository(db)
        self.credit_service = CreditService(db)  # 注入 Credit Service

    # ==================== 枚举验证方法 ====================

    def _validate_job_type(self, job_type: int) -> None:
        """验证任务类型枚举值"""
        from app.models.ai_job import AIJobType
        if job_type not in AIJobType.to_dict().keys():
            raise ValidationError(f"无效的任务类型: {job_type}")

    def _validate_job_status(self, job_status: int) -> None:
        """验证任务状态枚举值"""
        from app.models.ai_job import AIJobStatus
        if job_status not in AIJobStatus.to_dict().keys():
            raise ValidationError(f"无效的任务状态: {job_status}")

    def _validate_provider(self, provider: int) -> None:
        """验证提供商枚举值"""
        from app.models.ai_model import AIModelProvider
        if provider not in AIModelProvider.to_dict().keys():
            raise ValidationError(f"无效的提供商: {provider}")

    def _validate_model_type(self, model_type: int) -> None:
        """验证模型类型枚举值"""
        from app.models.ai_model import AIModelType
        if model_type not in AIModelType.to_dict().keys():
            raise ValidationError(f"无效的模型类型: {model_type}")

    def _validate_unit_type(self, unit_type: int) -> None:
        """验证单位类型枚举值"""
        from app.models.ai_model import AIModelUnitType
        if unit_type not in AIModelUnitType.to_dict().keys():
            raise ValidationError(f"无效的单位类型: {unit_type}")

    # ==================== 私有方法 ====================

    async def _check_quota(self, user_id: UUID, quota_type: str) -> bool:
        """检查用户配额是否充足"""
        quota = await self.quota_repository.get_user_quota(user_id, quota_type)
        if not quota:
            return True  # 没有配额限制
        
        if quota.used_quota >= quota.total_quota:
            from app.core.exceptions import QuotaExceededError
            raise QuotaExceededError(f"配额不足:{quota_type}")
        
        return True

    def _get_feature_type(self, job_type: int) -> int:
        """将 AIJobType 映射到 Credit Service 的 feature_type
        
        Args:
            job_type: AIJobType 枚举值
            
        Returns:
            Credit Service 的 feature_type 枚举值(SMALLINT)
        """
        from app.models.ai_job import AIJobType
        
        # 直接返回 job_type,因为 AI Service 和 Credit Service 使用相同的枚举值
        # AIJobType: 1=图片 2=视频 3=音效 4=配音 5=字幕 6=文本处理 7=资源 8=分镜脚本 9=剧本生成
        # FeatureType: 1=图片生成 2=视频生成 3=文本处理 4=音频生成 5=音效生成 6=配音生成 7=字幕生成 8=资源生成 9=分镜脚本生成 10=剧本生成
        
        # 需要特殊处理的映射:
        # AIJobType.TEXT_PROCESSING (6) -> FeatureType.TEXT_PROCESSING (3)
        # AIJobType.RESOURCE (7) -> FeatureType.RESOURCE_GENERATION (8)
        # AIJobType.STORYBOARD_SCRIPT (8) -> FeatureType.STORYBOARD_GENERATION (9)
        # AIJobType.SCRIPT_GENERATION (9) -> FeatureType.SCRIPT_GENERATION (10)
        
        feature_type_map = {
            AIJobType.IMAGE: 1,           # 图片生成
            AIJobType.VIDEO: 2,           # 视频生成
            AIJobType.TEXT_PROCESSING: 3,   # 文本处理
            AIJobType.SOUND: 5,            # 音效生成
            AIJobType.VOICE: 6,            # 配音生成
            AIJobType.SUBTITLE: 7,         # 字幕生成
            AIJobType.RESOURCE: 8,         # 资源生成
            AIJobType.STORYBOARD_SCRIPT: 9, # 分镜脚本生成
            AIJobType.SCRIPT_GENERATION: 10 # 剧本生成
        }
        
        return feature_type_map.get(job_type, 1)  # 默认返回图片生成

    async def _get_model(self, model_name: Optional[str], model_type: int) -> Dict[str, Any]:
        """获取模型配置
        
        Args:
            model_name: 模型名称
            model_type: 模型类型(AIModelType 枚举值)
        """
        # 验证模型类型枚举值
        self._validate_model_type(model_type)

        if model_name:
            model = await self.model_repository.get_by_name(model_name)
        else:
            # 获取默认模型
            model = await self.model_repository.get_default_model(model_type)
        
        if not model or not model.is_active:
            from app.core.exceptions import NotFoundError
            raise NotFoundError(f"模型不可用: {model_name}")
        
        return model

    async def _record_usage(
        self,
        user_id: UUID,
        job_id: UUID,
        model_id: UUID,
        units_used: float,
        unit_type: int,  # UnitType 枚举值
        cost: float,
        credits_used: int,
        meta_data: Dict[str, Any]
    ) -> None:
        """记录使用日志"""
        # 验证单位类型枚举值
        self._validate_unit_type(unit_type)

        await self.usage_log_repository.create({
            'user_id': user_id,
            'ai_job_id': job_id,
            'model_id': model_id,
            'units_used': units_used,
            'unit_type': unit_type,
            'cost': cost,
            'credits_used': credits_used,
            'meta_data': meta_data
        })

    async def generate_image(
        self,
        user_id: UUID,
        prompt: str,
        model: Optional[str] = None,
        width: int = 1024,
        height: int = 1024,
        style: Optional[str] = None,
        **kwargs
    ) -> Dict[str, Any]:
        """生成图片(异步)
        
        应用层引用完整性保证:
        1. 验证 user_id 是否存在
        2. 验证 model_id 是否存在
        3. 使用事务确保原子性
        """
        # 1. 验证用户是否存在(应用层引用完整性保证)
        from app.repositories.user_repository import UserRepository
        user_repo = UserRepository(self.db)
        if not await user_repo.exists(user_id):
            raise ValidationError("用户不存在")
        
        # 2. 检查配额
        await self._check_quota(user_id, 'image_generation')
        
        # 3. 获取模型配置(AIModelType.IMAGE = 2)
        from app.models.ai_model import AIModelType
        model_config = await self._get_model(model, AIModelType.IMAGE)
        
        # 4. 验证模型是否存在(应用层引用完整性保证)
        if not await self.model_repository.exists(model_config.model_id):
            raise ValidationError("AI 模型不存在")
        
        # 5. 计算所需积分
        credits_needed = model_config.credits_per_unit
        
        # 6. 使用事务确保原子性
        async with self.db.begin():
            # 预扣积分(同步,事务保证)
            try:
                consumption_log = await self.credit_service.consume_credits(
                    user_id=user_id,
                    amount=credits_needed,
                    feature_type=1,  # FeatureType.IMAGE_GENERATION
                    ai_job_id=None,  # 稍后更新
                    task_params={
                        'prompt': prompt,
                        'model': model_config.model_name,
                        'width': width,
                        'height': height,
                        'style': style,
                        **kwargs
                    }
                )
            except InsufficientCreditsError as e:
                raise ValidationError(f"积分不足: {str(e)}")
            
            # 创建任务记录
            from app.models.ai_job import AIJobType, AIJobStatus
            job = await self.job_repository.create({
                'job_type': AIJobType.IMAGE,
                'status': AIJobStatus.PENDING,
                'user_id': user_id,
                'model_id': model_config.model_id,
                'model_name': model_config.model_name,
                'consumption_log_id': consumption_log.consumption_id,
                'input_data': {
                    'prompt': prompt,
                    'width': width,
                    'height': height,
                    'style': style,
                    **kwargs
                }
            })

            # 更新 consumption_log 的 ai_job_id
            consumption_log.ai_job_id = job.ai_job_id
            consumption_log.task_id = str(job.ai_job_id)
            await self.db.commit()

        # 7. 提交异步任务
        task = generate_image_task.delay(
            job_id=job.ai_job_id,
            prompt=prompt,
            model=model_config.model_name,
            width=width,
            height=height,
            style=style,
            **kwargs
        )

        # 8. 更新任务 ID
        await self.job_repository.update(job.ai_job_id, {'task_id': task.id})

        return {
            'job_id': str(job.ai_job_id),
            'task_id': task.id,
            'status': 'pending',
            'estimated_cost': float(model_config.cost_per_unit),
            'estimated_credits': credits_needed
        }

    async def generate_video(
        self,
        user_id: UUID,
        video_type: str,
        prompt: Optional[str] = None,
        image_url: Optional[str] = None,
        duration: int = 5,
        fps: int = 30,
        model: Optional[str] = None,
        **kwargs
    ) -> Dict[str, Any]:
        """生成视频(异步)"""
        # 验证参数
        if video_type == 'text2video' and not prompt:
            raise ValidationError("文本转视频需要提供 prompt")
        if video_type == 'img2video' and not image_url:
            raise ValidationError("图片转视频需要提供 image_url")

        # 检查配额
        await self._check_quota(user_id, 'video_generation')

        # 获取模型配置(AIModelType.VIDEO = 3)
        from app.models.ai_model import AIModelType
        model_config = await self._get_model(model, AIModelType.VIDEO)

        # 计算所需积分
        feature_type = self._get_feature_type(AIJobType.VIDEO)
        credits_needed = await self.credit_service.calculate_credits(
            feature_type=feature_type,
            params={
                'model_name': model_config.model_name,
                'video_type': video_type,
                'duration': duration,
                'quality': kwargs.get('quality', 'sd')
            }
        )

        # 预扣积分(同步,事务保证)
        try:
            consumption_log = await self.credit_service.consume_credits(
                user_id=user_id,
                amount=credits_needed,
                feature_type=feature_type,
                ai_job_id=None,
                task_params={
                    'video_type': video_type,
                    'model': model_config.model_name,
                    'duration': duration,
                    'fps': fps,
                    **kwargs
                }
            )
        except InsufficientCreditsError as e:
            raise ValidationError(f"积分不足: {str(e)}")

        # 创建任务记录
        from app.models.ai_job import AIJobType, AIJobStatus
        job = await self.job_repository.create({
            'job_type': AIJobType.VIDEO,
            'status': AIJobStatus.PENDING,
            'user_id': user_id,
            'model_id': model_config.model_id,
            'model_name': model_config.model_name,
            'consumption_log_id': consumption_log.consumption_id,
            'credits_used': credits_needed,
            'input_data': {
                'video_type': video_type,
                'prompt': prompt,
                'image_url': image_url,
                'duration': duration,
                'fps': fps,
                **kwargs
            }
        })

        # 更新 consumption_log 的 ai_job_id
        consumption_log.ai_job_id = job.ai_job_id
        consumption_log.task_id = str(job.ai_job_id)
        await self.db.commit()

        # 提交异步任务
        task = generate_video_task.delay(
            job_id=job.ai_job_id,
            video_type=video_type,
            prompt=prompt,
            image_url=image_url,
            duration=duration,
            fps=fps,
            **kwargs
        )

        # 更新任务 ID
        await self.job_repository.update(job.ai_job_id, {'task_id': task.id})

        return {
            'job_id': str(job.ai_job_id),
            'task_id': task.id,
            'status': 'pending',
            'estimated_credits': credits_needed
        }

    async def generate_sound(
        self,
        user_id: UUID,
        description: str,
        duration: int = 5,
        sound_type: Optional[str] = None,
        model: Optional[str] = None,
        **kwargs
    ) -> Dict[str, Any]:
        """生成音效(异步)"""
        # 检查配额
        await self._check_quota(user_id, 'sound_generation')

        # 获取模型配置(AIModelType.AUDIO = 4)
        from app.models.ai_model import AIModelType
        model_config = await self._get_model(model, AIModelType.AUDIO)

        # 计算所需积分
        feature_type = self._get_feature_type(AIJobType.SOUND)
        credits_needed = await self.credit_service.calculate_credits(
            feature_type=feature_type,
            params={
                'model_name': model_config.model_name,
                'duration': duration
            }
        )

        # 预扣积分(同步,事务保证)
        try:
            consumption_log = await self.credit_service.consume_credits(
                user_id=user_id,
                amount=credits_needed,
                feature_type=feature_type,
                ai_job_id=None,
                task_params={
                    'description': description,
                    'model': model_config.model_name,
                    'duration': duration,
                    'sound_type': sound_type,
                    **kwargs
                }
            )
        except InsufficientCreditsError as e:
            raise ValidationError(f"积分不足: {str(e)}")

        # 创建任务记录
        from app.models.ai_job import AIJobType, AIJobStatus
        job = await self.job_repository.create({
            'job_type': AIJobType.SOUND,
            'status': AIJobStatus.PENDING,
            'user_id': user_id,
            'model_id': model_config.model_id,
            'model_name': model_config.model_name,
            'consumption_log_id': consumption_log.consumption_id,
            'credits_used': credits_needed,
            'input_data': {
                'description': description,
                'duration': duration,
                'sound_type': sound_type,
                **kwargs
            }
        })

        # 更新 consumption_log 的 ai_job_id
        consumption_log.ai_job_id = job.ai_job_id
        consumption_log.task_id = str(job.ai_job_id)
        await self.db.commit()

        # 提交异步任务
        task = generate_sound_task.delay(
            job_id=job.ai_job_id,
            description=description,
            duration=duration,
            sound_type=sound_type,
            **kwargs
        )

        # 更新任务 ID
        await self.job_repository.update(job.ai_job_id, {'task_id': task.id})

        return {
            'job_id': str(job.ai_job_id),
            'task_id': task.id,
            'status': 'pending',
            'estimated_credits': credits_needed
        }

    async def generate_voice(
        self,
        user_id: UUID,
        text: str,
        voice_type: str = 'alloy',
        speed: float = 1.0,
        language: str = 'zh-CN',
        model: Optional[str] = None,
        **kwargs
    ) -> Dict[str, Any]:
        """生成配音(异步)"""
        # 检查配额
        await self._check_quota(user_id, 'voice_generation')

        # 获取模型配置(AIModelType.AUDIO = 4)
        from app.models.ai_model import AIModelType
        model_config = await self._get_model(model, AIModelType.AUDIO)

        # 计算所需积分
        feature_type = self._get_feature_type(AIJobType.VOICE)
        credits_needed = await self.credit_service.calculate_credits(
            feature_type=feature_type,
            params={
                'model_name': model_config.model_name,
                'char_count': len(text)
            }
        )

        # 预扣积分(同步,事务保证)
        try:
            consumption_log = await self.credit_service.consume_credits(
                user_id=user_id,
                amount=credits_needed,
                feature_type=feature_type,
                ai_job_id=None,
                task_params={
                    'text': text,
                    'model': model_config.model_name,
                    'voice_type': voice_type,
                    'speed': speed,
                    'language': language,
                    **kwargs
                }
            )
        except InsufficientCreditsError as e:
            raise ValidationError(f"积分不足: {str(e)}")

        # 创建任务记录
        from app.models.ai_job import AIJobType, AIJobStatus
        job = await self.job_repository.create({
            'job_type': AIJobType.VOICE,
            'status': AIJobStatus.PENDING,
            'user_id': user_id,
            'model_id': model_config.model_id,
            'model_name': model_config.model_name,
            'consumption_log_id': consumption_log.consumption_id,
            'credits_used': credits_needed,
            'input_data': {
                'text': text,
                'voice_type': voice_type,
                'speed': speed,
                'language': language,
                **kwargs
            }
        })

        # 更新 consumption_log 的 ai_job_id
        consumption_log.ai_job_id = job.ai_job_id
        consumption_log.task_id = str(job.ai_job_id)
        await self.db.commit()

        # 提交异步任务
        task = generate_voice_task.delay(
            job_id=job.ai_job_id,
            text=text,
            voice_type=voice_type,
            speed=speed,
            language=language,
            **kwargs
        )

        # 更新任务 ID
        await self.job_repository.update(job.ai_job_id, {'task_id': task.id})

        return {
            'job_id': str(job.ai_job_id),
            'task_id': task.id,
            'status': 'pending',
            'estimated_credits': credits_needed
        }

    async def generate_subtitle(
        self,
        user_id: UUID,
        audio_url: str,
        language: str = 'zh',
        model: Optional[str] = None,
        **kwargs
    ) -> Dict[str, Any]:
        """生成字幕(异步)"""
        # 检查配额
        await self._check_quota(user_id, 'subtitle_generation')

        # 获取模型配置(AIModelType.AUDIO = 4)
        from app.models.ai_model import AIModelType
        model_config = await self._get_model(model, AIModelType.AUDIO)

        # 计算所需积分(假设音频时长为 60 秒,实际应从音频文件获取)
        duration = kwargs.get('duration', 60)
        feature_type = self._get_feature_type(AIJobType.SUBTITLE)
        credits_needed = await self.credit_service.calculate_credits(
            feature_type=feature_type,
            params={
                'model_name': model_config.model_name,
                'duration': duration
            }
        )

        # 预扣积分(同步,事务保证)
        try:
            consumption_log = await self.credit_service.consume_credits(
                user_id=user_id,
                amount=credits_needed,
                feature_type=feature_type,
                ai_job_id=None,
                task_params={
                    'audio_url': audio_url,
                    'model': model_config.model_name,
                    'language': language,
                    'duration': duration,
                    **kwargs
                }
            )
        except InsufficientCreditsError as e:
            raise ValidationError(f"积分不足: {str(e)}")

        # 创建任务记录
        from app.models.ai_job import AIJobType, AIJobStatus
        job = await self.job_repository.create({
            'job_type': AIJobType.SUBTITLE,
            'status': AIJobStatus.PENDING,
            'user_id': user_id,
            'model_id': model_config.model_id,
            'model_name': model_config.model_name,
            'consumption_log_id': consumption_log.consumption_id,
            'credits_used': credits_needed,
            'input_data': {
                'audio_url': audio_url,
                'language': language,
                'duration': duration,
                **kwargs
            }
        })

        # 更新 consumption_log 的 ai_job_id
        consumption_log.ai_job_id = job.ai_job_id
        consumption_log.task_id = str(job.ai_job_id)
        await self.db.commit()

        # 提交异步任务
        task = generate_subtitle_task.delay(
            job_id=job.ai_job_id,
            audio_url=audio_url,
            language=language,
            **kwargs
        )

        # 更新任务 ID
        await self.job_repository.update(job.ai_job_id, {'task_id': task.id})

        return {
            'job_id': str(job.ai_job_id),
            'task_id': task.id,
            'status': 'pending',
            'estimated_credits': credits_needed
        }

    async def process_text(
        self,
        user_id: UUID,
        task_type: str,
        text: str,
        model: Optional[str] = None,
        **kwargs
    ) -> Dict[str, Any]:
        """处理文本(异步)
        
        Args:
            user_id: 用户 ID
            task_type: 任务类型
                - screenplay_parse: 剧本拆解
                - content_analysis: 内容分析
                - style_transform: 风格转换
                - prompt_generation: 提示词生成
            text: 待处理的文本
            model: 使用的模型(默认使用配置的模型)
            **kwargs: 其他参数
                - output_format: 输出格式 (json/text)
                - temperature: 生成温度 (0.0-2.0)
                - max_tokens: 最大 token 数
        """
        # 验证任务类型
        valid_task_types = [
            'screenplay_parse',
            'content_analysis', 
            'style_transform',
            'prompt_generation'
        ]
        if task_type not in valid_task_types:
            raise ValidationError(f"不支持的任务类型: {task_type}")

        # 创建任务记录
        from app.models.ai_job import AIJobType, AIJobStatus
        job = await self.job_repository.create({
            'job_type': AIJobType.TEXT_PROCESSING,
            'status': AIJobStatus.PENDING,
            'user_id': user_id,
            'input_data': {
                'task_type': task_type,
                'text': text,
                'model': model or 'gpt-4',
                'output_format': kwargs.get('output_format', 'json'),
                'temperature': kwargs.get('temperature', 0.7),
                'max_tokens': kwargs.get('max_tokens', 4000),
                **kwargs
            }
        })

        # 提交异步任务
        from app.tasks.ai_tasks import process_text_task
        task = process_text_task.delay(
            job_id=job.id,
            task_type=task_type,
            text=text,
            model=model or 'gpt-4',
            **kwargs
        )

        return {
            'job_id': str(job.id),
            'task_id': task.id,
            'status': 'pending'
        }

    async def get_job_status(self, job_id: UUID) -> Dict[str, Any]:
        """查询任务状态"""
        job = await self.job_repository.get_by_id(job_id)
        if not job:
            from app.core.exceptions import NotFoundError
            raise NotFoundError("任务不存在")

        return {
            'job_id': str(job.id),
            'job_type': job.job_type,
            'status': job.status,
            'progress': job.progress,
            'input_data': job.input_data,
            'output_data': job.output_data,
            'error_message': job.error_message,
            'model_name': job.model_name,
            'cost': float(job.cost) if job.cost else None,
            'credits_used': job.credits_used,
            'created_at': job.created_at.isoformat(),
            'updated_at': job.updated_at.isoformat(),
            'started_at': job.started_at.isoformat() if job.started_at else None,
            'completed_at': job.completed_at.isoformat() if job.completed_at else None
        }

    async def cancel_job(self, user_id: UUID, job_id: UUID) -> None:
        """取消任务"""
        job = await self.job_repository.get_by_id(job_id)
        if not job:
            from app.core.exceptions import NotFoundError
            raise NotFoundError("任务不存在")

        if job.user_id != user_id:
            from app.core.exceptions import PermissionDeniedError
            raise PermissionDeniedError("没有权限取消此任务")

        from app.models.ai_job import AIJobStatus
        if job.status in [AIJobStatus.COMPLETED, AIJobStatus.FAILED]:
            raise ValidationError("任务已完成或失败,无法取消")

        # 取消 Celery 任务
        from app.tasks.celery_app import celery_app
        if job.task_id:
            celery_app.control.revoke(job.task_id, terminate=True)

        # 更新任务状态
        from app.models.ai_job import AIJobStatus
        await self.job_repository.update(job_id, {
            'status': AIJobStatus.CANCELLED,
            'error_message': '用户取消'
        })

    async def get_user_usage_stats(
        self,
        user_id: UUID,
        start_date: Optional[str] = None,
        end_date: Optional[str] = None
    ) -> Dict[str, Any]:
        """获取用户使用统计"""
        stats = await self.usage_log_repository.get_user_stats(
            user_id, start_date, end_date
        )
        
        quotas = await self.quota_repository.get_user_quotas(user_id)
        
        return {
            'total_cost': float(stats.get('total_cost', 0)),
            'total_credits_used': stats.get('total_credits_used', 0),
            'total_requests': stats.get('total_requests', 0),
            'by_model': stats.get('by_model', {}),
            'by_type': stats.get('by_type', {}),
            'quotas': [
                {
                    'quota_type': q.quota_type,
                    'period': q.period,
                    'total_quota': q.total_quota,
                    'used_quota': q.used_quota,
                    'remaining_quota': q.total_quota - q.used_quota,
                    'reset_at': q.reset_at.isoformat()
                }
                for q in quotas
            ]
        }

    async def get_available_models(
        self,
        model_type: Optional[int] = None
    ) -> list[Dict[str, Any]]:
        """获取可用的 AI 模型列表
        
        Args:
            model_type: 模型类型(AIModelType 枚举值,可选)
        """
        models = await self.model_repository.get_active_models(model_type)
        
        return [
            {
                'model_id': str(m.model_id),
                'model_name': m.model_name,
                'display_name': m.display_name,
                'description': m.description,
                'provider': m.provider,
                'model_type': m.model_type,
                'cost_per_unit': float(m.cost_per_unit),
                'unit_type': m.unit_type,
                'credits_per_unit': m.credits_per_unit,
                'is_beta': m.is_beta,
                'config': m.config
            }
            for m in models
        ]

API 接口

1. 创建 AI 任务(统一入口)

POST /api/v1/ai/jobs

请求体(图片生成)

{
  "jobType": "image",
  "prompt": "一只可爱的猫咪在花园里玩耍",
  "model": "stable_diffusion",
  "width": 1024,
  "height": 1024,
  "style": "realistic"
}

请求体(视频生成 - 文本转视频)

{
  "jobType": "video",
  "videoType": "text2video",
  "prompt": "一只猫咪在花园里奔跑",
  "duration": 5,
  "fps": 30
}

请求体(视频生成 - 图片转视频)

{
  "jobType": "video",
  "videoType": "img2video",
  "imageUrl": "https://example.com/image.jpg",
  "duration": 5,
  "fps": 30
}

请求体(音效生成)

{
  "jobType": "sound",
  "description": "雨声",
  "duration": 10,
  "soundType": "ambient"
}

请求体(配音生成)

{
  "jobType": "voice",
  "text": "欢迎来到Jointo平台",
  "voiceType": "alloy",
  "speed": 1.0,
  "language": "zh-CN"
}

请求体(字幕生成)

{
  "jobType": "subtitle",
  "audioUrl": "https://example.com/audio.mp3",
  "language": "zh"
}

请求体(文本处理 - 剧本拆解)

{
  "jobType": "textProcessing",
  "taskType": "screenplay_parse",
  "text": "场景1:咖啡厅 - 白天\n小明走进咖啡厅,看到小红坐在窗边...",
  "model": "gpt-4",
  "outputFormat": "json",
  "temperature": 0.7
}

响应

{
  "code": 200,
  "message": "Success",
  "data": {
    "jobId": "019d1234-5678-7abc-def0-111111111111",
    "taskId": "abc-123-def",
    "status": "pending",
    "estimatedCredits": 10
  }
}

错误响应

// 400 - 参数错误
{
  "code": 400,
  "message": "参数验证失败: prompt 不能为空",
  "data": null
}

// 402 - 积分不足
{
  "code": 402,
  "message": "积分不足,当前余额: 5,需要: 10",
  "data": null
}

// 429 - 配额超限
{
  "code": 429,
  "message": "每日配额已用完,请明天再试",
  "data": null
}

// 500 - 服务器错误
{
  "code": 500,
  "message": "AI 服务暂时不可用,请稍后重试",
  "data": null
}

2. 查询任务状态

GET /api/v1/ai/jobs/{job_id}

响应

{
  "code": 200,
  "message": "Success",
  "data": {
    "jobId": "019d1234-5678-7abc-def0-111111111111",
    "jobType": "image",
    "status": "completed",
    "progress": 100,
    "outputData": {
      "fileUrl": "https://storage.jointo.ai/ai-generated/images/abc123def456.png",
      "fileSize": 1024000,
      "checksum": "abc123def456...",
      "mimeType": "image/png",
      "meta_data": {
        "width": 1024,
        "height": 1024,
        "format": "png"
      }
    },
    "createdAt": "2026-01-27T10:00:00Z",
    "completedAt": "2026-01-27T10:00:30Z"
  }
}

3. 获取任务列表

GET /api/v1/ai/jobs

查询参数

  • jobType(可选):任务类型(image/video/sound/voice/subtitle/textProcessing)
  • status(可选):任务状态(pending/processing/completed/failed/cancelled)
  • page(可选):页码(默认 1)
  • pageSize(可选):每页数量(默认 20)

响应

{
  "code": 200,
  "message": "Success",
  "data": {
    "items": [
      {
        "jobId": "019d1234-5678-7abc-def0-111111111111",
        "jobType": "image",
        "status": "completed",
        "createdAt": "2026-01-27T10:00:00Z"
      }
    ],
    "total": 100,
    "page": 1,
    "pageSize": 20,
    "totalPages": 5
  }
}

4. 取消任务

POST /api/v1/ai/jobs/{job_id}/cancel

响应

{
  "code": 200,
  "message": "任务已取消",
  "data": null
}

5. 获取用户使用统计

5. 获取用户使用统计

GET /api/v1/ai/usage/stats

查询参数

  • startDate(可选):开始日期(YYYY-MM-DD)
  • endDate(可选):结束日期(YYYY-MM-DD)

响应

{
  "code": 200,
  "message": "Success",
  "data": {
    "totalCost": 12.50,
    "totalCreditsUsed": 250,
    "totalRequests": 45,
    "byModel": {
      "gpt-4": {
        "requests": 20,
        "cost": 8.00,
        "creditsUsed": 160
      },
      "dall-e-3": {
        "requests": 15,
        "cost": 3.50,
        "creditsUsed": 75
      }
    },
    "byType": {
      "textProcessing": {
        "requests": 20,
        "cost": 8.00
      },
      "image": {
        "requests": 15,
        "cost": 3.50
      }
    },
    "quotas": [
      {
        "quotaType": "imageGeneration",
        "period": "daily",
        "totalQuota": 50,
        "usedQuota": 15,
        "remainingQuota": 35,
        "resetAt": "2026-01-28T00:00:00Z"
      },
      {
        "quotaType": "textProcessing",
        "period": "monthly",
        "totalQuota": 1000,
        "usedQuota": 250,
        "remainingQuota": 750,
        "resetAt": "2026-02-01T00:00:00Z"
      }
    ]
  }
}

6. 获取可用模型列表

GET /api/v1/ai/models

查询参数

  • type(可选):模型类型(text/image/video/audio)

响应

{
  "code": 200,
  "message": "Success",
  "data": {
    "items": [
      {
        "modelId": 1,
        "modelName": "gpt-4",
        "displayName": "GPT-4",
        "description": "OpenAI 最强大的语言模型",
        "provider": "openai",
        "modelType": "text",
        "costPerUnit": 0.03,
        "unitType": "token",
        "creditsPerUnit": 1,
        "isBeta": false,
        "config": {
          "maxTokens": 8000
        }
      },
      {
        "modelId": 4,
        "modelName": "dall-e-3",
        "displayName": "DALL-E 3",
        "description": "OpenAI 图片生成模型",
        "provider": "openai",
        "modelType": "image",
        "costPerUnit": 0.04,
        "unitType": "image",
        "creditsPerUnit": 10,
        "isBeta": false,
        "config": {
          "supportedResolutions": ["1024x1024", "1792x1024", "1024x1792"]
        }
      }
    ],
    "total": 2,
    "page": 1,
    "pageSize": 20,
    "totalPages": 1
  }
}

文件存储集成

对象存储服务

AI 生成的文件需要存储到对象存储服务(OSS),系统通过 FileStorageService 提供统一的存储接口。

存储策略

  • 开发环境:使用 MinIO(本地对象存储)
  • 生产环境:可切换到云服务商 OSS
    • 阿里云 OSS
    • AWS S3
    • 腾讯云 COS
    • 七牛云 Kodo

配置方式

通过环境变量 STORAGE_PROVIDER 切换存储服务:

# .env
STORAGE_PROVIDER=minio  # 开发环境
# STORAGE_PROVIDER=aliyun  # 生产环境(阿里云 OSS)
# STORAGE_PROVIDER=aws  # 生产环境(AWS S3)

文件存储流程

AI 任务完成后,需要将 AI 提供商返回的临时文件下载并存储到自有 OSS:

# AI Task 工作流(集成文件存储)
async def generate_image_task(job_id: str, ...):
    try:
        # 1. 调用 AI Provider 生成图片
        result = await provider.generate_image(...)
        # result = {"image_url": "https://ai-provider.com/temp/abc123.png"}
        
        # 2. 下载 AI 生成的文件
        async with httpx.AsyncClient() as client:
            response = await client.get(result['image_url'])
            image_data = response.content
        
        # 3. 上传到 OSS(带去重)
        async with async_session_maker() as session:
            file_storage = FileStorageService(session)
            file_meta_data = await file_storage.upload_file(
                file_content=image_data,
                filename=f"{job_id}.png",
                content_type="image/png",
                category="ai-generated/images",
                user_id=user_id
            )
        
        # 4. 更新任务结果(使用自有 URL)
        output_data = {
            "file_url": file_meta_data.file_url,  # 自有 OSS URL
            "file_size": file_meta_data.file_size,
            "checksum": file_meta_data.checksum,
            "mime_type": file_meta_data.mime_type,
            "original_url": result['image_url'],  # 保留原始 URL(可选)
            "meta_data": {
                "width": result.get('width'),
                "height": result.get('height'),
                "format": "png"
            }
        }
        
        await _update_job_status(
            job_id, 
            AIJobStatus.COMPLETED, 
            progress=100, 
            output_data=output_data
        )
        
        # 5. 确认积分消耗
        await _confirm_or_refund_credits(...)
        
    except Exception as e:
        # 任务失败,退还积分
        await _update_job_status(job_id, AIJobStatus.FAILED, error_message=str(e))
        await _confirm_or_refund_credits(..., success=False, error_message=str(e))

优势

  1. 数据主权:AI 生成的内容存储在自有系统中
  2. 稳定性:不依赖第三方临时 URL(如 OpenAI 的临时 URL 只有 1 小时有效期)
  3. 去重优化:通过文件校验和(SHA256)自动去重,节省存储成本
  4. 灵活切换:开发环境使用 MinIO,生产环境可无缝切换到云服务商 OSS
  5. 统一管理:所有文件通过 FileStorageService 统一管理,支持引用计数和自动清理

FileStorageService 集成

AI Service 需要注入 FileStorageService 依赖:

# app/services/ai_service.py
from app.services.file_storage_service import FileStorageService

class AIService:
    def __init__(self, db: Session):
        self.db = db
        self.job_repository = AIJobRepository(db)
        self.model_repository = AIModelRepository(db)
        self.credit_service = CreditService(db)
        self.file_storage = FileStorageService(db)  # 注入文件存储服务

详见:File Storage Service 文档


AI 提供商集成

提供商抽象层

# app/services/ai_providers/base.py
from abc import ABC, abstractmethod
from typing import Dict, Any

class AIProvider(ABC):
    @abstractmethod
    async def generate_image(self, prompt: str, **kwargs) -> Dict[str, Any]:
        """生成图片"""
        pass

    @abstractmethod
    async def generate_video(self, prompt: str, **kwargs) -> Dict[str, Any]:
        """生成视频"""
        pass

Stable Diffusion 提供商

# app/services/ai_providers/stable_diffusion.py
import httpx
from app.services.ai_providers.base import AIProvider

class StableDiffusionProvider(AIProvider):
    def __init__(self, api_key: str, api_url: str):
        self.api_key = api_key
        self.api_url = api_url

    async def generate_image(self, prompt: str, **kwargs) -> Dict[str, Any]:
        """调用 Stable Diffusion API 生成图片"""
        async with httpx.AsyncClient() as client:
            response = await client.post(
                f"{self.api_url}/generate",
                json={
                    "prompt": prompt,
                    "width": kwargs.get('width', 1024),
                    "height": kwargs.get('height', 1024),
                    "steps": kwargs.get('steps', 50),
                    "cfg_scale": kwargs.get('cfg_scale', 7.5)
                },
                headers={"Authorization": f"Bearer {self.api_key}"},
                timeout=300
            )
            return response.json()

    async def generate_video(self, prompt: str, **kwargs) -> Dict[str, Any]:
        """Stable Diffusion 不支持视频生成"""
        raise NotImplementedError("Stable Diffusion 不支持视频生成")

OpenAI 提供商

# app/services/ai_providers/openai.py
import openai
from app.services.ai_providers.base import AIProvider

class OpenAIProvider(AIProvider):
    def __init__(self, api_key: str):
        self.api_key = api_key
        openai.api_key = api_key

    async def generate_image(self, prompt: str, **kwargs) -> Dict[str, Any]:
        """调用 DALL-E API 生成图片"""
        response = await openai.Image.acreate(
            prompt=prompt,
            n=1,
            size=f"{kwargs.get('width', 1024)}x{kwargs.get('height', 1024)}"
        )
        return {
            'image_url': response['data'][0]['url']
        }

    async def generate_voice(self, text: str, **kwargs) -> Dict[str, Any]:
        """调用 OpenAI TTS API 生成配音"""
        response = await openai.Audio.create(
            model="tts-1",
            voice=kwargs.get('voice_type', 'alloy'),
            input=text,
            speed=kwargs.get('speed', 1.0)
        )
        return {
            'audio_url': response['url']
        }

    async def process_text(self, task_type: str, text: str, **kwargs) -> Dict[str, Any]:
        """调用 GPT API 处理文本"""
        # 根据任务类型构建系统提示词
        system_prompts = {
            'screenplay_parse': """你是一个专业的剧本分析助手。请将输入的剧本文本拆解为结构化的场景和镜头信息。
输出 JSON 格式,包含:scenes (场景列表),每个场景包含 scene_number, location, time, description, characters, shots。
每个 shot 包含 shot_number, shot_size, camera_movement, description, duration。""",
            
            'content_analysis': """你是一个内容分析专家。请分析输入文本,提取关键信息。
输出 JSON 格式,包含:characters (人物列表), locations (地点列表), timeline (时间线), themes (主题), emotions (情感)。""",
            
            'style_transform': """你是一个文本风格转换专家。请根据要求调整文本的风格和语气。
保持原意,但改变表达方式。""",
            
            'prompt_generation': """你是一个 AI 绘画提示词专家。请将输入的场景描述转换为详细的 AI 绘画提示词。
包含:主体、环境、光线、色调、风格、镜头角度等要素。输出英文提示词。"""
        }

        messages = [
            {"role": "system", "content": system_prompts.get(task_type, "")},
            {"role": "user", "content": text}
        ]

        response = await openai.ChatCompletion.acreate(
            model=kwargs.get('model', 'gpt-4'),
            messages=messages,
            temperature=kwargs.get('temperature', 0.7),
            max_tokens=kwargs.get('max_tokens', 4000)
        )

        content = response.choices[0].message.content

        # 如果要求 JSON 格式,尝试解析
        if kwargs.get('output_format') == 'json':
            import json
            try:
                content = json.loads(content)
            except json.JSONDecodeError:
                # 如果解析失败,尝试提取 JSON 部分
                import re
                json_match = re.search(r'\{.*\}', content, re.DOTALL)
                if json_match:
                    content = json.loads(json_match.group())

        return {
            'result': content,
            'model': response.model,
            'usage': {
                'prompt_tokens': response.usage.prompt_tokens,
                'completion_tokens': response.usage.completion_tokens,
                'total_tokens': response.usage.total_tokens
            }
        }

提供商工厂

# app/services/ai_service_factory.py
from app.services.ai_providers.stable_diffusion import StableDiffusionProvider
from app.services.ai_providers.openai import OpenAIProvider
from app.config import settings

class AIServiceFactory:
    @staticmethod
    def create_provider(provider_type: str):
        """创建 AI 提供商实例"""
        if provider_type == 'stable_diffusion':
            return StableDiffusionProvider(
                api_key=settings.STABLE_DIFFUSION_API_KEY,
                api_url=settings.STABLE_DIFFUSION_API_URL
            )
        elif provider_type == 'openai':
            return OpenAIProvider(
                api_key=settings.OPENAI_API_KEY
            )
        else:
            raise ValueError(f"Unknown provider: {provider_type}")

相关文档


附录:Repository 实现示例

基础 Repository(引用完整性保证)

# app/repositories/base_repository.py
from typing import Optional, TypeVar, Generic
from uuid import UUID
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession

T = TypeVar('T')

class BaseRepository(Generic[T]):
    """基础 Repository - 提供引用完整性验证"""
    
    def __init__(self, db: AsyncSession, model: type[T]):
        self.db = db
        self.model = model
    
    async def exists(self, id: UUID) -> bool:
        """检查记录是否存在(排除软删除)
        
        这是应用层保证引用完整性的核心方法
        """
        query = select(self.model.id).where(
            self.model.id == id
        ).limit(1)
        
        # 如果模型有 deleted_at 字段,排除软删除记录
        if hasattr(self.model, 'deleted_at'):
            query = query.where(self.model.deleted_at.is_(None))
        
        result = await self.db.execute(query)
        return result.scalar_one_or_none() is not None
    
    async def get_by_id(self, id: UUID) -> Optional[T]:
        """根据 ID 获取记录"""
        query = select(self.model).where(self.model.id == id)
        
        if hasattr(self.model, 'deleted_at'):
            query = query.where(self.model.deleted_at.is_(None))
        
        result = await self.db.execute(query)
        return result.scalar_one_or_none()

AIJobRepository

# app/repositories/ai_job_repository.py
from typing import List, Optional
from uuid import UUID
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.ai_job import AIJob, AIJobType, AIJobStatus
from app.repositories.base_repository import BaseRepository

class AIJobRepository(BaseRepository[AIJob]):
    """AI 任务数据访问层"""
    
    def __init__(self, db: AsyncSession):
        super().__init__(db, AIJob)
    
    async def create(self, job_data: dict) -> AIJob:
        """创建 AI 任务"""
        job = AIJob(**job_data)
        self.db.add(job)
        await self.db.commit()
        await self.db.refresh(job)
        return job
    
    async def update(self, job_id: UUID, update_data: dict) -> AIJob:
        """更新 AI 任务"""
        job = await self.get_by_id(job_id)
        if not job:
            raise ValueError("任务不存在")
        
        for key, value in update_data.items():
            setattr(job, key, value)
        
        await self.db.commit()
        await self.db.refresh(job)
        return job
    
    async def get_user_jobs(
        self,
        user_id: UUID,
        job_type: Optional[int] = None,
        status: Optional[int] = None,
        page: int = 1,
        page_size: int = 20
    ) -> List[AIJob]:
        """获取用户的 AI 任务列表"""
        query = select(AIJob).where(AIJob.user_id == user_id)
        
        if job_type:
            query = query.where(AIJob.job_type == job_type)
        if status:
            query = query.where(AIJob.status == status)
        
        query = query.order_by(AIJob.created_at.desc())
        query = query.offset((page - 1) * page_size).limit(page_size)
        
        result = await self.db.execute(query)
        return result.scalars().all()

AIModelRepository

# app/repositories/ai_model_repository.py
from typing import Optional, List
from uuid import UUID
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.ai_model import AIModel, AIModelType
from app.repositories.base_repository import BaseRepository
from app.core.exceptions import ValidationError

class AIModelRepository(BaseRepository[AIModel]):
    """AI 模型数据访问层"""
    
    def __init__(self, db: AsyncSession):
        super().__init__(db, AIModel)

    async def get_by_name(self, model_name: str) -> Optional[AIModel]:
        """根据模型名称获取模型"""
        query = select(AIModel).where(AIModel.model_name == model_name)
        result = await self.db.execute(query)
        return result.scalar_one_or_none()

    async def get_default_model(self, model_type: int) -> Optional[AIModel]:
        """获取指定类型的默认模型(最便宜的活跃模型)
        
        Args:
            model_type: 模型类型(AIModelType 枚举值)
        """
        # 验证模型类型枚举值
        if model_type not in AIModelType.__members__.values():
            raise ValidationError(f"无效的模型类型: {model_type}")

        query = select(AIModel).where(
            AIModel.model_type == model_type,
            AIModel.is_active == True
        ).order_by(AIModel.cost_per_unit.asc())
        
        result = await self.db.execute(query)
        return result.scalar_one_or_none()

    async def get_active_models(
        self,
        model_type: Optional[int] = None
    ) -> List[AIModel]:
        """获取所有活跃模型
        
        Args:
            model_type: 模型类型(AIModelType 枚举值,可选)
        """
        # 如果指定了模型类型,验证枚举值
        if model_type is not None and model_type not in AIModelType.__members__.values():
            raise ValidationError(f"无效的模型类型: {model_type}")

        query = select(AIModel).where(AIModel.is_active == True)
        
        if model_type:
            query = query.where(AIModel.model_type == model_type)
        
        query = query.order_by(AIModel.model_type, AIModel.cost_per_unit)
        
        result = await self.db.execute(query)
        return result.scalars().all()

AIQuotaRepository

# app/repositories/ai_quota_repository.py
from typing import Optional, List
from uuid import UUID
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.ai_quota import AIQuota, QuotaPeriod
from app.repositories.base_repository import BaseRepository
from datetime import datetime, timedelta

class AIQuotaRepository(BaseRepository[AIQuota]):
    """AI 配额数据访问层"""
    
    def __init__(self, db: AsyncSession):
        super().__init__(db, AIQuota)

    async def get_user_quota(
        self,
        user_id: UUID,
        quota_type: str,
        period: int = QuotaPeriod.DAILY
    ) -> Optional[AIQuota]:
        """获取用户指定类型的配额
        
        Args:
            user_id: 用户 ID
            quota_type: 配额类型
            period: 配额周期(QuotaPeriod 枚举值,默认 1=DAILY)
        """
        query = select(AIQuota).where(
            AIQuota.user_id == user_id,
            AIQuota.quota_type == quota_type,
            AIQuota.period == period
        )
        
        result = await self.db.execute(query)
        return result.scalar_one_or_none()

    async def get_user_quotas(self, user_id: UUID) -> List[AIQuota]:
        """获取用户所有配额"""
        query = select(AIQuota).where(AIQuota.user_id == user_id)
        result = await self.db.execute(query)
        return result.scalars().all()

    async def increment_usage(
        self,
        user_id: UUID,
        quota_type: str,
        amount: int = 1
    ) -> None:
        """增加配额使用量"""
        # 更新每日配额
        daily_quota = await self.get_user_quota(user_id, quota_type, QuotaPeriod.DAILY)
        if daily_quota:
            daily_quota.used_quota += amount
            await self.db.commit()
        
        # 更新每月配额
        monthly_quota = await self.get_user_quota(user_id, quota_type, QuotaPeriod.MONTHLY)
        if monthly_quota:
            monthly_quota.used_quota += amount
            await self.db.commit()

    async def create_default_quotas(self, user_id: UUID) -> None:
        """为新用户创建默认配额"""
        default_quotas = [
            {
                'user_id': user_id,
                'quota_type': 'image_generation',
                'period': QuotaPeriod.DAILY,
                'total_quota': 50,
                'reset_at': datetime.now(timezone.utc) + timedelta(days=1)
            },
            {
                'user_id': user_id,
                'quota_type': 'image_generation',
                'period': QuotaPeriod.MONTHLY,
                'total_quota': 1000,
                'reset_at': datetime.now(timezone.utc) + timedelta(days=30)
            },
            {
                'user_id': user_id,
                'quota_type': 'text_processing',
                'period': QuotaPeriod.DAILY,
                'total_quota': 100,
                'reset_at': datetime.now(timezone.utc) + timedelta(days=1)
            },
            {
                'user_id': user_id,
                'quota_type': 'text_processing',
                'period': QuotaPeriod.MONTHLY,
                'total_quota': 2000,
                'reset_at': datetime.now(timezone.utc) + timedelta(days=30)
            }
        ]
        
        for quota_data in default_quotas:
            quota = AIQuota(**quota_data)
            self.db.add(quota)
        
        await self.db.commit()

文档版本:v1.0
最后更新:2025-01-27


附录:Python 枚举定义

app/models/ai_job.py

from enum import IntEnum
from sqlmodel import SQLModel, Field
from sqlalchemy import Column, SmallInteger, Index, text
from sqlalchemy.dialects.postgresql import UUID as PG_UUID, JSONB
from datetime import datetime, timezone
from uuid import UUID
from typing import Optional
from app.utils.id_generator import generate_uuid

class AIJobType(IntEnum):
    """AI 任务类型"""
    IMAGE = 1              # 图片生成
    VIDEO = 2              # 视频生成
    SOUND = 3              # 音效生成
    VOICE = 4              # 配音生成
    SUBTITLE = 5           # 字幕生成
    TEXT_PROCESSING = 6    # 文本处理(剧本拆解等)
    RESOURCE = 7           # 资源生成
    STORYBOARD_SCRIPT = 8  # 分镜脚本生成
    SCRIPT_GENERATION = 9  # 剧本生成

class AIJobStatus(IntEnum):
    """AI 任务状态"""
    PENDING = 1      # 等待处理
    PROCESSING = 2   # 处理中
    COMPLETED = 3    # 已完成
    FAILED = 4       # 失败
    CANCELLED = 5    # 已取消

class AIJob(SQLModel, table=True):
    """AI 任务表 - 不使用外键约束,应用层保证引用完整性"""
    __tablename__ = "ai_jobs"
    
    # 主键
    ai_job_id: UUID = Field(
        sa_column=Column(
            PG_UUID(as_uuid=True),
            primary_key=True,
            default=generate_uuid
        )
    )
    
    # 关联字段(无外键约束,仅索引)
    user_id: UUID = Field(
        sa_column=Column(PG_UUID(as_uuid=True), nullable=False, index=True),
        description="用户 ID - 应用层验证"
    )
    project_id: Optional[UUID] = Field(
        default=None,
        sa_column=Column(PG_UUID(as_uuid=True), nullable=True, index=True),
        description="项目 ID - 应用层验证"
    )
    storyboard_id: Optional[UUID] = Field(
        default=None,
        sa_column=Column(PG_UUID(as_uuid=True), nullable=True, index=True),
        description="分镜 ID - 应用层验证"
    )
    consumption_log_id: Optional[UUID] = Field(
        default=None,
        sa_column=Column(PG_UUID(as_uuid=True), nullable=True, index=True),
        description="积分消耗日志 ID - 应用层验证"
    )
    model_id: Optional[UUID] = Field(
        default=None,
        sa_column=Column(PG_UUID(as_uuid=True), nullable=True, index=True),
        description="AI 模型 ID - 应用层验证"
    )
    
    # 任务信息(使用 SMALLINT 存储枚举)
    job_type: int = Field(
        sa_column=Column(SmallInteger, nullable=False),
        description="任务类型:1=图片 2=视频 3=音效 4=配音 5=字幕 6=文本处理 7=资源 8=分镜脚本 9=剧本生成"
    )
    status: int = Field(
        default=AIJobStatus.PENDING,
        sa_column=Column(SmallInteger, nullable=False, default=1, index=True),
        description="任务状态:1=等待处理 2=处理中 3=已完成 4=失败 5=已取消"
    )
    
    # JSONB 字段
    input_data: dict = Field(
        default_factory=dict,
        sa_column=Column(JSONB, nullable=False, default={}),
        description="输入参数"
    )
    output_data: Optional[dict] = Field(
        default=None,
        sa_column=Column(JSONB, nullable=True),
        description="输出结果"
    )
    
    # 其他字段
    model_name: Optional[str] = Field(default=None, max_length=255)
    progress: int = Field(default=0, ge=0, le=100)
    error_message: Optional[str] = Field(default=None)
    task_id: Optional[str] = Field(default=None)
    cost: Optional[float] = Field(default=None)
    credits_used: int = Field(default=0)
    
    # 时间戳
    estimated_completion_at: Optional[datetime] = Field(default=None)
    started_at: Optional[datetime] = Field(default=None)
    completed_at: Optional[datetime] = Field(default=None)
    created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
    updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
    
    # 表级索引
    __table_args__ = (
        Index('idx_ai_jobs_status_created_at', 'status', 'created_at',
              postgresql_where=text('status IN (1, 2)')),
        Index('idx_ai_jobs_input_data_gin', 'input_data', postgresql_using='gin'),
        Index('idx_ai_jobs_output_data_gin', 'output_data', 
              postgresql_using='gin',
              postgresql_where=text('output_data IS NOT NULL')),
    )

app/models/ai_model.py

from enum import IntEnum

class AIModelType(IntEnum):
    """AI 模型类型"""
    TEXT = 1    # 文本模型(GPT, Claude 等)
    IMAGE = 2   # 图片模型(DALL-E, Stable Diffusion 等)
    VIDEO = 3   # 视频模型(Runway, Pika 等)
    AUDIO = 4   # 音频模型(TTS, STT 等)

class AIProvider(IntEnum):
    """AI 提供商"""
    OPENAI = 1
    ANTHROPIC = 2
    GOOGLE = 3       # Google Gemini
    STABILITY = 4
    RUNWAY = 5
    PIKA = 6
    ELEVENLABS = 7
    AZURE = 8
    BAIDU = 9
    ALIYUN = 10
    CUSTOM = 99      # 自定义提供商

class UnitType(IntEnum):
    """计费单位类型"""
    TOKEN = 1    # Token(文本模型)
    IMAGE = 2    # 图片(图片模型)
    SECOND = 3   # 秒(视频/音频模型)
    REQUEST = 4  # 请求(通用计费单位)

app/models/ai_quota.py

from enum import IntEnum

class QuotaPeriod(IntEnum):
    """配额周期"""
    DAILY = 1    # 每日配额
    MONTHLY = 2  # 每月配额
    TOTAL = 3    # 总配额

测试规范

单元测试

# server/tests/unit/test_ai_service.py
import pytest
from uuid import UUID
from unittest.mock import AsyncMock, MagicMock, patch
from app.services.ai_service import AIService
from app.models.ai_job import AIJobType, AIJobStatus
from app.models.ai_model import AIModelType
from app.core.exceptions import ValidationError, InsufficientCreditsError, NotFoundError

@pytest.fixture
def mock_db():
    """模拟数据库会话"""
    db = AsyncMock()
    db.begin = AsyncMock()
    db.commit = AsyncMock()
    db.rollback = AsyncMock()
    return db

@pytest.fixture
def ai_service(mock_db):
    """创建 AI Service 实例"""
    return AIService(mock_db)

class TestAIService:
    """AI Service 单元测试"""
    
    @pytest.mark.asyncio
    async def test_generate_image_success(self, ai_service, mock_db):
        """测试图片生成成功"""
        # 准备测试数据
        user_id = UUID('019d1234-5678-7abc-def0-111111111111')
        prompt = "一只可爱的猫咪"
        
        # Mock 依赖
        with patch.object(ai_service, '_check_quota', return_value=True), \
             patch.object(ai_service, '_get_model') as mock_get_model, \
             patch.object(ai_service.job_repository, 'create') as mock_create_job, \
             patch.object(ai_service.credit_service, 'consume_credits') as mock_consume:
            
            # 设置 mock 返回值
            mock_model = MagicMock()
            mock_model.model_id = UUID('019d1234-5678-7abc-def0-222222222222')
            mock_model.model_name = 'stable_diffusion'
            mock_model.credits_per_unit = 10
            mock_model.cost_per_unit = 0.5
            mock_get_model.return_value = mock_model
            
            mock_consumption = MagicMock()
            mock_consumption.consumption_id = UUID('019d1234-5678-7abc-def0-333333333333')
            mock_consume.return_value = mock_consumption
            
            mock_job = MagicMock()
            mock_job.ai_job_id = UUID('019d1234-5678-7abc-def0-444444444444')
            mock_create_job.return_value = mock_job
            
            # 执行测试
            result = await ai_service.generate_image(
                user_id=user_id,
                prompt=prompt,
                width=1024,
                height=1024
            )
            
            # 验证结果
            assert 'job_id' in result
            assert 'task_id' in result
            assert result['status'] == 'pending'
            assert result['estimated_credits'] == 10
            
            # 验证调用
            mock_consume.assert_called_once()
            mock_create_job.assert_called_once()
    
    @pytest.mark.asyncio
    async def test_generate_image_insufficient_credits(self, ai_service):
        """测试积分不足"""
        user_id = UUID('019d1234-5678-7abc-def0-111111111111')
        
        with patch.object(ai_service, '_check_quota', return_value=True), \
             patch.object(ai_service, '_get_model') as mock_get_model, \
             patch.object(ai_service.credit_service, 'consume_credits') as mock_consume:
            
            mock_model = MagicMock()
            mock_model.credits_per_unit = 10
            mock_get_model.return_value = mock_model
            
            # 模拟积分不足
            mock_consume.side_effect = InsufficientCreditsError("积分不足")
            
            # 验证抛出异常
            with pytest.raises(ValidationError, match="积分不足"):
                await ai_service.generate_image(
                    user_id=user_id,
                    prompt="test"
                )
    
    @pytest.mark.asyncio
    async def test_get_job_status_not_found(self, ai_service):
        """测试查询不存在的任务"""
        job_id = UUID('019d1234-5678-7abc-def0-111111111111')
        
        with patch.object(ai_service.job_repository, 'get_by_id', return_value=None):
            with pytest.raises(NotFoundError, match="任务不存在"):
                await ai_service.get_job_status(job_id)
    
    @pytest.mark.asyncio
    async def test_cancel_job_success(self, ai_service):
        """测试取消任务成功"""
        user_id = UUID('019d1234-5678-7abc-def0-111111111111')
        job_id = UUID('019d1234-5678-7abc-def0-222222222222')
        
        mock_job = MagicMock()
        mock_job.user_id = user_id
        mock_job.status = AIJobStatus.PENDING
        mock_job.task_id = 'celery-task-123'
        
        with patch.object(ai_service.job_repository, 'get_by_id', return_value=mock_job), \
             patch.object(ai_service.job_repository, 'update') as mock_update, \
             patch('app.tasks.celery_app.celery_app.control.revoke') as mock_revoke:
            
            await ai_service.cancel_job(user_id, job_id)
            
            mock_revoke.assert_called_once_with('celery-task-123', terminate=True)
            mock_update.assert_called_once()

集成测试

# server/tests/integration/test_ai_api.py
import pytest
from httpx import AsyncClient
from uuid import UUID
from app.main import app
from app.core.database import get_db
from app.models.ai_job import AIJobType, AIJobStatus

@pytest.fixture
async def client():
    """创建测试客户端"""
    async with AsyncClient(app=app, base_url="http://test") as ac:
        yield ac

@pytest.fixture
async def auth_headers(client):
    """获取认证 token"""
    # 登录获取 token
    response = await client.post("/api/v1/auth/login", json={
        "email": "test@example.com",
        "password": "testpass123"
    })
    token = response.json()["data"]["accessToken"]
    return {"Authorization": f"Bearer {token}"}

class TestAIAPI:
    """AI API 集成测试"""
    
    @pytest.mark.asyncio
    async def test_create_image_job(self, client, auth_headers):
        """测试创建图片生成任务"""
        response = await client.post(
            "/api/v1/ai/jobs",
            headers=auth_headers,
            json={
                "jobType": "image",
                "prompt": "一只可爱的猫咪",
                "width": 1024,
                "height": 1024
            }
        )
        
        assert response.status_code == 200
        data = response.json()
        assert data["code"] == 200
        assert "jobId" in data["data"]
        assert "taskId" in data["data"]
        assert data["data"]["status"] == "pending"
    
    @pytest.mark.asyncio
    async def test_create_job_invalid_params(self, client, auth_headers):
        """测试无效参数"""
        response = await client.post(
            "/api/v1/ai/jobs",
            headers=auth_headers,
            json={
                "jobType": "image",
                "prompt": "",  # 空提示词
                "width": 1024
            }
        )
        
        assert response.status_code == 400
        data = response.json()
        assert data["code"] == 400
        assert "prompt" in data["message"].lower()
    
    @pytest.mark.asyncio
    async def test_get_job_status(self, client, auth_headers):
        """测试查询任务状态"""
        # 先创建任务
        create_response = await client.post(
            "/api/v1/ai/jobs",
            headers=auth_headers,
            json={
                "jobType": "image",
                "prompt": "test"
            }
        )
        job_id = create_response.json()["data"]["jobId"]
        
        # 查询状态
        response = await client.get(
            f"/api/v1/ai/jobs/{job_id}",
            headers=auth_headers
        )
        
        assert response.status_code == 200
        data = response.json()
        assert data["code"] == 200
        assert data["data"]["jobId"] == job_id
        assert "status" in data["data"]
    
    @pytest.mark.asyncio
    async def test_cancel_job(self, client, auth_headers):
        """测试取消任务"""
        # 先创建任务
        create_response = await client.post(
            "/api/v1/ai/jobs",
            headers=auth_headers,
            json={
                "jobType": "image",
                "prompt": "test"
            }
        )
        job_id = create_response.json()["data"]["jobId"]
        
        # 取消任务
        response = await client.post(
            f"/api/v1/ai/jobs/{job_id}/cancel",
            headers=auth_headers
        )
        
        assert response.status_code == 200
        data = response.json()
        assert data["code"] == 200
        assert "已取消" in data["message"]
    
    @pytest.mark.asyncio
    async def test_get_usage_stats(self, client, auth_headers):
        """测试获取使用统计"""
        response = await client.get(
            "/api/v1/ai/usage/stats",
            headers=auth_headers
        )
        
        assert response.status_code == 200
        data = response.json()
        assert data["code"] == 200
        assert "totalCost" in data["data"]
        assert "totalCreditsUsed" in data["data"]
        assert "quotas" in data["data"]
    
    @pytest.mark.asyncio
    async def test_get_available_models(self, client, auth_headers):
        """测试获取可用模型列表"""
        response = await client.get(
            "/api/v1/ai/models",
            headers=auth_headers
        )
        
        assert response.status_code == 200
        data = response.json()
        assert data["code"] == 200
        assert "items" in data["data"]
        assert len(data["data"]["items"]) > 0

测试运行

# 运行所有 AI Service 测试
docker exec jointo-server-app pytest tests/unit/test_ai_service.py -v

# 运行集成测试
docker exec jointo-server-app pytest tests/integration/test_ai_api.py -v

# 运行所有测试并生成覆盖率报告
docker exec jointo-server-app pytest tests/ --cov=app/services/ai_service --cov-report=html

# 运行特定测试
docker exec jointo-server-app pytest tests/unit/test_ai_service.py::TestAIService::test_generate_image_success -v

数据库迁移脚本

创建 AI 服务表

# server/alembic/versions/20260129_1700_create_ai_service_tables.py
"""create ai service tables

Revision ID: abc123def456
Revises: 3a3a2a1417de
Create Date: 2026-01-29 17:00:00.000000

"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = 'abc123def456'
down_revision = '3a3a2a1417de'
branch_labels = None
depends_on = None


def upgrade() -> None:
    """创建 AI 服务相关表(无外键约束,应用层保证引用完整性)"""
    
    # 1. 创建 ai_models 表
    op.create_table(
        'ai_models',
        sa.Column('model_id', postgresql.UUID(as_uuid=True), nullable=False),
        sa.Column('model_name', sa.Text(), nullable=False),
        sa.Column('display_name', sa.Text(), nullable=False),
        sa.Column('description', sa.Text(), nullable=True),
        sa.Column('provider', sa.SmallInteger(), nullable=False),
        sa.Column('model_type', sa.SmallInteger(), nullable=False),
        sa.Column('cost_per_unit', sa.Numeric(10, 4), nullable=False),
        sa.Column('unit_type', sa.SmallInteger(), nullable=False),
        sa.Column('credits_per_unit', sa.Integer(), nullable=False),
        sa.Column('config', postgresql.JSONB(astext_type=sa.Text()), nullable=False, server_default='{}'),
        sa.Column('rate_limit', sa.Integer(), nullable=True),
        sa.Column('daily_quota', sa.Integer(), nullable=True),
        sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
        sa.Column('is_beta', sa.Boolean(), nullable=False, server_default='false'),
        sa.Column('created_at', sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.text('now()')),
        sa.Column('updated_at', sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.text('now()')),
        sa.PrimaryKeyConstraint('model_id'),
        sa.UniqueConstraint('model_name'),
        comment='AI 模型配置表'
    )
    
    # 索引
    op.create_index('idx_ai_models_provider', 'ai_models', ['provider'], 
                    postgresql_where=sa.text('is_active = true'))
    op.create_index('idx_ai_models_type', 'ai_models', ['model_type'], 
                    postgresql_where=sa.text('is_active = true'))
    op.create_index('idx_ai_models_is_active', 'ai_models', ['is_active'])
    op.create_index('idx_ai_models_config_gin', 'ai_models', ['config'], 
                    postgresql_using='gin')
    
    # 触发器
    op.execute("""
        CREATE TRIGGER update_ai_models_updated_at
            BEFORE UPDATE ON ai_models
            FOR EACH ROW
            EXECUTE FUNCTION update_updated_at_column();
    """)
    
    # 2. 创建 ai_jobs 表(无外键约束)
    op.create_table(
        'ai_jobs',
        sa.Column('ai_job_id', postgresql.UUID(as_uuid=True), nullable=False),
        sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
        sa.Column('project_id', postgresql.UUID(as_uuid=True), nullable=True),
        sa.Column('storyboard_id', postgresql.UUID(as_uuid=True), nullable=True),
        sa.Column('consumption_log_id', postgresql.UUID(as_uuid=True), nullable=True),
        sa.Column('job_type', sa.SmallInteger(), nullable=False),
        sa.Column('status', sa.SmallInteger(), nullable=False, server_default='1'),
        sa.Column('input_data', postgresql.JSONB(astext_type=sa.Text()), nullable=False, server_default='{}'),
        sa.Column('output_data', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
        sa.Column('model_id', postgresql.UUID(as_uuid=True), nullable=True),
        sa.Column('model_name', sa.Text(), nullable=True),
        sa.Column('progress', sa.Integer(), nullable=False, server_default='0'),
        sa.Column('error_message', sa.Text(), nullable=True),
        sa.Column('task_id', sa.Text(), nullable=True),
        sa.Column('estimated_completion_at', sa.TIMESTAMP(timezone=True), nullable=True),
        sa.Column('started_at', sa.TIMESTAMP(timezone=True), nullable=True),
        sa.Column('completed_at', sa.TIMESTAMP(timezone=True), nullable=True),
        sa.Column('created_at', sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.text('now()')),
        sa.Column('updated_at', sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.text('now()')),
        sa.Column('cost', sa.Numeric(10, 4), nullable=True),
        sa.Column('credits_used', sa.Integer(), nullable=False, server_default='0'),
        sa.PrimaryKeyConstraint('ai_job_id'),
        sa.CheckConstraint('progress >= 0 AND progress <= 100', name='ai_jobs_progress_check'),
        comment='AI 任务表 - 应用层保证引用完整性'
    )
    
    # 索引(关联字段必须有索引)
    op.create_index('idx_ai_jobs_user_id', 'ai_jobs', ['user_id'])
    op.create_index('idx_ai_jobs_project_id', 'ai_jobs', ['project_id'], 
                    postgresql_where=sa.text('project_id IS NOT NULL'))
    op.create_index('idx_ai_jobs_storyboard_id', 'ai_jobs', ['storyboard_id'], 
                    postgresql_where=sa.text('storyboard_id IS NOT NULL'))
    op.create_index('idx_ai_jobs_type', 'ai_jobs', ['job_type'])
    op.create_index('idx_ai_jobs_status', 'ai_jobs', ['status'])
    op.create_index('idx_ai_jobs_model_id', 'ai_jobs', ['model_id'], 
                    postgresql_where=sa.text('model_id IS NOT NULL'))
    op.create_index('idx_ai_jobs_consumption_log_id', 'ai_jobs', ['consumption_log_id'], 
                    postgresql_where=sa.text('consumption_log_id IS NOT NULL'))
    op.create_index('idx_ai_jobs_created_at', 'ai_jobs', ['created_at'])
    op.create_index('idx_ai_jobs_status_created_at', 'ai_jobs', ['status', 'created_at'], 
                    postgresql_where=sa.text('status IN (1, 2)'))
    op.create_index('idx_ai_jobs_input_data_gin', 'ai_jobs', ['input_data'], 
                    postgresql_using='gin')
    op.create_index('idx_ai_jobs_output_data_gin', 'ai_jobs', ['output_data'], 
                    postgresql_using='gin',
                    postgresql_where=sa.text('output_data IS NOT NULL'))
    
    # 触发器
    op.execute("""
        CREATE TRIGGER update_ai_jobs_updated_at
            BEFORE UPDATE ON ai_jobs
            FOR EACH ROW
            EXECUTE FUNCTION update_updated_at_column();
    """)
    
    # 字段注释
    op.execute("""
        COMMENT ON COLUMN ai_jobs.user_id IS '用户 ID - 应用层验证';
        COMMENT ON COLUMN ai_jobs.project_id IS '项目 ID - 应用层验证';
        COMMENT ON COLUMN ai_jobs.storyboard_id IS '分镜 ID - 应用层验证';
        COMMENT ON COLUMN ai_jobs.consumption_log_id IS '积分消耗日志 ID - 应用层验证';
        COMMENT ON COLUMN ai_jobs.model_id IS 'AI 模型 ID - 应用层验证';
        COMMENT ON COLUMN ai_jobs.job_type IS '任务类型(1=图片 2=视频 3=音效 4=配音 5=字幕 6=文本处理 7=资源 8=分镜脚本 9=剧本生成)';
        COMMENT ON COLUMN ai_jobs.status IS '任务状态(1=等待处理 2=处理中 3=已完成 4=失败 5=已取消)';
    """)
    
    # 3. 创建 ai_usage_logs 表(无外键约束)
    op.create_table(
        'ai_usage_logs',
        sa.Column('usage_log_id', postgresql.UUID(as_uuid=True), nullable=False),
        sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
        sa.Column('ai_job_id', postgresql.UUID(as_uuid=True), nullable=False),
        sa.Column('model_id', postgresql.UUID(as_uuid=True), nullable=True),
        sa.Column('units_used', sa.Numeric(10, 2), nullable=False),
        sa.Column('unit_type', sa.SmallInteger(), nullable=False),
        sa.Column('cost', sa.Numeric(10, 4), nullable=False),
        sa.Column('credits_used', sa.Integer(), nullable=False),
        sa.Column('meta_data', postgresql.JSONB(astext_type=sa.Text()), nullable=False, server_default='{}'),
        sa.Column('created_at', sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.text('now()')),
        sa.PrimaryKeyConstraint('usage_log_id'),
        sa.CheckConstraint('unit_type BETWEEN 1 AND 4', name='ai_usage_logs_unit_type_check'),
        comment='AI 使用日志表 - 应用层保证引用完整性'
    )
    
    # 索引
    op.create_index('idx_ai_usage_logs_user_id', 'ai_usage_logs', ['user_id'])
    op.create_index('idx_ai_usage_logs_ai_job_id', 'ai_usage_logs', ['ai_job_id'])
    op.create_index('idx_ai_usage_logs_model_id', 'ai_usage_logs', ['model_id'], 
                    postgresql_where=sa.text('model_id IS NOT NULL'))
    op.create_index('idx_ai_usage_logs_created_at', 'ai_usage_logs', ['created_at'])
    op.create_index('idx_ai_usage_logs_user_created', 'ai_usage_logs', ['user_id', 'created_at'])
    op.create_index('idx_ai_usage_logs_meta_data_gin', 'ai_usage_logs', ['meta_data'], 
                    postgresql_using='gin')
    
    # 字段注释
    op.execute("""
        COMMENT ON COLUMN ai_usage_logs.user_id IS '用户 ID - 应用层验证';
        COMMENT ON COLUMN ai_usage_logs.ai_job_id IS 'AI 任务 ID - 应用层验证';
        COMMENT ON COLUMN ai_usage_logs.model_id IS 'AI 模型 ID - 应用层验证';
        COMMENT ON COLUMN ai_usage_logs.unit_type IS '单位类型(1=Token 2=图片 3=秒 4=请求)';
    """)
    
    # 4. 创建 ai_quotas 表(无外键约束)
    op.create_table(
        'ai_quotas',
        sa.Column('quota_id', postgresql.UUID(as_uuid=True), nullable=False),
        sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
        sa.Column('quota_type', sa.Text(), nullable=False),
        sa.Column('period', sa.SmallInteger(), nullable=False),
        sa.Column('total_quota', sa.Integer(), nullable=False),
        sa.Column('used_quota', sa.Integer(), nullable=False, server_default='0'),
        sa.Column('reset_at', sa.TIMESTAMP(timezone=True), nullable=False),
        sa.Column('created_at', sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.text('now()')),
        sa.Column('updated_at', sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.text('now()')),
        sa.PrimaryKeyConstraint('quota_id'),
        sa.UniqueConstraint('user_id', 'quota_type', 'period', name='ai_quotas_unique'),
        sa.CheckConstraint('period BETWEEN 1 AND 3', name='ai_quotas_period_check'),
        sa.CheckConstraint('used_quota >= 0 AND used_quota <= total_quota', name='ai_quotas_used_check'),
        comment='AI 配额表 - 应用层保证引用完整性'
    )
    
    # 索引
    op.create_index('idx_ai_quotas_user_id', 'ai_quotas', ['user_id'])
    op.create_index('idx_ai_quotas_type', 'ai_quotas', ['quota_type'])
    op.create_index('idx_ai_quotas_reset_at', 'ai_quotas', ['reset_at'])
    op.create_index('idx_ai_quotas_user_type', 'ai_quotas', ['user_id', 'quota_type'])
    
    # 触发器
    op.execute("""
        CREATE TRIGGER update_ai_quotas_updated_at
            BEFORE UPDATE ON ai_quotas
            FOR EACH ROW
            EXECUTE FUNCTION update_updated_at_column();
    """)
    
    # 字段注释
    op.execute("""
        COMMENT ON COLUMN ai_quotas.user_id IS '用户 ID - 应用层验证';
        COMMENT ON COLUMN ai_quotas.period IS '配额周期(1=每日 2=每月 3=总计)';
    """)
    
    # 5. 创建配额重置函数
    op.execute("""
        CREATE OR REPLACE FUNCTION reset_expired_quotas()
        RETURNS void AS $$
        BEGIN
            -- 重置每日配额 (period = 1)
            UPDATE ai_quotas
            SET used_quota = 0,
                reset_at = reset_at + INTERVAL '1 day',
                updated_at = now()
            WHERE period = 1 AND reset_at <= now();
            
            -- 重置每月配额 (period = 2)
            UPDATE ai_quotas
            SET used_quota = 0,
                reset_at = reset_at + INTERVAL '1 month',
                updated_at = now()
            WHERE period = 2 AND reset_at <= now();
        END;
        $$ LANGUAGE plpgsql;
    """)


def downgrade() -> None:
    """删除 AI 服务相关表"""
    op.execute("DROP FUNCTION IF EXISTS reset_expired_quotas();")
    op.drop_table('ai_quotas')
    op.drop_table('ai_usage_logs')
    op.drop_table('ai_jobs')
    op.drop_table('ai_models')

运行迁移

# 生成迁移文件
docker exec jointo-server-app alembic revision --autogenerate -m "create ai service tables"

# 应用迁移
docker exec jointo-server-app alembic upgrade head

# 查看当前版本
docker exec jointo-server-app alembic current

# 回滚迁移
docker exec jointo-server-app alembic downgrade -1