# AI 生成服务 > **文档版本**:v1.0 > **最后更新**:2025-01-27 --- ## 目录 1. [服务概述](#服务概述) 2. [核心功能](#核心功能) 3. [与 Credit Service 集成](#与-credit-service-集成) 4. [数据库设计](#数据库设计) 5. [服务实现](#服务实现) 6. [API 接口](#api-接口) 7. [AI 提供商集成](#ai-提供商集成) --- ## 服务概述 AI 生成服务负责处理各类 AI 内容生成任务,包括图片、视频、音效、配音等,支持多种 AI 模型和服务提供商。 ### 职责 - extra_data图片生成(文本转图片) - 视频生成(文本转视频、图片转视频) - 音效生成 - 配音生成(文本转语音) - 字幕生成(语音转文本) - 任务状态管理 --- ## 核心功能 ### 应用层引用完整性保证 ⚠️ **重要**:本服务遵循 Jointo 技术栈规范,**不使用数据库外键约束**,改为在应用层保证引用完整性。 **三层保证机制**: 1. **Repository 层**:提供 `exists()` 方法检查记录是否存在 2. **Service 层**:创建/更新前验证所有关联 ID,使用事务确保原子性 3. **后台任务**:定期检查孤儿记录并告警 **优势**: - 写入性能提升 15-30% - 为分库分表、微服务化做准备 - 复杂业务逻辑在应用层更易实现 - 表结构变更无需处理外键依赖 详见:[数据库设计](#数据库设计) 和 [服务实现](#服务实现) --- ### 1. 图片生成 **支持的模型**: - Stable Diffusion(开源) - DALL-E(OpenAI) - Midjourney(如果可用) **功能**: - 文本描述生成图片 - 支持风格控制 - 支持分辨率设置 - 异步生成,返回任务 ID ### 2. 视频生成 **支持的类型**: - **文本转视频**(text2video) - **图片转视频**(img2video) - **关键帧动画**(keyframe) - **视频融合**(fusion) - **视频替换**(replace) **支持的模型**: - Runway Gen-2 - Pika Labs - Stable Video Diffusion ### 3. 音效生成 **功能**: - 根据描述生成音效 - 支持音效类型选择 - 支持时长控制 ### 4. 配音生成 **功能**: - 文本转语音(TTS) - 支持多种语音类型 - 支持语速、音调控制 - 支持多语言 **支持的服务**: - OpenAI TTS - Azure Speech Services - 百度语音 / 讯飞语音 ### 5. 字幕生成 **功能**: - 语音转文本(STT) - 自动断句 - 分镜看板对齐 **支持的服务**: - Whisper(OpenAI) - 阿里云语音识别 ### 6. 文本处理 **功能**: - **剧本拆解**:将完整剧本拆分为场景、镜头 - **内容分析**:提取关键信息(人物、地点、时间、动作) - **结构化输出**:生成 JSON 格式的结构化数据 - **智能补全**:根据上下文补充缺失信息 - **风格转换**:调整文本风格和语气 **支持的模型**: - GPT-4 / GPT-3.5(OpenAI) - Claude(Anthropic) - 文心一言(百度) - 通义千问(阿里) **应用场景**: - 剧本自动拆解为分镜 - 场景描述生成视觉提示词 - 对话文本生成配音脚本 - 内容审核和合规检查 ### 7. 剧本解析(Screenplay Parsing) **核心功能**: - **自动提取剧本元素**: - 角色识别(主角、配角、群演) - 场景识别(地点、时间、描述) - 道具识别(重要性分类) - **变体识别**: - 角色变体(年龄段、时代、状态) - 场景变体(时代、季节、状态) - 道具变体(状态、版本) - **分镜拆解**: - 自动拆分为分镜脚本 - 识别景别和运镜 - 估算时长 - **自动关联**: - 分镜与角色/场景/道具自动关联 - 变体自动匹配 - **数据持久化**: - 自动存储到数据库 - 建立关联关系 **工作流程**: 1. 用户上传/创建剧本 2. 触发 AI 解析任务 3. 预扣积分,创建 AI 任务 4. Celery Worker 调用 AI 模型 5. AI 返回结构化 JSON 数据 6. 自动存储角色/场景/道具/变体 7. 自动创建分镜记录 8. 自动建立关联关系 9. 确认积分消耗 10. 返回解析结果 **详细文档**:参见 [AI 解析剧本工作流](../../workflows/screenplay-ai-parse-workflow.md) **API 接口**: ``` POST /api/v1/screenplays/{screenplay_id}/parse ``` **请求体**: ```json { "auto_create_elements": true, "auto_create_tags": true, "auto_create_storyboards": true, "model": "gpt-4" } ``` **响应**: ```json { "code": 200, "message": "Success", "data": { "jobId": "019d1234-5678-7abc-def0-222222222222", "taskId": "abc123-def456-ghi789", "status": "pending", "estimatedCredits": 50 } } ``` **AI 输出格式**: AI 模型返回包含以下结构的 JSON 数据: ```json { "characters": [ { "name": "张三", "description": "男主角,30岁,程序员", "role_type": "main", "meta_data": {"age": 30, "gender": "male"} } ], "scenes": [ { "scene_number": 1, "title": "咖啡厅", "location": "市中心星巴克", "time_of_day": "afternoon", "description": "温馨的咖啡厅" } ], "props": [ { "name": "笔记本电脑", "description": "张三的工作电脑", "category": "电子设备", "importance": "normal" } ], "character_tags": [ { "character_name": "张三", "tag_key": "youth", "tag_label": "少年", "description": "15岁的张三" } ], "scene_tags": [...], "prop_tags": [...], "storyboards": [ { "shot_number": "001", "title": "开场", "description": "张三坐在咖啡厅里", "dialogue": "又是一个平凡的下午...", "shot_size": "medium_shot", "camera_movement": "static", "estimated_duration": 5.5, "characters": ["张三"], "character_tags": {"张三": "adult"}, "scenes": ["咖啡厅"], "props": ["笔记本电脑"] } ] } ``` **自动存储逻辑**: 1. **存储角色**:批量插入 `screenplay_characters` 表,返回角色 ID 映射 2. **存储场景**:批量插入 `screenplay_scenes` 表,返回场景 ID 映射 3. **存储道具**:批量插入 `screenplay_props` 表,返回道具 ID 映射 4. **存储标签**:调用 `ScreenplayTagService.store_tags()` 批量插入 `screenplay_element_tags` 表,返回标签 ID 映射 5. **存储分镜**:批量插入 `storyboards` 表,同时建立关联关系 **标签存储详细流程**: ```python # Celery Worker 调用 Screenplay Service result = await screenplay_service.store_parsed_elements( screenplay_id=screenplay_id, parsed_data=parsed_data ) # Screenplay Service 内部调用 Screenplay Tag Service from app.services.screenplay_tag_service import ScreenplayTagService tag_service = ScreenplayTagService(db) tag_id_maps = await tag_service.store_tags( screenplay_id=screenplay_id, parsed_data=parsed_data, character_id_map=character_id_map, scene_id_map=scene_id_map, prop_id_map=prop_id_map ) # 返回的 tag_id_maps 结构 { 'character_tags': { '张三-youth': UUID('019d1234-5678-7abc-def0-444444444444'), '张三-adult': UUID('019d1234-5678-7abc-def0-555555555555') }, 'scene_tags': { '花果山-daytime': UUID('019d1234-5678-7abc-def0-666666666666'), '花果山-night': UUID('019d1234-5678-7abc-def0-777777777777') }, 'prop_tags': { '金箍棒-new': UUID('019d1234-5678-7abc-def0-888888888888') } } ``` **ScreenplayTagService.store_tags() 实现**: ```python async def store_tags( self, screenplay_id: UUID, parsed_data: Dict[str, Any], character_id_map: Dict[str, UUID], scene_id_map: Dict[str, UUID], prop_id_map: Dict[str, UUID] ) -> Dict[str, Dict[str, UUID]]: """存储 AI 解析的标签(供 AI Service 调用)""" logger.info("存储 AI 解析的标签: 剧本=%s", screenplay_id) tag_id_maps = { 'character_tags': {}, 'scene_tags': {}, 'prop_tags': {} } # 1. 存储角色标签 for char_name, tags in parsed_data.get('character_tags', {}).items(): character_id = character_id_map.get(char_name) if not character_id: continue for tag_data in tags: tag = await self.repository.create(ScreenplayElementTag( screenplay_id=screenplay_id, element_type=ElementType.CHARACTER, element_id=character_id, element_name=char_name, tag_key=tag_data['tag_key'], tag_label=tag_data['tag_label'], description=tag_data.get('description'), meta_data=tag_data.get('meta_data', {}) )) map_key = f"{char_name}-{tag_data['tag_key']}" tag_id_maps['character_tags'][map_key] = tag.tag_id # 更新角色的 has_tags 标志 await self._update_element_has_tags(ElementType.CHARACTER, character_id, True) # 2. 存储场景标签(逻辑类似) for scene_name, tags in parsed_data.get('scene_tags', {}).items(): scene_id = scene_id_map.get(scene_name) if not scene_id: continue for tag_data in tags: tag = await self.repository.create(ScreenplayElementTag( screenplay_id=screenplay_id, element_type=ElementType.SCENE, element_id=scene_id, element_name=scene_name, tag_key=tag_data['tag_key'], tag_label=tag_data['tag_label'], description=tag_data.get('description'), meta_data=tag_data.get('meta_data', {}) )) map_key = f"{scene_name}-{tag_data['tag_key']}" tag_id_maps['scene_tags'][map_key] = tag.tag_id await self._update_element_has_tags(ElementType.SCENE, scene_id, True) # 3. 存储道具标签(逻辑类似) for prop_name, tags in parsed_data.get('prop_tags', {}).items(): prop_id = prop_id_map.get(prop_name) if not prop_id: continue for tag_data in tags: tag = await self.repository.create(ScreenplayElementTag( screenplay_id=screenplay_id, element_type=ElementType.PROP, element_id=prop_id, element_name=prop_name, tag_key=tag_data['tag_key'], tag_label=tag_data['tag_label'], description=tag_data.get('description'), meta_data=tag_data.get('meta_data', {}) )) map_key = f"{prop_name}-{tag_data['tag_key']}" tag_id_maps['prop_tags'][map_key] = tag.tag_id await self._update_element_has_tags(ElementType.PROP, prop_id, True) logger.info( "标签存储完成: 角色标签=%d, 场景标签=%d, 道具标签=%d", len(tag_id_maps['character_tags']), len(tag_id_maps['scene_tags']), len(tag_id_maps['prop_tags']) ) return tag_id_maps ``` **自动关联逻辑**: 分镜创建时,根据 AI 返回的名称数组和标签映射,自动查找对应的数据库 ID: ```python # 示例:角色关联 for char_name in storyboard_data['characters']: character_id = character_id_map.get(char_name) if character_id: screenplay_character_ids.append(character_id) # 检查标签 tag_key = storyboard_data['character_tags'].get(char_name) if tag_key: tag_id = tag_id_maps['character_tags'].get(f"{char_name}-{tag_key}") if tag_id: screenplay_character_tag_ids.append(tag_id) ``` **涉及的服务**: - **AI Service**:调用 AI 模型进行解析 - **Screenplay Service**:管理剧本和剧本元素 - **Screenplay Tag Service**:管理标签(新增) - **Storyboard Service**:管理分镜 - **Credit Service**:管理积分扣除 --- ## 与 Credit Service 集成 ### 集成架构 AI 服务与积分服务采用**同步调用 + 事务保证**的集成模式: ``` AI Service (依赖) → Credit Service (被依赖) ``` ### 积分扣除流程 #### 1. 任务创建时预扣积分 ```python # AI Service 创建任务时 async def generate_image(self, user_id: UUID, ...): # 1. 计算所需积分 model_config = await self._get_model(model_name, 'image') credits_needed = model_config.credits_per_unit # 2. 调用 Credit Service 预扣积分(同步,事务保证) try: consumption_log = await self.credit_service.consume_credits( user_id=user_id, amount=credits_needed, feature_type=1, # FeatureType.IMAGE_GENERATION ai_job_id=None, # 稍后更新 task_params={...} ) except InsufficientCreditsError: raise ValidationError("积分不足") # 3. 创建 AI 任务,关联 consumption_log job = await self.job_repository.create({ 'consumption_log_id': consumption_log.consumption_id, ... }) # 4. 更新 consumption_log 的 ai_job_id consumption_log.ai_job_id = job.ai_job_id await self.db.commit() ``` #### 2. 任务完成时确认扣除 ```python # Celery Worker 任务完成 @celery_app.task def generate_image_task(job_id: UUID): try: # 执行 AI 生成 result = call_ai_provider(...) # 更新任务状态 await ai_service.update_job(job_id, {'status': 3, ...}) # AIJobStatus.COMPLETED # 确认积分消耗 job = await ai_service.get_job(job_id) await credit_service.confirm_consumption( consumption_id=job.consumption_log_id, resource_id=result.resource_id ) except Exception as e: # 任务失败,退还积分 await ai_service.update_job(job_id, {'status': 4, ...}) # AIJobStatus.FAILED await credit_service.refund_credits( consumption_id=job.consumption_log_id, reason=str(e) ) ``` ### 数据关联 ```sql -- AI Service 表关联 Credit Service(无外键约束,应用层保证引用完整性) ALTER TABLE ai_jobs ADD COLUMN consumption_log_id UUID; CREATE INDEX idx_ai_jobs_consumption_log_id ON ai_jobs (consumption_log_id) WHERE consumption_log_id IS NOT NULL; -- Credit Service 表关联 AI Service(无外键约束,应用层保证引用完整性) ALTER TABLE credit_consumption_logs ADD COLUMN ai_job_id UUID; CREATE INDEX idx_credit_consumption_logs_ai_job_id ON credit_consumption_logs (ai_job_id) WHERE ai_job_id IS NOT NULL; COMMENT ON COLUMN ai_jobs.consumption_log_id IS '积分消耗日志 ID - 应用层验证'; COMMENT ON COLUMN credit_consumption_logs.ai_job_id IS 'AI 任务 ID - 应用层验证'; ``` ### 事务保证 使用数据库事务确保积分扣除和任务创建的原子性: ```python async with self.db.begin(): # 1. 扣除积分 consumption_log = await credit_service.consume_credits(...) # 2. 创建任务 job = await ai_service.create_job(...) # 3. 关联 consumption_log.ai_job_id = job.ai_job_id # 提交事务(失败自动回滚) ``` --- ## 数据库设计 AI 服务需要以下数据表支撑任务管理、成本控制和配额管理。 ### 3.1 ai_jobs(AI 任务表) **核心表**,记录所有 AI 生成任务的状态和结果。 ```sql -- Python 枚举定义(app/models/ai_job.py) -- class AIJobType(IntEnum): -- IMAGE = 1 # 图片生成 -- VIDEO = 2 # 视频生成 -- SOUND = 3 # 音效生成 -- VOICE = 4 # 配音生成 -- SUBTITLE = 5 # 字幕生成 -- TEXT_PROCESSING = 6 # 文本处理(剧本拆解等) -- RESOURCE = 7 # 资源生成 -- STORYBOARD_SCRIPT = 8 # 分镜脚本生成 -- SCRIPT_GENERATION = 9 # 剧本生成 -- class AIJobStatus(IntEnum): -- PENDING = 1 # 等待处理 -- PROCESSING = 2 # 处理中 -- COMPLETED = 3 # 已完成 -- FAILED = 4 # 失败 -- CANCELLED = 5 # 已取消 CREATE TABLE ai_jobs ( ai_job_id UUID PRIMARY KEY, -- AI 任务唯一标识 -- 关联信息(无外键约束,应用层保证引用完整性) user_id UUID NOT NULL, -- 用户 ID project_id UUID, -- 项目 ID(可选) storyboard_id UUID, -- 分镜 ID(可选) -- 积分关联(无外键约束,应用层保证引用完整性) consumption_log_id UUID, -- 积分消耗日志 ID -- 任务信息 job_type SMALLINT NOT NULL, -- 任务类型(1=图片 2=视频 3=音效 4=配音 5=字幕 6=文本处理 7=资源 8=分镜脚本 9=剧本生成) status SMALLINT NOT NULL DEFAULT 1, -- 任务状态(1=等待处理 2=处理中 3=已完成 4=失败 5=已取消) -- 输入输出 input_data JSONB NOT NULL DEFAULT '{}', -- 输入参数(prompt, 配置等) output_data JSONB, -- 输出结果(URL, 元数据等) -- AI 模型信息(无外键约束,应用层保证引用完整性) model_id UUID, -- AI 模型 ID model_name TEXT, -- 使用的模型名称(冗余字段) -- 任务状态 progress INTEGER NOT NULL DEFAULT 0 CHECK (progress >= 0 AND progress <= 100), -- 任务进度(0-100) error_message TEXT, -- 错误信息 task_id TEXT, -- Celery 异步任务 ID -- 时间信息 estimated_completion_at TIMESTAMPTZ, -- 预计完成时间 started_at TIMESTAMPTZ, -- 开始处理时间 completed_at TIMESTAMPTZ, -- 完成时间 created_at TIMESTAMPTZ NOT NULL DEFAULT now(), -- 创建时间 updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), -- 更新时间 -- 成本信息 cost NUMERIC(10, 4), -- 实际成本(元) credits_used INTEGER NOT NULL DEFAULT 0 -- 使用的积分 ); -- 字段注释 COMMENT ON TABLE ai_jobs IS 'AI 任务表 - 应用层保证引用完整性'; COMMENT ON COLUMN ai_jobs.ai_job_id IS 'AI 任务唯一标识'; COMMENT ON COLUMN ai_jobs.user_id IS '用户 ID - 应用层验证'; COMMENT ON COLUMN ai_jobs.project_id IS '项目 ID(可选)- 应用层验证'; COMMENT ON COLUMN ai_jobs.storyboard_id IS '分镜 ID(可选)- 应用层验证'; COMMENT ON COLUMN ai_jobs.consumption_log_id IS '积分消耗日志 ID - 应用层验证'; COMMENT ON COLUMN ai_jobs.job_type IS '任务类型(1=图片 2=视频 3=音效 4=配音 5=字幕 6=文本处理 7=资源 8=分镜脚本 9=剧本生成)'; COMMENT ON COLUMN ai_jobs.status IS '任务状态(1=等待处理 2=处理中 3=已完成 4=失败 5=已取消)'; COMMENT ON COLUMN ai_jobs.input_data IS '输入参数(prompt, 配置等)'; COMMENT ON COLUMN ai_jobs.output_data IS '输出结果(URL, 元数据等)'; COMMENT ON COLUMN ai_jobs.model_id IS 'AI 模型 ID - 应用层验证'; COMMENT ON COLUMN ai_jobs.model_name IS '使用的模型名称(冗余字段)'; COMMENT ON COLUMN ai_jobs.progress IS '任务进度(0-100)'; COMMENT ON COLUMN ai_jobs.error_message IS '错误信息'; COMMENT ON COLUMN ai_jobs.task_id IS 'Celery 异步任务 ID'; COMMENT ON COLUMN ai_jobs.estimated_completion_at IS '预计完成时间'; COMMENT ON COLUMN ai_jobs.started_at IS '开始处理时间'; COMMENT ON COLUMN ai_jobs.completed_at IS '完成时间'; COMMENT ON COLUMN ai_jobs.created_at IS '创建时间'; COMMENT ON COLUMN ai_jobs.updated_at IS '更新时间'; COMMENT ON COLUMN ai_jobs.cost IS '实际成本(元)'; COMMENT ON COLUMN ai_jobs.credits_used IS '使用的积分'; -- 索引 CREATE INDEX idx_ai_jobs_user_id ON ai_jobs (user_id); CREATE INDEX idx_ai_jobs_project_id ON ai_jobs (project_id) WHERE project_id IS NOT NULL; CREATE INDEX idx_ai_jobs_storyboard_id ON ai_jobs (storyboard_id) WHERE storyboard_id IS NOT NULL; CREATE INDEX idx_ai_jobs_type ON ai_jobs (job_type); CREATE INDEX idx_ai_jobs_status ON ai_jobs (status); CREATE INDEX idx_ai_jobs_model_id ON ai_jobs (model_id) WHERE model_id IS NOT NULL; CREATE INDEX idx_ai_jobs_consumption_log_id ON ai_jobs (consumption_log_id) WHERE consumption_log_id IS NOT NULL; CREATE INDEX idx_ai_jobs_created_at ON ai_jobs (created_at); CREATE INDEX idx_ai_jobs_status_created_at ON ai_jobs (status, created_at) WHERE status IN (1, 2); -- PENDING, PROCESSING CREATE INDEX idx_ai_jobs_input_data_gin ON ai_jobs USING GIN (input_data); CREATE INDEX idx_ai_jobs_output_data_gin ON ai_jobs USING GIN (output_data) WHERE output_data IS NOT NULL; -- 触发器 CREATE TRIGGER update_ai_jobs_updated_at BEFORE UPDATE ON ai_jobs FOR EACH ROW EXECUTE FUNCTION update_updated_at_column(); ``` **字段说明**: - `input_data`:存储任务输入参数 ```json { "prompt": "一只可爱的猫咪", "width": 1024, "height": 1024, "style": "realistic", "temperature": 0.7 } ``` - `output_data`:存储任务输出结果 ```json { "file_url": "https://storage.jointo.ai/ai-generated/images/abc123def456.png", "file_size": 1024000, "checksum": "abc123def456...", "mime_type": "image/png", "original_url": "https://ai-provider.com/temp/xyz789.png", "meta_data": { "width": 1024, "height": 1024, "format": "png" } } ``` **字段说明**: - `file_url`: 自有 OSS 存储的文件 URL(永久有效) - `file_size`: 文件大小(字节) - `checksum`: 文件 SHA256 校验和(用于去重) - `mime_type`: 文件 MIME 类型 - `original_url`: AI 提供商返回的原始 URL(可选,用于调试) - `meta_data`: 文件元数据(宽度、高度、格式等) --- ### 3.2 ai_models(AI 模型配置表) **推荐表**,管理可用的 AI 模型及其定价配置。 ```sql -- Python 枚举定义(app/models/ai_model.py) -- class AIModelType(IntEnum): -- TEXT = 1 # 文本模型(GPT, Claude 等) -- IMAGE = 2 # 图片模型(DALL-E, Stable Diffusion 等) -- VIDEO = 3 # 视频模型(Runway, Pika 等) -- AUDIO = 4 # 音频模型(TTS, STT 等) -- class AIProvider(IntEnum): -- OPENAI = 1 -- ANTHROPIC = 2 -- GOOGLE = 3 # Google Gemini -- STABILITY = 4 -- RUNWAY = 5 -- PIKA = 6 -- ELEVENLABS = 7 -- AZURE = 8 -- BAIDU = 9 -- ALIYUN = 10 -- CUSTOM = 99 # 自定义提供商 -- class UnitType(IntEnum): -- TOKEN = 1 # Token(文本模型) -- IMAGE = 2 # 图片(图片模型) -- SECOND = 3 # 秒(视频/音频模型) -- REQUEST = 4 # 请求(通用计费单位) CREATE TABLE ai_models ( model_id UUID PRIMARY KEY, -- AI 模型唯一标识 -- 模型信息 model_name TEXT NOT NULL UNIQUE, -- 模型名称(如:gpt-4, dall-e-3, runway-gen2) display_name TEXT NOT NULL, -- 显示名称 description TEXT, -- 模型描述 -- 分类 provider SMALLINT NOT NULL, -- 提供商(1=OpenAI 2=Anthropic 3=Google 4=Stability 5=Runway 6=Pika 7=ElevenLabs 8=Azure 9=Baidu 10=Aliyun 99=自定义) model_type SMALLINT NOT NULL, -- 模型类型(1=文本 2=图片 3=视频 4=音频) -- 定价(按使用量计费) cost_per_unit NUMERIC(10, 4) NOT NULL, -- 单位成本(元) unit_type SMALLINT NOT NULL, -- 单位类型(1=Token 2=图片 3=秒 4=请求) credits_per_unit INTEGER NOT NULL, -- 每单位消耗积分 -- 配置 config JSONB NOT NULL DEFAULT '{}', -- 模型特定配置 -- 限制 rate_limit INTEGER, -- 速率限制(每分钟请求数) daily_quota INTEGER, -- 每日配额 -- 状态 is_active BOOLEAN NOT NULL DEFAULT true, -- 是否启用 is_beta BOOLEAN NOT NULL DEFAULT false, -- 是否为测试版 -- 审计 created_at TIMESTAMPTZ NOT NULL DEFAULT now(), -- 创建时间 updated_at TIMESTAMPTZ NOT NULL DEFAULT now() -- 更新时间 ); -- 字段注释 COMMENT ON COLUMN ai_models.model_id IS 'AI 模型唯一标识'; COMMENT ON COLUMN ai_models.model_name IS '模型名称(如:gpt-4, dall-e-3, runway-gen2)'; COMMENT ON COLUMN ai_models.display_name IS '显示名称'; COMMENT ON COLUMN ai_models.description IS '模型描述'; COMMENT ON COLUMN ai_models.provider IS '提供商(1=OpenAI 2=Anthropic 3=Google 4=Stability 5=Runway 6=Pika 7=ElevenLabs 8=Azure 9=Baidu 10=Aliyun 99=自定义)'; COMMENT ON COLUMN ai_models.model_type IS '模型类型(1=文本 2=图片 3=视频 4=音频)'; COMMENT ON COLUMN ai_models.cost_per_unit IS '单位成本(元)'; COMMENT ON COLUMN ai_models.unit_type IS '单位类型(1=Token 2=图片 3=秒 4=请求)'; COMMENT ON COLUMN ai_models.credits_per_unit IS '每单位消耗积分'; COMMENT ON COLUMN ai_models.config IS '模型特定配置'; COMMENT ON COLUMN ai_models.rate_limit IS '速率限制(每分钟请求数)'; COMMENT ON COLUMN ai_models.daily_quota IS '每日配额'; COMMENT ON COLUMN ai_models.is_active IS '是否启用'; COMMENT ON COLUMN ai_models.is_beta IS '是否为测试版'; COMMENT ON COLUMN ai_models.created_at IS '创建时间'; COMMENT ON COLUMN ai_models.updated_at IS '更新时间'; -- 索引 CREATE INDEX idx_ai_models_provider ON ai_models (provider) WHERE is_active = true; CREATE INDEX idx_ai_models_type ON ai_models (model_type) WHERE is_active = true; CREATE INDEX idx_ai_models_is_active ON ai_models (is_active); CREATE INDEX idx_ai_models_config_gin ON ai_models USING GIN (config); -- 触发器 CREATE TRIGGER update_ai_models_updated_at BEFORE UPDATE ON ai_models FOR EACH ROW EXECUTE FUNCTION update_updated_at_column(); ``` **设计说明**: - 动态配置模型,无需修改代码 - 支持多种计费单位(token、图片、秒、请求) - 使用 JSONB 存储模型特定配置 - 支持速率限制和配额管理 - **新增提供商**: - `google`:支持 Gemini Pro 和 Gemini 1.5 Pro - `custom`:支持自定义 AI 提供商(用户自建模型、第三方 API 等) **Custom 提供商使用场景**: - 用户自建的 AI 模型服务 - 第三方 AI API 集成 - 企业内部 AI 服务 - 测试和开发环境 **Custom 提供商配置示例**: ```json { "api_endpoint": "https://custom-ai.example.com/v1/generate", "auth_type": "bearer", "headers": { "Authorization": "Bearer ${API_KEY}" }, "request_format": "json", "response_format": "json" } ``` --- ### 3.3 ai_usage_logs(AI 使用日志表) **计费必需表**,记录每次 AI 调用的详细使用情况。 ```sql CREATE TABLE ai_usage_logs ( usage_log_id UUID PRIMARY KEY, -- 使用日志唯一标识 -- 关联信息(无外键约束,应用层保证引用完整性) user_id UUID NOT NULL, -- 用户 ID ai_job_id UUID NOT NULL, -- AI 任务 ID model_id UUID, -- AI 模型 ID -- 使用量 units_used NUMERIC(10, 2) NOT NULL, -- 使用的单位数(tokens, 秒数等) unit_type SMALLINT NOT NULL CHECK (unit_type BETWEEN 1 AND 4), -- 单位类型(1=Token 2=图片 3=秒 4=请求) -- 成本 cost NUMERIC(10, 4) NOT NULL, -- 实际成本(元) credits_used INTEGER NOT NULL, -- 消耗的积分 -- 元数据 meta_data JSONB NOT NULL DEFAULT '{}', -- 使用元数据 -- 时间 created_at TIMESTAMPTZ NOT NULL DEFAULT now() -- 创建时间 ); -- 字段注释 COMMENT ON TABLE ai_usage_logs IS 'AI 使用日志表 - 应用层保证引用完整性'; COMMENT ON COLUMN ai_usage_logs.usage_log_id IS '使用日志唯一标识'; COMMENT ON COLUMN ai_usage_logs.user_id IS '用户 ID - 应用层验证'; COMMENT ON COLUMN ai_usage_logs.ai_job_id IS 'AI 任务 ID - 应用层验证'; COMMENT ON COLUMN ai_usage_logs.model_id IS 'AI 模型 ID - 应用层验证'; COMMENT ON COLUMN ai_usage_logs.units_used IS '使用的单位数(tokens, 秒数等)'; COMMENT ON COLUMN ai_usage_logs.unit_type IS '单位类型(1=Token 2=图片 3=秒 4=请求)'; COMMENT ON COLUMN ai_usage_logs.cost IS '实际成本(元)'; COMMENT ON COLUMN ai_usage_logs.credits_used IS '消耗的积分'; COMMENT ON COLUMN ai_usage_logs.meta_data IS '使用元数据'; COMMENT ON COLUMN ai_usage_logs.created_at IS '创建时间'; -- 索引 CREATE INDEX idx_ai_usage_logs_user_id ON ai_usage_logs (user_id); CREATE INDEX idx_ai_usage_logs_ai_job_id ON ai_usage_logs (ai_job_id); CREATE INDEX idx_ai_usage_logs_model_id ON ai_usage_logs (model_id) WHERE model_id IS NOT NULL; CREATE INDEX idx_ai_usage_logs_created_at ON ai_usage_logs (created_at); CREATE INDEX idx_ai_usage_logs_user_created ON ai_usage_logs (user_id, created_at); CREATE INDEX idx_ai_usage_logs_meta_data_gin ON ai_usage_logs USING GIN (meta_data); -- 分区策略(按月分区,便于归档) -- CREATE TABLE ai_usage_logs_2026_01 PARTITION OF ai_usage_logs -- FOR VALUES FROM ('2026-01-01') TO ('2026-02-01'); ``` **设计说明**: - 每次 AI 调用都记录详细日志 - 支持成本分析和用户账单生成 - 建议按月分区,便于历史数据归档 - 使用 JSONB 存储详细的使用元数据 --- ### 3.4 ai_quotas(AI 配额表) **推荐表**,管理用户的 AI 使用配额和限流。 ```sql -- Python 枚举定义(app/models/ai_quota.py) -- class QuotaPeriod(IntEnum): -- DAILY = 1 # 每日配额 -- MONTHLY = 2 # 每月配额 -- TOTAL = 3 # 总配额 CREATE TABLE ai_quotas ( quota_id UUID PRIMARY KEY, -- 配额唯一标识 -- 用户信息(无外键约束,应用层保证引用完整性) user_id UUID NOT NULL, -- 用户 ID -- 配额类型 quota_type TEXT NOT NULL, -- 配额类型(如:image_generation, video_generation, text_processing) period SMALLINT NOT NULL CHECK (period BETWEEN 1 AND 3), -- 配额周期(1=每日 2=每月 3=总计) -- 配额限制 total_quota INTEGER NOT NULL, -- 总配额 used_quota INTEGER NOT NULL DEFAULT 0, -- 已使用配额 -- 重置时间 reset_at TIMESTAMPTZ NOT NULL, -- 重置时间 -- 审计 created_at TIMESTAMPTZ NOT NULL DEFAULT now(), -- 创建时间 updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), -- 更新时间 -- 唯一约束 CONSTRAINT ai_quotas_unique UNIQUE (user_id, quota_type, period) NULLS NOT DISTINCT, -- 检查约束 CONSTRAINT ai_quotas_used_check CHECK (used_quota >= 0 AND used_quota <= total_quota) ); -- 字段注释 COMMENT ON TABLE ai_quotas IS 'AI 配额表 - 应用层保证引用完整性'; COMMENT ON COLUMN ai_quotas.quota_id IS '配额唯一标识'; COMMENT ON COLUMN ai_quotas.user_id IS '用户 ID - 应用层验证'; COMMENT ON COLUMN ai_quotas.quota_type IS '配额类型(如:image_generation, video_generation, text_processing)'; COMMENT ON COLUMN ai_quotas.period IS '配额周期(1=每日 2=每月 3=总计)'; COMMENT ON COLUMN ai_quotas.total_quota IS '总配额'; COMMENT ON COLUMN ai_quotas.used_quota IS '已使用配额'; COMMENT ON COLUMN ai_quotas.reset_at IS '重置时间'; COMMENT ON COLUMN ai_quotas.created_at IS '创建时间'; COMMENT ON COLUMN ai_quotas.updated_at IS '更新时间'; -- 索引 CREATE INDEX idx_ai_quotas_user_id ON ai_quotas (user_id); CREATE INDEX idx_ai_quotas_type ON ai_quotas (quota_type); CREATE INDEX idx_ai_quotas_reset_at ON ai_quotas (reset_at); CREATE INDEX idx_ai_quotas_user_type ON ai_quotas (user_id, quota_type); -- 触发器 CREATE TRIGGER update_ai_quotas_updated_at BEFORE UPDATE ON ai_quotas FOR EACH ROW EXECUTE FUNCTION update_updated_at_column(); -- 自动重置配额的函数 CREATE OR REPLACE FUNCTION reset_expired_quotas() RETURNS void AS $$ BEGIN -- 重置每日配额 (period = 1) UPDATE ai_quotas SET used_quota = 0, reset_at = reset_at + INTERVAL '1 day', updated_at = now() WHERE period = 1 AND reset_at <= now(); -- 重置每月配额 (period = 2) UPDATE ai_quotas SET used_quota = 0, reset_at = reset_at + INTERVAL '1 month', updated_at = now() WHERE period = 2 AND reset_at <= now(); END; $$ LANGUAGE plpgsql; -- 定时任务(需要配合 pg_cron 或应用层定时任务) -- SELECT cron.schedule('reset-ai-quotas', '0 0 * * *', 'SELECT reset_expired_quotas()'); ``` **设计说明**: - 支持每日、每月、总配额三种类型 - 自动重置过期配额 - 防止配额超限(CHECK 约束) - 支持不同类型的配额管理 --- ### 3.5 数据表关系图 ``` ┌─────────────┐ │ users │ └──────┬──────┘ │ ├──────────────────────────────────┐ │ │ ▼ ▼ ┌─────────────┐ ┌─────────────┐ │ ai_jobs │───────────────────>│ ai_models │ └──────┬──────┘ └─────────────┘ │ ▲ │ │ ▼ │ ┌─────────────┐ │ │ai_usage_logs│──────────────────────────┘ └─────────────┘ ┌─────────────┐ │ ai_quotas │ └─────────────┘ ▲ │ └──────── users ``` --- ## Pydantic Schema 定义 ### 请求 Schema ```python # app/schemas/ai_job.py from pydantic import BaseModel, Field, field_validator from uuid import UUID from typing import Optional, Dict, Any from enum import IntEnum class AIJobType(IntEnum): """AI 任务类型""" IMAGE = 1 VIDEO = 2 SOUND = 3 VOICE = 4 SUBTITLE = 5 TEXT_PROCESSING = 6 RESOURCE = 7 STORYBOARD_SCRIPT = 8 SCRIPT_GENERATION = 9 class VideoType(str): """视频生成类型""" TEXT2VIDEO = "text2video" IMG2VIDEO = "img2video" KEYFRAME = "keyframe" FUSION = "fusion" REPLACE = "replace" class TextProcessingTaskType(str): """文本处理任务类型""" SCREENPLAY_PARSE = "screenplay_parse" CONTENT_ANALYSIS = "content_analysis" STYLE_TRANSFORM = "style_transform" PROMPT_GENERATION = "prompt_generation" # ==================== 图片生成 ==================== class ImageGenerationRequest(BaseModel): """图片生成请求""" prompt: str = Field(..., description="提示词", min_length=1, max_length=2000) model: Optional[str] = Field(None, description="模型名称") width: int = Field(1024, description="宽度", ge=256, le=2048) height: int = Field(1024, description="高度", ge=256, le=2048) style: Optional[str] = Field(None, description="风格") class Config: json_schema_extra = { "example": { "prompt": "一只可爱的猫咪在花园里玩耍", "model": "stable_diffusion", "width": 1024, "height": 1024, "style": "realistic" } } # ==================== 视频生成 ==================== class VideoGenerationRequest(BaseModel): """视频生成请求""" video_type: str = Field(..., description="视频类型: text2video, img2video, keyframe, fusion, replace") prompt: Optional[str] = Field(None, description="提示词(text2video 必需)") image_url: Optional[str] = Field(None, description="图片 URL(img2video 必需)") duration: int = Field(5, description="时长(秒)", ge=1, le=60) fps: int = Field(30, description="帧率", ge=15, le=60) model: Optional[str] = Field(None, description="模型名称") @field_validator('video_type') @classmethod def validate_video_type(cls, v): valid_types = ['text2video', 'img2video', 'keyframe', 'fusion', 'replace'] if v not in valid_types: raise ValueError(f"video_type 必须是以下之一: {', '.join(valid_types)}") return v @field_validator('prompt') @classmethod def validate_prompt(cls, v, info): if info.data.get('video_type') == 'text2video' and not v: raise ValueError("text2video 类型必须提供 prompt") return v @field_validator('image_url') @classmethod def validate_image_url(cls, v, info): if info.data.get('video_type') == 'img2video' and not v: raise ValueError("img2video 类型必须提供 image_url") return v # ==================== 音效生成 ==================== class SoundGenerationRequest(BaseModel): """音效生成请求""" description: str = Field(..., description="音效描述", min_length=1, max_length=500) duration: int = Field(5, description="时长(秒)", ge=1, le=30) sound_type: Optional[str] = Field(None, description="音效类型") model: Optional[str] = Field(None, description="模型名称") # ==================== 配音生成 ==================== class VoiceGenerationRequest(BaseModel): """配音生成请求""" text: str = Field(..., description="文本内容", min_length=1, max_length=5000) voice_type: str = Field("alloy", description="语音类型") speed: float = Field(1.0, description="语速", ge=0.5, le=2.0) language: str = Field("zh-CN", description="语言") model: Optional[str] = Field(None, description="模型名称") # ==================== 字幕生成 ==================== class SubtitleGenerationRequest(BaseModel): """字幕生成请求""" audio_url: str = Field(..., description="音频 URL") language: str = Field("zh", description="语言") model: Optional[str] = Field(None, description="模型名称") # ==================== 文本处理 ==================== class TextProcessingRequest(BaseModel): """文本处理请求""" task_type: str = Field(..., description="任务类型: screenplay_parse, content_analysis, style_transform, prompt_generation") text: str = Field(..., description="待处理文本", min_length=1, max_length=50000) model: Optional[str] = Field(None, description="模型名称") output_format: str = Field("json", description="输出格式: json, text") temperature: float = Field(0.7, description="生成温度", ge=0.0, le=2.0) max_tokens: int = Field(4000, description="最大 token 数", ge=100, le=8000) @field_validator('task_type') @classmethod def validate_task_type(cls, v): valid_types = ['screenplay_parse', 'content_analysis', 'style_transform', 'prompt_generation'] if v not in valid_types: raise ValueError(f"task_type 必须是以下之一: {', '.join(valid_types)}") return v # ==================== 统一创建请求 ==================== class AIJobCreateRequest(BaseModel): """AI 任务创建请求(统一入口)""" job_type: str = Field(..., description="任务类型: image, video, sound, voice, subtitle, textProcessing") params: Dict[str, Any] = Field(..., description="任务参数") @field_validator('job_type') @classmethod def validate_job_type(cls, v): valid_types = ['image', 'video', 'sound', 'voice', 'subtitle', 'textProcessing'] if v not in valid_types: raise ValueError(f"job_type 必须是以下之一: {', '.join(valid_types)}") return v ``` ### 响应 Schema ```python # app/schemas/ai_job.py (续) from datetime import datetime class AIJobResponse(BaseModel): """AI 任务响应""" job_id: UUID = Field(..., description="任务 ID") task_id: str = Field(..., description="Celery 任务 ID") status: str = Field(..., description="任务状态") estimated_credits: int = Field(..., description="预估积分消耗") class Config: json_schema_extra = { "example": { "job_id": "019d1234-5678-7abc-def0-111111111111", "task_id": "abc-123-def", "status": "pending", "estimated_credits": 10 } } class AIJobStatusResponse(BaseModel): """AI 任务状态响应""" job_id: UUID job_type: int status: int progress: int input_data: Dict[str, Any] output_data: Optional[Dict[str, Any]] error_message: Optional[str] model_name: Optional[str] cost: Optional[float] credits_used: int created_at: datetime updated_at: datetime started_at: Optional[datetime] completed_at: Optional[datetime] class AIJobListResponse(BaseModel): """AI 任务列表响应""" items: list[AIJobStatusResponse] total: int page: int page_size: int total_pages: int class AIModelResponse(BaseModel): """AI 模型响应""" model_id: UUID model_name: str display_name: str description: Optional[str] provider: int model_type: int cost_per_unit: float unit_type: int credits_per_unit: int is_beta: bool config: Dict[str, Any] class AIUsageStatsResponse(BaseModel): """AI 使用统计响应""" total_cost: float total_credits_used: int total_requests: int by_model: Dict[str, Dict[str, Any]] by_type: Dict[str, Dict[str, Any]] quotas: list[Dict[str, Any]] ``` --- ## 服务实现 ### AIService 类 ```python # app/services/ai_service.py from typing import Dict, Any, Optional from uuid import UUID from app.tasks.ai_tasks import ( generate_image_task, generate_video_task, generate_sound_task, generate_voice_task, generate_subtitle_task, process_text_task ) from app.repositories.ai_job_repository import AIJobRepository from app.repositories.ai_model_repository import AIModelRepository from app.repositories.ai_usage_log_repository import AIUsageLogRepository from app.repositories.ai_quota_repository import AIQuotaRepository from app.services.credit_service import CreditService from app.core.exceptions import InsufficientCreditsError, ValidationError from sqlalchemy.orm import Session class AIService: def __init__(self, db: Session): self.db = db self.job_repository = AIJobRepository(db) self.model_repository = AIModelRepository(db) self.usage_log_repository = AIUsageLogRepository(db) self.quota_repository = AIQuotaRepository(db) self.credit_service = CreditService(db) # 注入 Credit Service # ==================== 枚举验证方法 ==================== def _validate_job_type(self, job_type: int) -> None: """验证任务类型枚举值""" from app.models.ai_job import AIJobType if job_type not in AIJobType.to_dict().keys(): raise ValidationError(f"无效的任务类型: {job_type}") def _validate_job_status(self, job_status: int) -> None: """验证任务状态枚举值""" from app.models.ai_job import AIJobStatus if job_status not in AIJobStatus.to_dict().keys(): raise ValidationError(f"无效的任务状态: {job_status}") def _validate_provider(self, provider: int) -> None: """验证提供商枚举值""" from app.models.ai_model import AIModelProvider if provider not in AIModelProvider.to_dict().keys(): raise ValidationError(f"无效的提供商: {provider}") def _validate_model_type(self, model_type: int) -> None: """验证模型类型枚举值""" from app.models.ai_model import AIModelType if model_type not in AIModelType.to_dict().keys(): raise ValidationError(f"无效的模型类型: {model_type}") def _validate_unit_type(self, unit_type: int) -> None: """验证单位类型枚举值""" from app.models.ai_model import AIModelUnitType if unit_type not in AIModelUnitType.to_dict().keys(): raise ValidationError(f"无效的单位类型: {unit_type}") # ==================== 私有方法 ==================== async def _check_quota(self, user_id: UUID, quota_type: str) -> bool: """检查用户配额是否充足""" quota = await self.quota_repository.get_user_quota(user_id, quota_type) if not quota: return True # 没有配额限制 if quota.used_quota >= quota.total_quota: from app.core.exceptions import QuotaExceededError raise QuotaExceededError(f"配额不足:{quota_type}") return True def _get_feature_type(self, job_type: int) -> int: """将 AIJobType 映射到 Credit Service 的 feature_type Args: job_type: AIJobType 枚举值 Returns: Credit Service 的 feature_type 枚举值(SMALLINT) """ from app.models.ai_job import AIJobType # 直接返回 job_type,因为 AI Service 和 Credit Service 使用相同的枚举值 # AIJobType: 1=图片 2=视频 3=音效 4=配音 5=字幕 6=文本处理 7=资源 8=分镜脚本 9=剧本生成 # FeatureType: 1=图片生成 2=视频生成 3=文本处理 4=音频生成 5=音效生成 6=配音生成 7=字幕生成 8=资源生成 9=分镜脚本生成 10=剧本生成 # 需要特殊处理的映射: # AIJobType.TEXT_PROCESSING (6) -> FeatureType.TEXT_PROCESSING (3) # AIJobType.RESOURCE (7) -> FeatureType.RESOURCE_GENERATION (8) # AIJobType.STORYBOARD_SCRIPT (8) -> FeatureType.STORYBOARD_GENERATION (9) # AIJobType.SCRIPT_GENERATION (9) -> FeatureType.SCRIPT_GENERATION (10) feature_type_map = { AIJobType.IMAGE: 1, # 图片生成 AIJobType.VIDEO: 2, # 视频生成 AIJobType.TEXT_PROCESSING: 3, # 文本处理 AIJobType.SOUND: 5, # 音效生成 AIJobType.VOICE: 6, # 配音生成 AIJobType.SUBTITLE: 7, # 字幕生成 AIJobType.RESOURCE: 8, # 资源生成 AIJobType.STORYBOARD_SCRIPT: 9, # 分镜脚本生成 AIJobType.SCRIPT_GENERATION: 10 # 剧本生成 } return feature_type_map.get(job_type, 1) # 默认返回图片生成 async def _get_model(self, model_name: Optional[str], model_type: int) -> Dict[str, Any]: """获取模型配置 Args: model_name: 模型名称 model_type: 模型类型(AIModelType 枚举值) """ # 验证模型类型枚举值 self._validate_model_type(model_type) if model_name: model = await self.model_repository.get_by_name(model_name) else: # 获取默认模型 model = await self.model_repository.get_default_model(model_type) if not model or not model.is_active: from app.core.exceptions import NotFoundError raise NotFoundError(f"模型不可用: {model_name}") return model async def _record_usage( self, user_id: UUID, job_id: UUID, model_id: UUID, units_used: float, unit_type: int, # UnitType 枚举值 cost: float, credits_used: int, meta_data: Dict[str, Any] ) -> None: """记录使用日志""" # 验证单位类型枚举值 self._validate_unit_type(unit_type) await self.usage_log_repository.create({ 'user_id': user_id, 'ai_job_id': job_id, 'model_id': model_id, 'units_used': units_used, 'unit_type': unit_type, 'cost': cost, 'credits_used': credits_used, 'meta_data': meta_data }) async def generate_image( self, user_id: UUID, prompt: str, model: Optional[str] = None, width: int = 1024, height: int = 1024, style: Optional[str] = None, **kwargs ) -> Dict[str, Any]: """生成图片(异步) 应用层引用完整性保证: 1. 验证 user_id 是否存在 2. 验证 model_id 是否存在 3. 使用事务确保原子性 """ # 1. 验证用户是否存在(应用层引用完整性保证) from app.repositories.user_repository import UserRepository user_repo = UserRepository(self.db) if not await user_repo.exists(user_id): raise ValidationError("用户不存在") # 2. 检查配额 await self._check_quota(user_id, 'image_generation') # 3. 获取模型配置(AIModelType.IMAGE = 2) from app.models.ai_model import AIModelType model_config = await self._get_model(model, AIModelType.IMAGE) # 4. 验证模型是否存在(应用层引用完整性保证) if not await self.model_repository.exists(model_config.model_id): raise ValidationError("AI 模型不存在") # 5. 计算所需积分 credits_needed = model_config.credits_per_unit # 6. 使用事务确保原子性 async with self.db.begin(): # 预扣积分(同步,事务保证) try: consumption_log = await self.credit_service.consume_credits( user_id=user_id, amount=credits_needed, feature_type=1, # FeatureType.IMAGE_GENERATION ai_job_id=None, # 稍后更新 task_params={ 'prompt': prompt, 'model': model_config.model_name, 'width': width, 'height': height, 'style': style, **kwargs } ) except InsufficientCreditsError as e: raise ValidationError(f"积分不足: {str(e)}") # 创建任务记录 from app.models.ai_job import AIJobType, AIJobStatus job = await self.job_repository.create({ 'job_type': AIJobType.IMAGE, 'status': AIJobStatus.PENDING, 'user_id': user_id, 'model_id': model_config.model_id, 'model_name': model_config.model_name, 'consumption_log_id': consumption_log.consumption_id, 'input_data': { 'prompt': prompt, 'width': width, 'height': height, 'style': style, **kwargs } }) # 更新 consumption_log 的 ai_job_id consumption_log.ai_job_id = job.ai_job_id consumption_log.task_id = str(job.ai_job_id) await self.db.commit() # 7. 提交异步任务 task = generate_image_task.delay( job_id=job.ai_job_id, prompt=prompt, model=model_config.model_name, width=width, height=height, style=style, **kwargs ) # 8. 更新任务 ID await self.job_repository.update(job.ai_job_id, {'task_id': task.id}) return { 'job_id': str(job.ai_job_id), 'task_id': task.id, 'status': 'pending', 'estimated_cost': float(model_config.cost_per_unit), 'estimated_credits': credits_needed } async def generate_video( self, user_id: UUID, video_type: str, prompt: Optional[str] = None, image_url: Optional[str] = None, duration: int = 5, fps: int = 30, model: Optional[str] = None, **kwargs ) -> Dict[str, Any]: """生成视频(异步)""" # 验证参数 if video_type == 'text2video' and not prompt: raise ValidationError("文本转视频需要提供 prompt") if video_type == 'img2video' and not image_url: raise ValidationError("图片转视频需要提供 image_url") # 检查配额 await self._check_quota(user_id, 'video_generation') # 获取模型配置(AIModelType.VIDEO = 3) from app.models.ai_model import AIModelType model_config = await self._get_model(model, AIModelType.VIDEO) # 计算所需积分 feature_type = self._get_feature_type(AIJobType.VIDEO) credits_needed = await self.credit_service.calculate_credits( feature_type=feature_type, params={ 'model_name': model_config.model_name, 'video_type': video_type, 'duration': duration, 'quality': kwargs.get('quality', 'sd') } ) # 预扣积分(同步,事务保证) try: consumption_log = await self.credit_service.consume_credits( user_id=user_id, amount=credits_needed, feature_type=feature_type, ai_job_id=None, task_params={ 'video_type': video_type, 'model': model_config.model_name, 'duration': duration, 'fps': fps, **kwargs } ) except InsufficientCreditsError as e: raise ValidationError(f"积分不足: {str(e)}") # 创建任务记录 from app.models.ai_job import AIJobType, AIJobStatus job = await self.job_repository.create({ 'job_type': AIJobType.VIDEO, 'status': AIJobStatus.PENDING, 'user_id': user_id, 'model_id': model_config.model_id, 'model_name': model_config.model_name, 'consumption_log_id': consumption_log.consumption_id, 'credits_used': credits_needed, 'input_data': { 'video_type': video_type, 'prompt': prompt, 'image_url': image_url, 'duration': duration, 'fps': fps, **kwargs } }) # 更新 consumption_log 的 ai_job_id consumption_log.ai_job_id = job.ai_job_id consumption_log.task_id = str(job.ai_job_id) await self.db.commit() # 提交异步任务 task = generate_video_task.delay( job_id=job.ai_job_id, video_type=video_type, prompt=prompt, image_url=image_url, duration=duration, fps=fps, **kwargs ) # 更新任务 ID await self.job_repository.update(job.ai_job_id, {'task_id': task.id}) return { 'job_id': str(job.ai_job_id), 'task_id': task.id, 'status': 'pending', 'estimated_credits': credits_needed } async def generate_sound( self, user_id: UUID, description: str, duration: int = 5, sound_type: Optional[str] = None, model: Optional[str] = None, **kwargs ) -> Dict[str, Any]: """生成音效(异步)""" # 检查配额 await self._check_quota(user_id, 'sound_generation') # 获取模型配置(AIModelType.AUDIO = 4) from app.models.ai_model import AIModelType model_config = await self._get_model(model, AIModelType.AUDIO) # 计算所需积分 feature_type = self._get_feature_type(AIJobType.SOUND) credits_needed = await self.credit_service.calculate_credits( feature_type=feature_type, params={ 'model_name': model_config.model_name, 'duration': duration } ) # 预扣积分(同步,事务保证) try: consumption_log = await self.credit_service.consume_credits( user_id=user_id, amount=credits_needed, feature_type=feature_type, ai_job_id=None, task_params={ 'description': description, 'model': model_config.model_name, 'duration': duration, 'sound_type': sound_type, **kwargs } ) except InsufficientCreditsError as e: raise ValidationError(f"积分不足: {str(e)}") # 创建任务记录 from app.models.ai_job import AIJobType, AIJobStatus job = await self.job_repository.create({ 'job_type': AIJobType.SOUND, 'status': AIJobStatus.PENDING, 'user_id': user_id, 'model_id': model_config.model_id, 'model_name': model_config.model_name, 'consumption_log_id': consumption_log.consumption_id, 'credits_used': credits_needed, 'input_data': { 'description': description, 'duration': duration, 'sound_type': sound_type, **kwargs } }) # 更新 consumption_log 的 ai_job_id consumption_log.ai_job_id = job.ai_job_id consumption_log.task_id = str(job.ai_job_id) await self.db.commit() # 提交异步任务 task = generate_sound_task.delay( job_id=job.ai_job_id, description=description, duration=duration, sound_type=sound_type, **kwargs ) # 更新任务 ID await self.job_repository.update(job.ai_job_id, {'task_id': task.id}) return { 'job_id': str(job.ai_job_id), 'task_id': task.id, 'status': 'pending', 'estimated_credits': credits_needed } async def generate_voice( self, user_id: UUID, text: str, voice_type: str = 'alloy', speed: float = 1.0, language: str = 'zh-CN', model: Optional[str] = None, **kwargs ) -> Dict[str, Any]: """生成配音(异步)""" # 检查配额 await self._check_quota(user_id, 'voice_generation') # 获取模型配置(AIModelType.AUDIO = 4) from app.models.ai_model import AIModelType model_config = await self._get_model(model, AIModelType.AUDIO) # 计算所需积分 feature_type = self._get_feature_type(AIJobType.VOICE) credits_needed = await self.credit_service.calculate_credits( feature_type=feature_type, params={ 'model_name': model_config.model_name, 'char_count': len(text) } ) # 预扣积分(同步,事务保证) try: consumption_log = await self.credit_service.consume_credits( user_id=user_id, amount=credits_needed, feature_type=feature_type, ai_job_id=None, task_params={ 'text': text, 'model': model_config.model_name, 'voice_type': voice_type, 'speed': speed, 'language': language, **kwargs } ) except InsufficientCreditsError as e: raise ValidationError(f"积分不足: {str(e)}") # 创建任务记录 from app.models.ai_job import AIJobType, AIJobStatus job = await self.job_repository.create({ 'job_type': AIJobType.VOICE, 'status': AIJobStatus.PENDING, 'user_id': user_id, 'model_id': model_config.model_id, 'model_name': model_config.model_name, 'consumption_log_id': consumption_log.consumption_id, 'credits_used': credits_needed, 'input_data': { 'text': text, 'voice_type': voice_type, 'speed': speed, 'language': language, **kwargs } }) # 更新 consumption_log 的 ai_job_id consumption_log.ai_job_id = job.ai_job_id consumption_log.task_id = str(job.ai_job_id) await self.db.commit() # 提交异步任务 task = generate_voice_task.delay( job_id=job.ai_job_id, text=text, voice_type=voice_type, speed=speed, language=language, **kwargs ) # 更新任务 ID await self.job_repository.update(job.ai_job_id, {'task_id': task.id}) return { 'job_id': str(job.ai_job_id), 'task_id': task.id, 'status': 'pending', 'estimated_credits': credits_needed } async def generate_subtitle( self, user_id: UUID, audio_url: str, language: str = 'zh', model: Optional[str] = None, **kwargs ) -> Dict[str, Any]: """生成字幕(异步)""" # 检查配额 await self._check_quota(user_id, 'subtitle_generation') # 获取模型配置(AIModelType.AUDIO = 4) from app.models.ai_model import AIModelType model_config = await self._get_model(model, AIModelType.AUDIO) # 计算所需积分(假设音频时长为 60 秒,实际应从音频文件获取) duration = kwargs.get('duration', 60) feature_type = self._get_feature_type(AIJobType.SUBTITLE) credits_needed = await self.credit_service.calculate_credits( feature_type=feature_type, params={ 'model_name': model_config.model_name, 'duration': duration } ) # 预扣积分(同步,事务保证) try: consumption_log = await self.credit_service.consume_credits( user_id=user_id, amount=credits_needed, feature_type=feature_type, ai_job_id=None, task_params={ 'audio_url': audio_url, 'model': model_config.model_name, 'language': language, 'duration': duration, **kwargs } ) except InsufficientCreditsError as e: raise ValidationError(f"积分不足: {str(e)}") # 创建任务记录 from app.models.ai_job import AIJobType, AIJobStatus job = await self.job_repository.create({ 'job_type': AIJobType.SUBTITLE, 'status': AIJobStatus.PENDING, 'user_id': user_id, 'model_id': model_config.model_id, 'model_name': model_config.model_name, 'consumption_log_id': consumption_log.consumption_id, 'credits_used': credits_needed, 'input_data': { 'audio_url': audio_url, 'language': language, 'duration': duration, **kwargs } }) # 更新 consumption_log 的 ai_job_id consumption_log.ai_job_id = job.ai_job_id consumption_log.task_id = str(job.ai_job_id) await self.db.commit() # 提交异步任务 task = generate_subtitle_task.delay( job_id=job.ai_job_id, audio_url=audio_url, language=language, **kwargs ) # 更新任务 ID await self.job_repository.update(job.ai_job_id, {'task_id': task.id}) return { 'job_id': str(job.ai_job_id), 'task_id': task.id, 'status': 'pending', 'estimated_credits': credits_needed } async def process_text( self, user_id: UUID, task_type: str, text: str, model: Optional[str] = None, **kwargs ) -> Dict[str, Any]: """处理文本(异步) Args: user_id: 用户 ID task_type: 任务类型 - screenplay_parse: 剧本拆解 - content_analysis: 内容分析 - style_transform: 风格转换 - prompt_generation: 提示词生成 text: 待处理的文本 model: 使用的模型(默认使用配置的模型) **kwargs: 其他参数 - output_format: 输出格式 (json/text) - temperature: 生成温度 (0.0-2.0) - max_tokens: 最大 token 数 """ # 验证任务类型 valid_task_types = [ 'screenplay_parse', 'content_analysis', 'style_transform', 'prompt_generation' ] if task_type not in valid_task_types: raise ValidationError(f"不支持的任务类型: {task_type}") # 创建任务记录 from app.models.ai_job import AIJobType, AIJobStatus job = await self.job_repository.create({ 'job_type': AIJobType.TEXT_PROCESSING, 'status': AIJobStatus.PENDING, 'user_id': user_id, 'input_data': { 'task_type': task_type, 'text': text, 'model': model or 'gpt-4', 'output_format': kwargs.get('output_format', 'json'), 'temperature': kwargs.get('temperature', 0.7), 'max_tokens': kwargs.get('max_tokens', 4000), **kwargs } }) # 提交异步任务 from app.tasks.ai_tasks import process_text_task task = process_text_task.delay( job_id=job.id, task_type=task_type, text=text, model=model or 'gpt-4', **kwargs ) return { 'job_id': str(job.id), 'task_id': task.id, 'status': 'pending' } async def get_job_status(self, job_id: UUID) -> Dict[str, Any]: """查询任务状态""" job = await self.job_repository.get_by_id(job_id) if not job: from app.core.exceptions import NotFoundError raise NotFoundError("任务不存在") return { 'job_id': str(job.id), 'job_type': job.job_type, 'status': job.status, 'progress': job.progress, 'input_data': job.input_data, 'output_data': job.output_data, 'error_message': job.error_message, 'model_name': job.model_name, 'cost': float(job.cost) if job.cost else None, 'credits_used': job.credits_used, 'created_at': job.created_at.isoformat(), 'updated_at': job.updated_at.isoformat(), 'started_at': job.started_at.isoformat() if job.started_at else None, 'completed_at': job.completed_at.isoformat() if job.completed_at else None } async def cancel_job(self, user_id: UUID, job_id: UUID) -> None: """取消任务""" job = await self.job_repository.get_by_id(job_id) if not job: from app.core.exceptions import NotFoundError raise NotFoundError("任务不存在") if job.user_id != user_id: from app.core.exceptions import PermissionDeniedError raise PermissionDeniedError("没有权限取消此任务") from app.models.ai_job import AIJobStatus if job.status in [AIJobStatus.COMPLETED, AIJobStatus.FAILED]: raise ValidationError("任务已完成或失败,无法取消") # 取消 Celery 任务 from app.tasks.celery_app import celery_app if job.task_id: celery_app.control.revoke(job.task_id, terminate=True) # 更新任务状态 from app.models.ai_job import AIJobStatus await self.job_repository.update(job_id, { 'status': AIJobStatus.CANCELLED, 'error_message': '用户取消' }) async def get_user_usage_stats( self, user_id: UUID, start_date: Optional[str] = None, end_date: Optional[str] = None ) -> Dict[str, Any]: """获取用户使用统计""" stats = await self.usage_log_repository.get_user_stats( user_id, start_date, end_date ) quotas = await self.quota_repository.get_user_quotas(user_id) return { 'total_cost': float(stats.get('total_cost', 0)), 'total_credits_used': stats.get('total_credits_used', 0), 'total_requests': stats.get('total_requests', 0), 'by_model': stats.get('by_model', {}), 'by_type': stats.get('by_type', {}), 'quotas': [ { 'quota_type': q.quota_type, 'period': q.period, 'total_quota': q.total_quota, 'used_quota': q.used_quota, 'remaining_quota': q.total_quota - q.used_quota, 'reset_at': q.reset_at.isoformat() } for q in quotas ] } async def get_available_models( self, model_type: Optional[int] = None ) -> list[Dict[str, Any]]: """获取可用的 AI 模型列表 Args: model_type: 模型类型(AIModelType 枚举值,可选) """ models = await self.model_repository.get_active_models(model_type) return [ { 'model_id': str(m.model_id), 'model_name': m.model_name, 'display_name': m.display_name, 'description': m.description, 'provider': m.provider, 'model_type': m.model_type, 'cost_per_unit': float(m.cost_per_unit), 'unit_type': m.unit_type, 'credits_per_unit': m.credits_per_unit, 'is_beta': m.is_beta, 'config': m.config } for m in models ] ``` --- ## API 接口 ### 1. 创建 AI 任务(统一入口) ``` POST /api/v1/ai/jobs ``` **请求体(图片生成)**: ```json { "jobType": "image", "prompt": "一只可爱的猫咪在花园里玩耍", "model": "stable_diffusion", "width": 1024, "height": 1024, "style": "realistic" } ``` **请求体(视频生成 - 文本转视频)**: ```json { "jobType": "video", "videoType": "text2video", "prompt": "一只猫咪在花园里奔跑", "duration": 5, "fps": 30 } ``` **请求体(视频生成 - 图片转视频)**: ```json { "jobType": "video", "videoType": "img2video", "imageUrl": "https://example.com/image.jpg", "duration": 5, "fps": 30 } ``` **请求体(音效生成)**: ```json { "jobType": "sound", "description": "雨声", "duration": 10, "soundType": "ambient" } ``` **请求体(配音生成)**: ```json { "jobType": "voice", "text": "欢迎来到Jointo平台", "voiceType": "alloy", "speed": 1.0, "language": "zh-CN" } ``` **请求体(字幕生成)**: ```json { "jobType": "subtitle", "audioUrl": "https://example.com/audio.mp3", "language": "zh" } ``` **请求体(文本处理 - 剧本拆解)**: ```json { "jobType": "textProcessing", "taskType": "screenplay_parse", "text": "场景1:咖啡厅 - 白天\n小明走进咖啡厅,看到小红坐在窗边...", "model": "gpt-4", "outputFormat": "json", "temperature": 0.7 } ``` **响应**: ```json { "code": 200, "message": "Success", "data": { "jobId": "019d1234-5678-7abc-def0-111111111111", "taskId": "abc-123-def", "status": "pending", "estimatedCredits": 10 } } ``` **错误响应**: ```json // 400 - 参数错误 { "code": 400, "message": "参数验证失败: prompt 不能为空", "data": null } // 402 - 积分不足 { "code": 402, "message": "积分不足,当前余额: 5,需要: 10", "data": null } // 429 - 配额超限 { "code": 429, "message": "每日配额已用完,请明天再试", "data": null } // 500 - 服务器错误 { "code": 500, "message": "AI 服务暂时不可用,请稍后重试", "data": null } ``` ### 2. 查询任务状态 ``` GET /api/v1/ai/jobs/{job_id} ``` **响应**: ```json { "code": 200, "message": "Success", "data": { "jobId": "019d1234-5678-7abc-def0-111111111111", "jobType": "image", "status": "completed", "progress": 100, "outputData": { "fileUrl": "https://storage.jointo.ai/ai-generated/images/abc123def456.png", "fileSize": 1024000, "checksum": "abc123def456...", "mimeType": "image/png", "meta_data": { "width": 1024, "height": 1024, "format": "png" } }, "createdAt": "2026-01-27T10:00:00Z", "completedAt": "2026-01-27T10:00:30Z" } } ``` ### 3. 获取任务列表 ``` GET /api/v1/ai/jobs ``` **查询参数**: - `jobType`(可选):任务类型(image/video/sound/voice/subtitle/textProcessing) - `status`(可选):任务状态(pending/processing/completed/failed/cancelled) - `page`(可选):页码(默认 1) - `pageSize`(可选):每页数量(默认 20) **响应**: ```json { "code": 200, "message": "Success", "data": { "items": [ { "jobId": "019d1234-5678-7abc-def0-111111111111", "jobType": "image", "status": "completed", "createdAt": "2026-01-27T10:00:00Z" } ], "total": 100, "page": 1, "pageSize": 20, "totalPages": 5 } } ``` ### 4. 取消任务 ``` POST /api/v1/ai/jobs/{job_id}/cancel ``` **响应**: ```json { "code": 200, "message": "任务已取消", "data": null } ``` ### 5. 获取用户使用统计 ### 5. 获取用户使用统计 ``` GET /api/v1/ai/usage/stats ``` **查询参数**: - `startDate`(可选):开始日期(YYYY-MM-DD) - `endDate`(可选):结束日期(YYYY-MM-DD) **响应**: ```json { "code": 200, "message": "Success", "data": { "totalCost": 12.50, "totalCreditsUsed": 250, "totalRequests": 45, "byModel": { "gpt-4": { "requests": 20, "cost": 8.00, "creditsUsed": 160 }, "dall-e-3": { "requests": 15, "cost": 3.50, "creditsUsed": 75 } }, "byType": { "textProcessing": { "requests": 20, "cost": 8.00 }, "image": { "requests": 15, "cost": 3.50 } }, "quotas": [ { "quotaType": "imageGeneration", "period": "daily", "totalQuota": 50, "usedQuota": 15, "remainingQuota": 35, "resetAt": "2026-01-28T00:00:00Z" }, { "quotaType": "textProcessing", "period": "monthly", "totalQuota": 1000, "usedQuota": 250, "remainingQuota": 750, "resetAt": "2026-02-01T00:00:00Z" } ] } } ``` ### 6. 获取可用模型列表 ``` GET /api/v1/ai/models ``` **查询参数**: - `type`(可选):模型类型(text/image/video/audio) **响应**: ```json { "code": 200, "message": "Success", "data": { "items": [ { "modelId": 1, "modelName": "gpt-4", "displayName": "GPT-4", "description": "OpenAI 最强大的语言模型", "provider": "openai", "modelType": "text", "costPerUnit": 0.03, "unitType": "token", "creditsPerUnit": 1, "isBeta": false, "config": { "maxTokens": 8000 } }, { "modelId": 4, "modelName": "dall-e-3", "displayName": "DALL-E 3", "description": "OpenAI 图片生成模型", "provider": "openai", "modelType": "image", "costPerUnit": 0.04, "unitType": "image", "creditsPerUnit": 10, "isBeta": false, "config": { "supportedResolutions": ["1024x1024", "1792x1024", "1024x1792"] } } ], "total": 2, "page": 1, "pageSize": 20, "totalPages": 1 } } ``` --- ## 文件存储集成 ### 对象存储服务 AI 生成的文件需要存储到对象存储服务(OSS),系统通过 `FileStorageService` 提供统一的存储接口。 **存储策略**: - **开发环境**:使用 MinIO(本地对象存储) - **生产环境**:可切换到云服务商 OSS - 阿里云 OSS - AWS S3 - 腾讯云 COS - 七牛云 Kodo **配置方式**: 通过环境变量 `STORAGE_PROVIDER` 切换存储服务: ```bash # .env STORAGE_PROVIDER=minio # 开发环境 # STORAGE_PROVIDER=aliyun # 生产环境(阿里云 OSS) # STORAGE_PROVIDER=aws # 生产环境(AWS S3) ``` ### 文件存储流程 AI 任务完成后,需要将 AI 提供商返回的临时文件下载并存储到自有 OSS: ```python # AI Task 工作流(集成文件存储) async def generate_image_task(job_id: str, ...): try: # 1. 调用 AI Provider 生成图片 result = await provider.generate_image(...) # result = {"image_url": "https://ai-provider.com/temp/abc123.png"} # 2. 下载 AI 生成的文件 async with httpx.AsyncClient() as client: response = await client.get(result['image_url']) image_data = response.content # 3. 上传到 OSS(带去重) async with async_session_maker() as session: file_storage = FileStorageService(session) file_meta_data = await file_storage.upload_file( file_content=image_data, filename=f"{job_id}.png", content_type="image/png", category="ai-generated/images", user_id=user_id ) # 4. 更新任务结果(使用自有 URL) output_data = { "file_url": file_meta_data.file_url, # 自有 OSS URL "file_size": file_meta_data.file_size, "checksum": file_meta_data.checksum, "mime_type": file_meta_data.mime_type, "original_url": result['image_url'], # 保留原始 URL(可选) "meta_data": { "width": result.get('width'), "height": result.get('height'), "format": "png" } } await _update_job_status( job_id, AIJobStatus.COMPLETED, progress=100, output_data=output_data ) # 5. 确认积分消耗 await _confirm_or_refund_credits(...) except Exception as e: # 任务失败,退还积分 await _update_job_status(job_id, AIJobStatus.FAILED, error_message=str(e)) await _confirm_or_refund_credits(..., success=False, error_message=str(e)) ``` **优势**: 1. **数据主权**:AI 生成的内容存储在自有系统中 2. **稳定性**:不依赖第三方临时 URL(如 OpenAI 的临时 URL 只有 1 小时有效期) 3. **去重优化**:通过文件校验和(SHA256)自动去重,节省存储成本 4. **灵活切换**:开发环境使用 MinIO,生产环境可无缝切换到云服务商 OSS 5. **统一管理**:所有文件通过 `FileStorageService` 统一管理,支持引用计数和自动清理 ### FileStorageService 集成 AI Service 需要注入 `FileStorageService` 依赖: ```python # app/services/ai_service.py from app.services.file_storage_service import FileStorageService class AIService: def __init__(self, db: Session): self.db = db self.job_repository = AIJobRepository(db) self.model_repository = AIModelRepository(db) self.credit_service = CreditService(db) self.file_storage = FileStorageService(db) # 注入文件存储服务 ``` 详见:[File Storage Service 文档](../resource/file-storage-service.md) --- ## AI 提供商集成 ### 提供商抽象层 ```python # app/services/ai_providers/base.py from abc import ABC, abstractmethod from typing import Dict, Any class AIProvider(ABC): @abstractmethod async def generate_image(self, prompt: str, **kwargs) -> Dict[str, Any]: """生成图片""" pass @abstractmethod async def generate_video(self, prompt: str, **kwargs) -> Dict[str, Any]: """生成视频""" pass ``` ### Stable Diffusion 提供商 ```python # app/services/ai_providers/stable_diffusion.py import httpx from app.services.ai_providers.base import AIProvider class StableDiffusionProvider(AIProvider): def __init__(self, api_key: str, api_url: str): self.api_key = api_key self.api_url = api_url async def generate_image(self, prompt: str, **kwargs) -> Dict[str, Any]: """调用 Stable Diffusion API 生成图片""" async with httpx.AsyncClient() as client: response = await client.post( f"{self.api_url}/generate", json={ "prompt": prompt, "width": kwargs.get('width', 1024), "height": kwargs.get('height', 1024), "steps": kwargs.get('steps', 50), "cfg_scale": kwargs.get('cfg_scale', 7.5) }, headers={"Authorization": f"Bearer {self.api_key}"}, timeout=300 ) return response.json() async def generate_video(self, prompt: str, **kwargs) -> Dict[str, Any]: """Stable Diffusion 不支持视频生成""" raise NotImplementedError("Stable Diffusion 不支持视频生成") ``` ### OpenAI 提供商 ```python # app/services/ai_providers/openai.py import openai from app.services.ai_providers.base import AIProvider class OpenAIProvider(AIProvider): def __init__(self, api_key: str): self.api_key = api_key openai.api_key = api_key async def generate_image(self, prompt: str, **kwargs) -> Dict[str, Any]: """调用 DALL-E API 生成图片""" response = await openai.Image.acreate( prompt=prompt, n=1, size=f"{kwargs.get('width', 1024)}x{kwargs.get('height', 1024)}" ) return { 'image_url': response['data'][0]['url'] } async def generate_voice(self, text: str, **kwargs) -> Dict[str, Any]: """调用 OpenAI TTS API 生成配音""" response = await openai.Audio.create( model="tts-1", voice=kwargs.get('voice_type', 'alloy'), input=text, speed=kwargs.get('speed', 1.0) ) return { 'audio_url': response['url'] } async def process_text(self, task_type: str, text: str, **kwargs) -> Dict[str, Any]: """调用 GPT API 处理文本""" # 根据任务类型构建系统提示词 system_prompts = { 'screenplay_parse': """你是一个专业的剧本分析助手。请将输入的剧本文本拆解为结构化的场景和镜头信息。 输出 JSON 格式,包含:scenes (场景列表),每个场景包含 scene_number, location, time, description, characters, shots。 每个 shot 包含 shot_number, shot_size, camera_movement, description, duration。""", 'content_analysis': """你是一个内容分析专家。请分析输入文本,提取关键信息。 输出 JSON 格式,包含:characters (人物列表), locations (地点列表), timeline (时间线), themes (主题), emotions (情感)。""", 'style_transform': """你是一个文本风格转换专家。请根据要求调整文本的风格和语气。 保持原意,但改变表达方式。""", 'prompt_generation': """你是一个 AI 绘画提示词专家。请将输入的场景描述转换为详细的 AI 绘画提示词。 包含:主体、环境、光线、色调、风格、镜头角度等要素。输出英文提示词。""" } messages = [ {"role": "system", "content": system_prompts.get(task_type, "")}, {"role": "user", "content": text} ] response = await openai.ChatCompletion.acreate( model=kwargs.get('model', 'gpt-4'), messages=messages, temperature=kwargs.get('temperature', 0.7), max_tokens=kwargs.get('max_tokens', 4000) ) content = response.choices[0].message.content # 如果要求 JSON 格式,尝试解析 if kwargs.get('output_format') == 'json': import json try: content = json.loads(content) except json.JSONDecodeError: # 如果解析失败,尝试提取 JSON 部分 import re json_match = re.search(r'\{.*\}', content, re.DOTALL) if json_match: content = json.loads(json_match.group()) return { 'result': content, 'model': response.model, 'usage': { 'prompt_tokens': response.usage.prompt_tokens, 'completion_tokens': response.usage.completion_tokens, 'total_tokens': response.usage.total_tokens } } ``` ### 提供商工厂 ```python # app/services/ai_service_factory.py from app.services.ai_providers.stable_diffusion import StableDiffusionProvider from app.services.ai_providers.openai import OpenAIProvider from app.config import settings class AIServiceFactory: @staticmethod def create_provider(provider_type: str): """创建 AI 提供商实例""" if provider_type == 'stable_diffusion': return StableDiffusionProvider( api_key=settings.STABLE_DIFFUSION_API_KEY, api_url=settings.STABLE_DIFFUSION_API_URL ) elif provider_type == 'openai': return OpenAIProvider( api_key=settings.OPENAI_API_KEY ) else: raise ValueError(f"Unknown provider: {provider_type}") ``` --- ## 相关文档 - [数据库设计](../../database-design.md) - [异步任务处理](../07-async-tasks.md) - [系统架构设计](../03-system-design.md) --- ## 附录:Repository 实现示例 ### 基础 Repository(引用完整性保证) ```python # app/repositories/base_repository.py from typing import Optional, TypeVar, Generic from uuid import UUID from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession T = TypeVar('T') class BaseRepository(Generic[T]): """基础 Repository - 提供引用完整性验证""" def __init__(self, db: AsyncSession, model: type[T]): self.db = db self.model = model async def exists(self, id: UUID) -> bool: """检查记录是否存在(排除软删除) 这是应用层保证引用完整性的核心方法 """ query = select(self.model.id).where( self.model.id == id ).limit(1) # 如果模型有 deleted_at 字段,排除软删除记录 if hasattr(self.model, 'deleted_at'): query = query.where(self.model.deleted_at.is_(None)) result = await self.db.execute(query) return result.scalar_one_or_none() is not None async def get_by_id(self, id: UUID) -> Optional[T]: """根据 ID 获取记录""" query = select(self.model).where(self.model.id == id) if hasattr(self.model, 'deleted_at'): query = query.where(self.model.deleted_at.is_(None)) result = await self.db.execute(query) return result.scalar_one_or_none() ``` ### AIJobRepository ```python # app/repositories/ai_job_repository.py from typing import List, Optional from uuid import UUID from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from app.models.ai_job import AIJob, AIJobType, AIJobStatus from app.repositories.base_repository import BaseRepository class AIJobRepository(BaseRepository[AIJob]): """AI 任务数据访问层""" def __init__(self, db: AsyncSession): super().__init__(db, AIJob) async def create(self, job_data: dict) -> AIJob: """创建 AI 任务""" job = AIJob(**job_data) self.db.add(job) await self.db.commit() await self.db.refresh(job) return job async def update(self, job_id: UUID, update_data: dict) -> AIJob: """更新 AI 任务""" job = await self.get_by_id(job_id) if not job: raise ValueError("任务不存在") for key, value in update_data.items(): setattr(job, key, value) await self.db.commit() await self.db.refresh(job) return job async def get_user_jobs( self, user_id: UUID, job_type: Optional[int] = None, status: Optional[int] = None, page: int = 1, page_size: int = 20 ) -> List[AIJob]: """获取用户的 AI 任务列表""" query = select(AIJob).where(AIJob.user_id == user_id) if job_type: query = query.where(AIJob.job_type == job_type) if status: query = query.where(AIJob.status == status) query = query.order_by(AIJob.created_at.desc()) query = query.offset((page - 1) * page_size).limit(page_size) result = await self.db.execute(query) return result.scalars().all() ``` ### AIModelRepository ```python # app/repositories/ai_model_repository.py from typing import Optional, List from uuid import UUID from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from app.models.ai_model import AIModel, AIModelType from app.repositories.base_repository import BaseRepository from app.core.exceptions import ValidationError class AIModelRepository(BaseRepository[AIModel]): """AI 模型数据访问层""" def __init__(self, db: AsyncSession): super().__init__(db, AIModel) async def get_by_name(self, model_name: str) -> Optional[AIModel]: """根据模型名称获取模型""" query = select(AIModel).where(AIModel.model_name == model_name) result = await self.db.execute(query) return result.scalar_one_or_none() async def get_default_model(self, model_type: int) -> Optional[AIModel]: """获取指定类型的默认模型(最便宜的活跃模型) Args: model_type: 模型类型(AIModelType 枚举值) """ # 验证模型类型枚举值 if model_type not in AIModelType.__members__.values(): raise ValidationError(f"无效的模型类型: {model_type}") query = select(AIModel).where( AIModel.model_type == model_type, AIModel.is_active == True ).order_by(AIModel.cost_per_unit.asc()) result = await self.db.execute(query) return result.scalar_one_or_none() async def get_active_models( self, model_type: Optional[int] = None ) -> List[AIModel]: """获取所有活跃模型 Args: model_type: 模型类型(AIModelType 枚举值,可选) """ # 如果指定了模型类型,验证枚举值 if model_type is not None and model_type not in AIModelType.__members__.values(): raise ValidationError(f"无效的模型类型: {model_type}") query = select(AIModel).where(AIModel.is_active == True) if model_type: query = query.where(AIModel.model_type == model_type) query = query.order_by(AIModel.model_type, AIModel.cost_per_unit) result = await self.db.execute(query) return result.scalars().all() ``` ### AIQuotaRepository ```python # app/repositories/ai_quota_repository.py from typing import Optional, List from uuid import UUID from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from app.models.ai_quota import AIQuota, QuotaPeriod from app.repositories.base_repository import BaseRepository from datetime import datetime, timedelta class AIQuotaRepository(BaseRepository[AIQuota]): """AI 配额数据访问层""" def __init__(self, db: AsyncSession): super().__init__(db, AIQuota) async def get_user_quota( self, user_id: UUID, quota_type: str, period: int = QuotaPeriod.DAILY ) -> Optional[AIQuota]: """获取用户指定类型的配额 Args: user_id: 用户 ID quota_type: 配额类型 period: 配额周期(QuotaPeriod 枚举值,默认 1=DAILY) """ query = select(AIQuota).where( AIQuota.user_id == user_id, AIQuota.quota_type == quota_type, AIQuota.period == period ) result = await self.db.execute(query) return result.scalar_one_or_none() async def get_user_quotas(self, user_id: UUID) -> List[AIQuota]: """获取用户所有配额""" query = select(AIQuota).where(AIQuota.user_id == user_id) result = await self.db.execute(query) return result.scalars().all() async def increment_usage( self, user_id: UUID, quota_type: str, amount: int = 1 ) -> None: """增加配额使用量""" # 更新每日配额 daily_quota = await self.get_user_quota(user_id, quota_type, QuotaPeriod.DAILY) if daily_quota: daily_quota.used_quota += amount await self.db.commit() # 更新每月配额 monthly_quota = await self.get_user_quota(user_id, quota_type, QuotaPeriod.MONTHLY) if monthly_quota: monthly_quota.used_quota += amount await self.db.commit() async def create_default_quotas(self, user_id: UUID) -> None: """为新用户创建默认配额""" default_quotas = [ { 'user_id': user_id, 'quota_type': 'image_generation', 'period': QuotaPeriod.DAILY, 'total_quota': 50, 'reset_at': datetime.now(timezone.utc) + timedelta(days=1) }, { 'user_id': user_id, 'quota_type': 'image_generation', 'period': QuotaPeriod.MONTHLY, 'total_quota': 1000, 'reset_at': datetime.now(timezone.utc) + timedelta(days=30) }, { 'user_id': user_id, 'quota_type': 'text_processing', 'period': QuotaPeriod.DAILY, 'total_quota': 100, 'reset_at': datetime.now(timezone.utc) + timedelta(days=1) }, { 'user_id': user_id, 'quota_type': 'text_processing', 'period': QuotaPeriod.MONTHLY, 'total_quota': 2000, 'reset_at': datetime.now(timezone.utc) + timedelta(days=30) } ] for quota_data in default_quotas: quota = AIQuota(**quota_data) self.db.add(quota) await self.db.commit() ``` --- **文档版本**:v1.0 **最后更新**:2025-01-27 --- ## 附录:Python 枚举定义 ### app/models/ai_job.py ```python from enum import IntEnum from sqlmodel import SQLModel, Field from sqlalchemy import Column, SmallInteger, Index, text from sqlalchemy.dialects.postgresql import UUID as PG_UUID, JSONB from datetime import datetime, timezone from uuid import UUID from typing import Optional from app.utils.id_generator import generate_uuid class AIJobType(IntEnum): """AI 任务类型""" IMAGE = 1 # 图片生成 VIDEO = 2 # 视频生成 SOUND = 3 # 音效生成 VOICE = 4 # 配音生成 SUBTITLE = 5 # 字幕生成 TEXT_PROCESSING = 6 # 文本处理(剧本拆解等) RESOURCE = 7 # 资源生成 STORYBOARD_SCRIPT = 8 # 分镜脚本生成 SCRIPT_GENERATION = 9 # 剧本生成 class AIJobStatus(IntEnum): """AI 任务状态""" PENDING = 1 # 等待处理 PROCESSING = 2 # 处理中 COMPLETED = 3 # 已完成 FAILED = 4 # 失败 CANCELLED = 5 # 已取消 class AIJob(SQLModel, table=True): """AI 任务表 - 不使用外键约束,应用层保证引用完整性""" __tablename__ = "ai_jobs" # 主键 ai_job_id: UUID = Field( sa_column=Column( PG_UUID(as_uuid=True), primary_key=True, default=generate_uuid ) ) # 关联字段(无外键约束,仅索引) user_id: UUID = Field( sa_column=Column(PG_UUID(as_uuid=True), nullable=False, index=True), description="用户 ID - 应用层验证" ) project_id: Optional[UUID] = Field( default=None, sa_column=Column(PG_UUID(as_uuid=True), nullable=True, index=True), description="项目 ID - 应用层验证" ) storyboard_id: Optional[UUID] = Field( default=None, sa_column=Column(PG_UUID(as_uuid=True), nullable=True, index=True), description="分镜 ID - 应用层验证" ) consumption_log_id: Optional[UUID] = Field( default=None, sa_column=Column(PG_UUID(as_uuid=True), nullable=True, index=True), description="积分消耗日志 ID - 应用层验证" ) model_id: Optional[UUID] = Field( default=None, sa_column=Column(PG_UUID(as_uuid=True), nullable=True, index=True), description="AI 模型 ID - 应用层验证" ) # 任务信息(使用 SMALLINT 存储枚举) job_type: int = Field( sa_column=Column(SmallInteger, nullable=False), description="任务类型:1=图片 2=视频 3=音效 4=配音 5=字幕 6=文本处理 7=资源 8=分镜脚本 9=剧本生成" ) status: int = Field( default=AIJobStatus.PENDING, sa_column=Column(SmallInteger, nullable=False, default=1, index=True), description="任务状态:1=等待处理 2=处理中 3=已完成 4=失败 5=已取消" ) # JSONB 字段 input_data: dict = Field( default_factory=dict, sa_column=Column(JSONB, nullable=False, default={}), description="输入参数" ) output_data: Optional[dict] = Field( default=None, sa_column=Column(JSONB, nullable=True), description="输出结果" ) # 其他字段 model_name: Optional[str] = Field(default=None, max_length=255) progress: int = Field(default=0, ge=0, le=100) error_message: Optional[str] = Field(default=None) task_id: Optional[str] = Field(default=None) cost: Optional[float] = Field(default=None) credits_used: int = Field(default=0) # 时间戳 estimated_completion_at: Optional[datetime] = Field(default=None) started_at: Optional[datetime] = Field(default=None) completed_at: Optional[datetime] = Field(default=None) created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) # 表级索引 __table_args__ = ( Index('idx_ai_jobs_status_created_at', 'status', 'created_at', postgresql_where=text('status IN (1, 2)')), Index('idx_ai_jobs_input_data_gin', 'input_data', postgresql_using='gin'), Index('idx_ai_jobs_output_data_gin', 'output_data', postgresql_using='gin', postgresql_where=text('output_data IS NOT NULL')), ) ``` ### app/models/ai_model.py ```python from enum import IntEnum class AIModelType(IntEnum): """AI 模型类型""" TEXT = 1 # 文本模型(GPT, Claude 等) IMAGE = 2 # 图片模型(DALL-E, Stable Diffusion 等) VIDEO = 3 # 视频模型(Runway, Pika 等) AUDIO = 4 # 音频模型(TTS, STT 等) class AIProvider(IntEnum): """AI 提供商""" OPENAI = 1 ANTHROPIC = 2 GOOGLE = 3 # Google Gemini STABILITY = 4 RUNWAY = 5 PIKA = 6 ELEVENLABS = 7 AZURE = 8 BAIDU = 9 ALIYUN = 10 CUSTOM = 99 # 自定义提供商 class UnitType(IntEnum): """计费单位类型""" TOKEN = 1 # Token(文本模型) IMAGE = 2 # 图片(图片模型) SECOND = 3 # 秒(视频/音频模型) REQUEST = 4 # 请求(通用计费单位) ``` ### app/models/ai_quota.py ```python from enum import IntEnum class QuotaPeriod(IntEnum): """配额周期""" DAILY = 1 # 每日配额 MONTHLY = 2 # 每月配额 TOTAL = 3 # 总配额 ``` --- ## 测试规范 ### 单元测试 ```python # server/tests/unit/test_ai_service.py import pytest from uuid import UUID from unittest.mock import AsyncMock, MagicMock, patch from app.services.ai_service import AIService from app.models.ai_job import AIJobType, AIJobStatus from app.models.ai_model import AIModelType from app.core.exceptions import ValidationError, InsufficientCreditsError, NotFoundError @pytest.fixture def mock_db(): """模拟数据库会话""" db = AsyncMock() db.begin = AsyncMock() db.commit = AsyncMock() db.rollback = AsyncMock() return db @pytest.fixture def ai_service(mock_db): """创建 AI Service 实例""" return AIService(mock_db) class TestAIService: """AI Service 单元测试""" @pytest.mark.asyncio async def test_generate_image_success(self, ai_service, mock_db): """测试图片生成成功""" # 准备测试数据 user_id = UUID('019d1234-5678-7abc-def0-111111111111') prompt = "一只可爱的猫咪" # Mock 依赖 with patch.object(ai_service, '_check_quota', return_value=True), \ patch.object(ai_service, '_get_model') as mock_get_model, \ patch.object(ai_service.job_repository, 'create') as mock_create_job, \ patch.object(ai_service.credit_service, 'consume_credits') as mock_consume: # 设置 mock 返回值 mock_model = MagicMock() mock_model.model_id = UUID('019d1234-5678-7abc-def0-222222222222') mock_model.model_name = 'stable_diffusion' mock_model.credits_per_unit = 10 mock_model.cost_per_unit = 0.5 mock_get_model.return_value = mock_model mock_consumption = MagicMock() mock_consumption.consumption_id = UUID('019d1234-5678-7abc-def0-333333333333') mock_consume.return_value = mock_consumption mock_job = MagicMock() mock_job.ai_job_id = UUID('019d1234-5678-7abc-def0-444444444444') mock_create_job.return_value = mock_job # 执行测试 result = await ai_service.generate_image( user_id=user_id, prompt=prompt, width=1024, height=1024 ) # 验证结果 assert 'job_id' in result assert 'task_id' in result assert result['status'] == 'pending' assert result['estimated_credits'] == 10 # 验证调用 mock_consume.assert_called_once() mock_create_job.assert_called_once() @pytest.mark.asyncio async def test_generate_image_insufficient_credits(self, ai_service): """测试积分不足""" user_id = UUID('019d1234-5678-7abc-def0-111111111111') with patch.object(ai_service, '_check_quota', return_value=True), \ patch.object(ai_service, '_get_model') as mock_get_model, \ patch.object(ai_service.credit_service, 'consume_credits') as mock_consume: mock_model = MagicMock() mock_model.credits_per_unit = 10 mock_get_model.return_value = mock_model # 模拟积分不足 mock_consume.side_effect = InsufficientCreditsError("积分不足") # 验证抛出异常 with pytest.raises(ValidationError, match="积分不足"): await ai_service.generate_image( user_id=user_id, prompt="test" ) @pytest.mark.asyncio async def test_get_job_status_not_found(self, ai_service): """测试查询不存在的任务""" job_id = UUID('019d1234-5678-7abc-def0-111111111111') with patch.object(ai_service.job_repository, 'get_by_id', return_value=None): with pytest.raises(NotFoundError, match="任务不存在"): await ai_service.get_job_status(job_id) @pytest.mark.asyncio async def test_cancel_job_success(self, ai_service): """测试取消任务成功""" user_id = UUID('019d1234-5678-7abc-def0-111111111111') job_id = UUID('019d1234-5678-7abc-def0-222222222222') mock_job = MagicMock() mock_job.user_id = user_id mock_job.status = AIJobStatus.PENDING mock_job.task_id = 'celery-task-123' with patch.object(ai_service.job_repository, 'get_by_id', return_value=mock_job), \ patch.object(ai_service.job_repository, 'update') as mock_update, \ patch('app.tasks.celery_app.celery_app.control.revoke') as mock_revoke: await ai_service.cancel_job(user_id, job_id) mock_revoke.assert_called_once_with('celery-task-123', terminate=True) mock_update.assert_called_once() ``` ### 集成测试 ```python # server/tests/integration/test_ai_api.py import pytest from httpx import AsyncClient from uuid import UUID from app.main import app from app.core.database import get_db from app.models.ai_job import AIJobType, AIJobStatus @pytest.fixture async def client(): """创建测试客户端""" async with AsyncClient(app=app, base_url="http://test") as ac: yield ac @pytest.fixture async def auth_headers(client): """获取认证 token""" # 登录获取 token response = await client.post("/api/v1/auth/login", json={ "email": "test@example.com", "password": "testpass123" }) token = response.json()["data"]["accessToken"] return {"Authorization": f"Bearer {token}"} class TestAIAPI: """AI API 集成测试""" @pytest.mark.asyncio async def test_create_image_job(self, client, auth_headers): """测试创建图片生成任务""" response = await client.post( "/api/v1/ai/jobs", headers=auth_headers, json={ "jobType": "image", "prompt": "一只可爱的猫咪", "width": 1024, "height": 1024 } ) assert response.status_code == 200 data = response.json() assert data["code"] == 200 assert "jobId" in data["data"] assert "taskId" in data["data"] assert data["data"]["status"] == "pending" @pytest.mark.asyncio async def test_create_job_invalid_params(self, client, auth_headers): """测试无效参数""" response = await client.post( "/api/v1/ai/jobs", headers=auth_headers, json={ "jobType": "image", "prompt": "", # 空提示词 "width": 1024 } ) assert response.status_code == 400 data = response.json() assert data["code"] == 400 assert "prompt" in data["message"].lower() @pytest.mark.asyncio async def test_get_job_status(self, client, auth_headers): """测试查询任务状态""" # 先创建任务 create_response = await client.post( "/api/v1/ai/jobs", headers=auth_headers, json={ "jobType": "image", "prompt": "test" } ) job_id = create_response.json()["data"]["jobId"] # 查询状态 response = await client.get( f"/api/v1/ai/jobs/{job_id}", headers=auth_headers ) assert response.status_code == 200 data = response.json() assert data["code"] == 200 assert data["data"]["jobId"] == job_id assert "status" in data["data"] @pytest.mark.asyncio async def test_cancel_job(self, client, auth_headers): """测试取消任务""" # 先创建任务 create_response = await client.post( "/api/v1/ai/jobs", headers=auth_headers, json={ "jobType": "image", "prompt": "test" } ) job_id = create_response.json()["data"]["jobId"] # 取消任务 response = await client.post( f"/api/v1/ai/jobs/{job_id}/cancel", headers=auth_headers ) assert response.status_code == 200 data = response.json() assert data["code"] == 200 assert "已取消" in data["message"] @pytest.mark.asyncio async def test_get_usage_stats(self, client, auth_headers): """测试获取使用统计""" response = await client.get( "/api/v1/ai/usage/stats", headers=auth_headers ) assert response.status_code == 200 data = response.json() assert data["code"] == 200 assert "totalCost" in data["data"] assert "totalCreditsUsed" in data["data"] assert "quotas" in data["data"] @pytest.mark.asyncio async def test_get_available_models(self, client, auth_headers): """测试获取可用模型列表""" response = await client.get( "/api/v1/ai/models", headers=auth_headers ) assert response.status_code == 200 data = response.json() assert data["code"] == 200 assert "items" in data["data"] assert len(data["data"]["items"]) > 0 ``` ### 测试运行 ```bash # 运行所有 AI Service 测试 docker exec jointo-server-app pytest tests/unit/test_ai_service.py -v # 运行集成测试 docker exec jointo-server-app pytest tests/integration/test_ai_api.py -v # 运行所有测试并生成覆盖率报告 docker exec jointo-server-app pytest tests/ --cov=app/services/ai_service --cov-report=html # 运行特定测试 docker exec jointo-server-app pytest tests/unit/test_ai_service.py::TestAIService::test_generate_image_success -v ``` --- ## 数据库迁移脚本 ### 创建 AI 服务表 ```python # server/alembic/versions/20260129_1700_create_ai_service_tables.py """create ai service tables Revision ID: abc123def456 Revises: 3a3a2a1417de Create Date: 2026-01-29 17:00:00.000000 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = 'abc123def456' down_revision = '3a3a2a1417de' branch_labels = None depends_on = None def upgrade() -> None: """创建 AI 服务相关表(无外键约束,应用层保证引用完整性)""" # 1. 创建 ai_models 表 op.create_table( 'ai_models', sa.Column('model_id', postgresql.UUID(as_uuid=True), nullable=False), sa.Column('model_name', sa.Text(), nullable=False), sa.Column('display_name', sa.Text(), nullable=False), sa.Column('description', sa.Text(), nullable=True), sa.Column('provider', sa.SmallInteger(), nullable=False), sa.Column('model_type', sa.SmallInteger(), nullable=False), sa.Column('cost_per_unit', sa.Numeric(10, 4), nullable=False), sa.Column('unit_type', sa.SmallInteger(), nullable=False), sa.Column('credits_per_unit', sa.Integer(), nullable=False), sa.Column('config', postgresql.JSONB(astext_type=sa.Text()), nullable=False, server_default='{}'), sa.Column('rate_limit', sa.Integer(), nullable=True), sa.Column('daily_quota', sa.Integer(), nullable=True), sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'), sa.Column('is_beta', sa.Boolean(), nullable=False, server_default='false'), sa.Column('created_at', sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.text('now()')), sa.Column('updated_at', sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.text('now()')), sa.PrimaryKeyConstraint('model_id'), sa.UniqueConstraint('model_name'), comment='AI 模型配置表' ) # 索引 op.create_index('idx_ai_models_provider', 'ai_models', ['provider'], postgresql_where=sa.text('is_active = true')) op.create_index('idx_ai_models_type', 'ai_models', ['model_type'], postgresql_where=sa.text('is_active = true')) op.create_index('idx_ai_models_is_active', 'ai_models', ['is_active']) op.create_index('idx_ai_models_config_gin', 'ai_models', ['config'], postgresql_using='gin') # 触发器 op.execute(""" CREATE TRIGGER update_ai_models_updated_at BEFORE UPDATE ON ai_models FOR EACH ROW EXECUTE FUNCTION update_updated_at_column(); """) # 2. 创建 ai_jobs 表(无外键约束) op.create_table( 'ai_jobs', sa.Column('ai_job_id', postgresql.UUID(as_uuid=True), nullable=False), sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False), sa.Column('project_id', postgresql.UUID(as_uuid=True), nullable=True), sa.Column('storyboard_id', postgresql.UUID(as_uuid=True), nullable=True), sa.Column('consumption_log_id', postgresql.UUID(as_uuid=True), nullable=True), sa.Column('job_type', sa.SmallInteger(), nullable=False), sa.Column('status', sa.SmallInteger(), nullable=False, server_default='1'), sa.Column('input_data', postgresql.JSONB(astext_type=sa.Text()), nullable=False, server_default='{}'), sa.Column('output_data', postgresql.JSONB(astext_type=sa.Text()), nullable=True), sa.Column('model_id', postgresql.UUID(as_uuid=True), nullable=True), sa.Column('model_name', sa.Text(), nullable=True), sa.Column('progress', sa.Integer(), nullable=False, server_default='0'), sa.Column('error_message', sa.Text(), nullable=True), sa.Column('task_id', sa.Text(), nullable=True), sa.Column('estimated_completion_at', sa.TIMESTAMP(timezone=True), nullable=True), sa.Column('started_at', sa.TIMESTAMP(timezone=True), nullable=True), sa.Column('completed_at', sa.TIMESTAMP(timezone=True), nullable=True), sa.Column('created_at', sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.text('now()')), sa.Column('updated_at', sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.text('now()')), sa.Column('cost', sa.Numeric(10, 4), nullable=True), sa.Column('credits_used', sa.Integer(), nullable=False, server_default='0'), sa.PrimaryKeyConstraint('ai_job_id'), sa.CheckConstraint('progress >= 0 AND progress <= 100', name='ai_jobs_progress_check'), comment='AI 任务表 - 应用层保证引用完整性' ) # 索引(关联字段必须有索引) op.create_index('idx_ai_jobs_user_id', 'ai_jobs', ['user_id']) op.create_index('idx_ai_jobs_project_id', 'ai_jobs', ['project_id'], postgresql_where=sa.text('project_id IS NOT NULL')) op.create_index('idx_ai_jobs_storyboard_id', 'ai_jobs', ['storyboard_id'], postgresql_where=sa.text('storyboard_id IS NOT NULL')) op.create_index('idx_ai_jobs_type', 'ai_jobs', ['job_type']) op.create_index('idx_ai_jobs_status', 'ai_jobs', ['status']) op.create_index('idx_ai_jobs_model_id', 'ai_jobs', ['model_id'], postgresql_where=sa.text('model_id IS NOT NULL')) op.create_index('idx_ai_jobs_consumption_log_id', 'ai_jobs', ['consumption_log_id'], postgresql_where=sa.text('consumption_log_id IS NOT NULL')) op.create_index('idx_ai_jobs_created_at', 'ai_jobs', ['created_at']) op.create_index('idx_ai_jobs_status_created_at', 'ai_jobs', ['status', 'created_at'], postgresql_where=sa.text('status IN (1, 2)')) op.create_index('idx_ai_jobs_input_data_gin', 'ai_jobs', ['input_data'], postgresql_using='gin') op.create_index('idx_ai_jobs_output_data_gin', 'ai_jobs', ['output_data'], postgresql_using='gin', postgresql_where=sa.text('output_data IS NOT NULL')) # 触发器 op.execute(""" CREATE TRIGGER update_ai_jobs_updated_at BEFORE UPDATE ON ai_jobs FOR EACH ROW EXECUTE FUNCTION update_updated_at_column(); """) # 字段注释 op.execute(""" COMMENT ON COLUMN ai_jobs.user_id IS '用户 ID - 应用层验证'; COMMENT ON COLUMN ai_jobs.project_id IS '项目 ID - 应用层验证'; COMMENT ON COLUMN ai_jobs.storyboard_id IS '分镜 ID - 应用层验证'; COMMENT ON COLUMN ai_jobs.consumption_log_id IS '积分消耗日志 ID - 应用层验证'; COMMENT ON COLUMN ai_jobs.model_id IS 'AI 模型 ID - 应用层验证'; COMMENT ON COLUMN ai_jobs.job_type IS '任务类型(1=图片 2=视频 3=音效 4=配音 5=字幕 6=文本处理 7=资源 8=分镜脚本 9=剧本生成)'; COMMENT ON COLUMN ai_jobs.status IS '任务状态(1=等待处理 2=处理中 3=已完成 4=失败 5=已取消)'; """) # 3. 创建 ai_usage_logs 表(无外键约束) op.create_table( 'ai_usage_logs', sa.Column('usage_log_id', postgresql.UUID(as_uuid=True), nullable=False), sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False), sa.Column('ai_job_id', postgresql.UUID(as_uuid=True), nullable=False), sa.Column('model_id', postgresql.UUID(as_uuid=True), nullable=True), sa.Column('units_used', sa.Numeric(10, 2), nullable=False), sa.Column('unit_type', sa.SmallInteger(), nullable=False), sa.Column('cost', sa.Numeric(10, 4), nullable=False), sa.Column('credits_used', sa.Integer(), nullable=False), sa.Column('meta_data', postgresql.JSONB(astext_type=sa.Text()), nullable=False, server_default='{}'), sa.Column('created_at', sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.text('now()')), sa.PrimaryKeyConstraint('usage_log_id'), sa.CheckConstraint('unit_type BETWEEN 1 AND 4', name='ai_usage_logs_unit_type_check'), comment='AI 使用日志表 - 应用层保证引用完整性' ) # 索引 op.create_index('idx_ai_usage_logs_user_id', 'ai_usage_logs', ['user_id']) op.create_index('idx_ai_usage_logs_ai_job_id', 'ai_usage_logs', ['ai_job_id']) op.create_index('idx_ai_usage_logs_model_id', 'ai_usage_logs', ['model_id'], postgresql_where=sa.text('model_id IS NOT NULL')) op.create_index('idx_ai_usage_logs_created_at', 'ai_usage_logs', ['created_at']) op.create_index('idx_ai_usage_logs_user_created', 'ai_usage_logs', ['user_id', 'created_at']) op.create_index('idx_ai_usage_logs_meta_data_gin', 'ai_usage_logs', ['meta_data'], postgresql_using='gin') # 字段注释 op.execute(""" COMMENT ON COLUMN ai_usage_logs.user_id IS '用户 ID - 应用层验证'; COMMENT ON COLUMN ai_usage_logs.ai_job_id IS 'AI 任务 ID - 应用层验证'; COMMENT ON COLUMN ai_usage_logs.model_id IS 'AI 模型 ID - 应用层验证'; COMMENT ON COLUMN ai_usage_logs.unit_type IS '单位类型(1=Token 2=图片 3=秒 4=请求)'; """) # 4. 创建 ai_quotas 表(无外键约束) op.create_table( 'ai_quotas', sa.Column('quota_id', postgresql.UUID(as_uuid=True), nullable=False), sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False), sa.Column('quota_type', sa.Text(), nullable=False), sa.Column('period', sa.SmallInteger(), nullable=False), sa.Column('total_quota', sa.Integer(), nullable=False), sa.Column('used_quota', sa.Integer(), nullable=False, server_default='0'), sa.Column('reset_at', sa.TIMESTAMP(timezone=True), nullable=False), sa.Column('created_at', sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.text('now()')), sa.Column('updated_at', sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.text('now()')), sa.PrimaryKeyConstraint('quota_id'), sa.UniqueConstraint('user_id', 'quota_type', 'period', name='ai_quotas_unique'), sa.CheckConstraint('period BETWEEN 1 AND 3', name='ai_quotas_period_check'), sa.CheckConstraint('used_quota >= 0 AND used_quota <= total_quota', name='ai_quotas_used_check'), comment='AI 配额表 - 应用层保证引用完整性' ) # 索引 op.create_index('idx_ai_quotas_user_id', 'ai_quotas', ['user_id']) op.create_index('idx_ai_quotas_type', 'ai_quotas', ['quota_type']) op.create_index('idx_ai_quotas_reset_at', 'ai_quotas', ['reset_at']) op.create_index('idx_ai_quotas_user_type', 'ai_quotas', ['user_id', 'quota_type']) # 触发器 op.execute(""" CREATE TRIGGER update_ai_quotas_updated_at BEFORE UPDATE ON ai_quotas FOR EACH ROW EXECUTE FUNCTION update_updated_at_column(); """) # 字段注释 op.execute(""" COMMENT ON COLUMN ai_quotas.user_id IS '用户 ID - 应用层验证'; COMMENT ON COLUMN ai_quotas.period IS '配额周期(1=每日 2=每月 3=总计)'; """) # 5. 创建配额重置函数 op.execute(""" CREATE OR REPLACE FUNCTION reset_expired_quotas() RETURNS void AS $$ BEGIN -- 重置每日配额 (period = 1) UPDATE ai_quotas SET used_quota = 0, reset_at = reset_at + INTERVAL '1 day', updated_at = now() WHERE period = 1 AND reset_at <= now(); -- 重置每月配额 (period = 2) UPDATE ai_quotas SET used_quota = 0, reset_at = reset_at + INTERVAL '1 month', updated_at = now() WHERE period = 2 AND reset_at <= now(); END; $$ LANGUAGE plpgsql; """) def downgrade() -> None: """删除 AI 服务相关表""" op.execute("DROP FUNCTION IF EXISTS reset_expired_quotas();") op.drop_table('ai_quotas') op.drop_table('ai_usage_logs') op.drop_table('ai_jobs') op.drop_table('ai_models') ``` ### 运行迁移 ```bash # 生成迁移文件 docker exec jointo-server-app alembic revision --autogenerate -m "create ai service tables" # 应用迁移 docker exec jointo-server-app alembic upgrade head # 查看当前版本 docker exec jointo-server-app alembic current # 回滚迁移 docker exec jointo-server-app alembic downgrade -1 ```