121 KiB
AI 生成服务
文档版本:v1.0
最后更新:2025-01-27
目录
服务概述
AI 生成服务负责处理各类 AI 内容生成任务,包括图片、视频、音效、配音等,支持多种 AI 模型和服务提供商。
职责
- extra_data图片生成(文本转图片)
- 视频生成(文本转视频、图片转视频)
- 音效生成
- 配音生成(文本转语音)
- 字幕生成(语音转文本)
- 任务状态管理
核心功能
应用层引用完整性保证
⚠️ 重要:本服务遵循 Jointo 技术栈规范,不使用数据库外键约束,改为在应用层保证引用完整性。
三层保证机制:
- Repository 层:提供
exists()方法检查记录是否存在 - Service 层:创建/更新前验证所有关联 ID,使用事务确保原子性
- 后台任务:定期检查孤儿记录并告警
优势:
- 写入性能提升 15-30%
- 为分库分表、微服务化做准备
- 复杂业务逻辑在应用层更易实现
- 表结构变更无需处理外键依赖
1. 图片生成
支持的模型:
- Stable Diffusion(开源)
- DALL-E(OpenAI)
- Midjourney(如果可用)
功能:
- 文本描述生成图片
- 支持风格控制
- 支持分辨率设置
- 异步生成,返回任务 ID
2. 视频生成
支持的类型:
- 文本转视频(text2video)
- 图片转视频(img2video)
- 关键帧动画(keyframe)
- 视频融合(fusion)
- 视频替换(replace)
支持的模型:
- Runway Gen-2
- Pika Labs
- Stable Video Diffusion
3. 音效生成
功能:
- 根据描述生成音效
- 支持音效类型选择
- 支持时长控制
4. 配音生成
功能:
- 文本转语音(TTS)
- 支持多种语音类型
- 支持语速、音调控制
- 支持多语言
支持的服务:
- OpenAI TTS
- Azure Speech Services
- 百度语音 / 讯飞语音
5. 字幕生成
功能:
- 语音转文本(STT)
- 自动断句
- 分镜看板对齐
支持的服务:
- Whisper(OpenAI)
- 阿里云语音识别
6. 文本处理
功能:
- 剧本拆解:将完整剧本拆分为场景、镜头
- 内容分析:提取关键信息(人物、地点、时间、动作)
- 结构化输出:生成 JSON 格式的结构化数据
- 智能补全:根据上下文补充缺失信息
- 风格转换:调整文本风格和语气
支持的模型:
- GPT-4 / GPT-3.5(OpenAI)
- Claude(Anthropic)
- 文心一言(百度)
- 通义千问(阿里)
应用场景:
- 剧本自动拆解为分镜
- 场景描述生成视觉提示词
- 对话文本生成配音脚本
- 内容审核和合规检查
7. 剧本解析(Screenplay Parsing)
核心功能:
- 自动提取剧本元素:
- 角色识别(主角、配角、群演)
- 场景识别(地点、时间、描述)
- 道具识别(重要性分类)
- 变体识别:
- 角色变体(年龄段、时代、状态)
- 场景变体(时代、季节、状态)
- 道具变体(状态、版本)
- 分镜拆解:
- 自动拆分为分镜脚本
- 识别景别和运镜
- 估算时长
- 自动关联:
- 分镜与角色/场景/道具自动关联
- 变体自动匹配
- 数据持久化:
- 自动存储到数据库
- 建立关联关系
工作流程:
- 用户上传/创建剧本
- 触发 AI 解析任务
- 预扣积分,创建 AI 任务
- Celery Worker 调用 AI 模型
- AI 返回结构化 JSON 数据
- 自动存储角色/场景/道具/变体
- 自动创建分镜记录
- 自动建立关联关系
- 确认积分消耗
- 返回解析结果
详细文档:参见 AI 解析剧本工作流
API 接口:
POST /api/v1/screenplays/{screenplay_id}/parse
请求体:
{
"auto_create_elements": true,
"auto_create_tags": true,
"auto_create_storyboards": true,
"model": "gpt-4"
}
响应:
{
"code": 200,
"message": "Success",
"data": {
"jobId": "019d1234-5678-7abc-def0-222222222222",
"taskId": "abc123-def456-ghi789",
"status": "pending",
"estimatedCredits": 50
}
}
AI 输出格式:
AI 模型返回包含以下结构的 JSON 数据:
{
"characters": [
{
"name": "张三",
"description": "男主角,30岁,程序员",
"role_type": "main",
"meta_data": {"age": 30, "gender": "male"}
}
],
"scenes": [
{
"scene_number": 1,
"title": "咖啡厅",
"location": "市中心星巴克",
"time_of_day": "afternoon",
"description": "温馨的咖啡厅"
}
],
"props": [
{
"name": "笔记本电脑",
"description": "张三的工作电脑",
"category": "电子设备",
"importance": "normal"
}
],
"character_tags": [
{
"character_name": "张三",
"tag_key": "youth",
"tag_label": "少年",
"description": "15岁的张三"
}
],
"scene_tags": [...],
"prop_tags": [...],
"storyboards": [
{
"shot_number": "001",
"title": "开场",
"description": "张三坐在咖啡厅里",
"dialogue": "又是一个平凡的下午...",
"shot_size": "medium_shot",
"camera_movement": "static",
"estimated_duration": 5.5,
"characters": ["张三"],
"character_tags": {"张三": "adult"},
"scenes": ["咖啡厅"],
"props": ["笔记本电脑"]
}
]
}
自动存储逻辑:
- 存储角色:批量插入
screenplay_characters表,返回角色 ID 映射 - 存储场景:批量插入
screenplay_scenes表,返回场景 ID 映射 - 存储道具:批量插入
screenplay_props表,返回道具 ID 映射 - 存储标签:调用
ScreenplayTagService.store_tags()批量插入screenplay_element_tags表,返回标签 ID 映射 - 存储分镜:批量插入
storyboards表,同时建立关联关系
标签存储详细流程:
# Celery Worker 调用 Screenplay Service
result = await screenplay_service.store_parsed_elements(
screenplay_id=screenplay_id,
parsed_data=parsed_data
)
# Screenplay Service 内部调用 Screenplay Tag Service
from app.services.screenplay_tag_service import ScreenplayTagService
tag_service = ScreenplayTagService(db)
tag_id_maps = await tag_service.store_tags(
screenplay_id=screenplay_id,
parsed_data=parsed_data,
character_id_map=character_id_map,
scene_id_map=scene_id_map,
prop_id_map=prop_id_map
)
# 返回的 tag_id_maps 结构
{
'character_tags': {
'张三-youth': UUID('019d1234-5678-7abc-def0-444444444444'),
'张三-adult': UUID('019d1234-5678-7abc-def0-555555555555')
},
'scene_tags': {
'花果山-daytime': UUID('019d1234-5678-7abc-def0-666666666666'),
'花果山-night': UUID('019d1234-5678-7abc-def0-777777777777')
},
'prop_tags': {
'金箍棒-new': UUID('019d1234-5678-7abc-def0-888888888888')
}
}
ScreenplayTagService.store_tags() 实现:
async def store_tags(
self,
screenplay_id: UUID,
parsed_data: Dict[str, Any],
character_id_map: Dict[str, UUID],
scene_id_map: Dict[str, UUID],
prop_id_map: Dict[str, UUID]
) -> Dict[str, Dict[str, UUID]]:
"""存储 AI 解析的标签(供 AI Service 调用)"""
logger.info("存储 AI 解析的标签: 剧本=%s", screenplay_id)
tag_id_maps = {
'character_tags': {},
'scene_tags': {},
'prop_tags': {}
}
# 1. 存储角色标签
for char_name, tags in parsed_data.get('character_tags', {}).items():
character_id = character_id_map.get(char_name)
if not character_id:
continue
for tag_data in tags:
tag = await self.repository.create(ScreenplayElementTag(
screenplay_id=screenplay_id,
element_type=ElementType.CHARACTER,
element_id=character_id,
element_name=char_name,
tag_key=tag_data['tag_key'],
tag_label=tag_data['tag_label'],
description=tag_data.get('description'),
meta_data=tag_data.get('meta_data', {})
))
map_key = f"{char_name}-{tag_data['tag_key']}"
tag_id_maps['character_tags'][map_key] = tag.tag_id
# 更新角色的 has_tags 标志
await self._update_element_has_tags(ElementType.CHARACTER, character_id, True)
# 2. 存储场景标签(逻辑类似)
for scene_name, tags in parsed_data.get('scene_tags', {}).items():
scene_id = scene_id_map.get(scene_name)
if not scene_id:
continue
for tag_data in tags:
tag = await self.repository.create(ScreenplayElementTag(
screenplay_id=screenplay_id,
element_type=ElementType.SCENE,
element_id=scene_id,
element_name=scene_name,
tag_key=tag_data['tag_key'],
tag_label=tag_data['tag_label'],
description=tag_data.get('description'),
meta_data=tag_data.get('meta_data', {})
))
map_key = f"{scene_name}-{tag_data['tag_key']}"
tag_id_maps['scene_tags'][map_key] = tag.tag_id
await self._update_element_has_tags(ElementType.SCENE, scene_id, True)
# 3. 存储道具标签(逻辑类似)
for prop_name, tags in parsed_data.get('prop_tags', {}).items():
prop_id = prop_id_map.get(prop_name)
if not prop_id:
continue
for tag_data in tags:
tag = await self.repository.create(ScreenplayElementTag(
screenplay_id=screenplay_id,
element_type=ElementType.PROP,
element_id=prop_id,
element_name=prop_name,
tag_key=tag_data['tag_key'],
tag_label=tag_data['tag_label'],
description=tag_data.get('description'),
meta_data=tag_data.get('meta_data', {})
))
map_key = f"{prop_name}-{tag_data['tag_key']}"
tag_id_maps['prop_tags'][map_key] = tag.tag_id
await self._update_element_has_tags(ElementType.PROP, prop_id, True)
logger.info(
"标签存储完成: 角色标签=%d, 场景标签=%d, 道具标签=%d",
len(tag_id_maps['character_tags']),
len(tag_id_maps['scene_tags']),
len(tag_id_maps['prop_tags'])
)
return tag_id_maps
自动关联逻辑:
分镜创建时,根据 AI 返回的名称数组和标签映射,自动查找对应的数据库 ID:
# 示例:角色关联
for char_name in storyboard_data['characters']:
character_id = character_id_map.get(char_name)
if character_id:
screenplay_character_ids.append(character_id)
# 检查标签
tag_key = storyboard_data['character_tags'].get(char_name)
if tag_key:
tag_id = tag_id_maps['character_tags'].get(f"{char_name}-{tag_key}")
if tag_id:
screenplay_character_tag_ids.append(tag_id)
涉及的服务:
- AI Service:调用 AI 模型进行解析
- Screenplay Service:管理剧本和剧本元素
- Screenplay Tag Service:管理标签(新增)
- Storyboard Service:管理分镜
- Credit Service:管理积分扣除
与 Credit Service 集成
集成架构
AI 服务与积分服务采用同步调用 + 事务保证的集成模式:
AI Service (依赖) → Credit Service (被依赖)
积分扣除流程
1. 任务创建时预扣积分
# AI Service 创建任务时
async def generate_image(self, user_id: UUID, ...):
# 1. 计算所需积分
model_config = await self._get_model(model_name, 'image')
credits_needed = model_config.credits_per_unit
# 2. 调用 Credit Service 预扣积分(同步,事务保证)
try:
consumption_log = await self.credit_service.consume_credits(
user_id=user_id,
amount=credits_needed,
feature_type=1, # FeatureType.IMAGE_GENERATION
ai_job_id=None, # 稍后更新
task_params={...}
)
except InsufficientCreditsError:
raise ValidationError("积分不足")
# 3. 创建 AI 任务,关联 consumption_log
job = await self.job_repository.create({
'consumption_log_id': consumption_log.consumption_id,
...
})
# 4. 更新 consumption_log 的 ai_job_id
consumption_log.ai_job_id = job.ai_job_id
await self.db.commit()
2. 任务完成时确认扣除
# Celery Worker 任务完成
@celery_app.task
def generate_image_task(job_id: UUID):
try:
# 执行 AI 生成
result = call_ai_provider(...)
# 更新任务状态
await ai_service.update_job(job_id, {'status': 3, ...}) # AIJobStatus.COMPLETED
# 确认积分消耗
job = await ai_service.get_job(job_id)
await credit_service.confirm_consumption(
consumption_id=job.consumption_log_id,
resource_id=result.resource_id
)
except Exception as e:
# 任务失败,退还积分
await ai_service.update_job(job_id, {'status': 4, ...}) # AIJobStatus.FAILED
await credit_service.refund_credits(
consumption_id=job.consumption_log_id,
reason=str(e)
)
数据关联
-- AI Service 表关联 Credit Service(无外键约束,应用层保证引用完整性)
ALTER TABLE ai_jobs
ADD COLUMN consumption_log_id UUID;
CREATE INDEX idx_ai_jobs_consumption_log_id ON ai_jobs (consumption_log_id)
WHERE consumption_log_id IS NOT NULL;
-- Credit Service 表关联 AI Service(无外键约束,应用层保证引用完整性)
ALTER TABLE credit_consumption_logs
ADD COLUMN ai_job_id UUID;
CREATE INDEX idx_credit_consumption_logs_ai_job_id ON credit_consumption_logs (ai_job_id)
WHERE ai_job_id IS NOT NULL;
COMMENT ON COLUMN ai_jobs.consumption_log_id IS '积分消耗日志 ID - 应用层验证';
COMMENT ON COLUMN credit_consumption_logs.ai_job_id IS 'AI 任务 ID - 应用层验证';
事务保证
使用数据库事务确保积分扣除和任务创建的原子性:
async with self.db.begin():
# 1. 扣除积分
consumption_log = await credit_service.consume_credits(...)
# 2. 创建任务
job = await ai_service.create_job(...)
# 3. 关联
consumption_log.ai_job_id = job.ai_job_id
# 提交事务(失败自动回滚)
数据库设计
AI 服务需要以下数据表支撑任务管理、成本控制和配额管理。
3.1 ai_jobs(AI 任务表)
核心表,记录所有 AI 生成任务的状态和结果。
-- Python 枚举定义(app/models/ai_job.py)
-- class AIJobType(IntEnum):
-- IMAGE = 1 # 图片生成
-- VIDEO = 2 # 视频生成
-- SOUND = 3 # 音效生成
-- VOICE = 4 # 配音生成
-- SUBTITLE = 5 # 字幕生成
-- TEXT_PROCESSING = 6 # 文本处理(剧本拆解等)
-- RESOURCE = 7 # 资源生成
-- STORYBOARD_SCRIPT = 8 # 分镜脚本生成
-- SCRIPT_GENERATION = 9 # 剧本生成
-- class AIJobStatus(IntEnum):
-- PENDING = 1 # 等待处理
-- PROCESSING = 2 # 处理中
-- COMPLETED = 3 # 已完成
-- FAILED = 4 # 失败
-- CANCELLED = 5 # 已取消
CREATE TABLE ai_jobs (
ai_job_id UUID PRIMARY KEY, -- AI 任务唯一标识
-- 关联信息(无外键约束,应用层保证引用完整性)
user_id UUID NOT NULL, -- 用户 ID
project_id UUID, -- 项目 ID(可选)
storyboard_id UUID, -- 分镜 ID(可选)
-- 积分关联(无外键约束,应用层保证引用完整性)
consumption_log_id UUID, -- 积分消耗日志 ID
-- 任务信息
job_type SMALLINT NOT NULL, -- 任务类型(1=图片 2=视频 3=音效 4=配音 5=字幕 6=文本处理 7=资源 8=分镜脚本 9=剧本生成)
status SMALLINT NOT NULL DEFAULT 1, -- 任务状态(1=等待处理 2=处理中 3=已完成 4=失败 5=已取消)
-- 输入输出
input_data JSONB NOT NULL DEFAULT '{}', -- 输入参数(prompt, 配置等)
output_data JSONB, -- 输出结果(URL, 元数据等)
-- AI 模型信息(无外键约束,应用层保证引用完整性)
model_id UUID, -- AI 模型 ID
model_name TEXT, -- 使用的模型名称(冗余字段)
-- 任务状态
progress INTEGER NOT NULL DEFAULT 0 CHECK (progress >= 0 AND progress <= 100), -- 任务进度(0-100)
error_message TEXT, -- 错误信息
task_id TEXT, -- Celery 异步任务 ID
-- 时间信息
estimated_completion_at TIMESTAMPTZ, -- 预计完成时间
started_at TIMESTAMPTZ, -- 开始处理时间
completed_at TIMESTAMPTZ, -- 完成时间
created_at TIMESTAMPTZ NOT NULL DEFAULT now(), -- 创建时间
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), -- 更新时间
-- 成本信息
cost NUMERIC(10, 4), -- 实际成本(元)
credits_used INTEGER NOT NULL DEFAULT 0 -- 使用的积分
);
-- 字段注释
COMMENT ON TABLE ai_jobs IS 'AI 任务表 - 应用层保证引用完整性';
COMMENT ON COLUMN ai_jobs.ai_job_id IS 'AI 任务唯一标识';
COMMENT ON COLUMN ai_jobs.user_id IS '用户 ID - 应用层验证';
COMMENT ON COLUMN ai_jobs.project_id IS '项目 ID(可选)- 应用层验证';
COMMENT ON COLUMN ai_jobs.storyboard_id IS '分镜 ID(可选)- 应用层验证';
COMMENT ON COLUMN ai_jobs.consumption_log_id IS '积分消耗日志 ID - 应用层验证';
COMMENT ON COLUMN ai_jobs.job_type IS '任务类型(1=图片 2=视频 3=音效 4=配音 5=字幕 6=文本处理 7=资源 8=分镜脚本 9=剧本生成)';
COMMENT ON COLUMN ai_jobs.status IS '任务状态(1=等待处理 2=处理中 3=已完成 4=失败 5=已取消)';
COMMENT ON COLUMN ai_jobs.input_data IS '输入参数(prompt, 配置等)';
COMMENT ON COLUMN ai_jobs.output_data IS '输出结果(URL, 元数据等)';
COMMENT ON COLUMN ai_jobs.model_id IS 'AI 模型 ID - 应用层验证';
COMMENT ON COLUMN ai_jobs.model_name IS '使用的模型名称(冗余字段)';
COMMENT ON COLUMN ai_jobs.progress IS '任务进度(0-100)';
COMMENT ON COLUMN ai_jobs.error_message IS '错误信息';
COMMENT ON COLUMN ai_jobs.task_id IS 'Celery 异步任务 ID';
COMMENT ON COLUMN ai_jobs.estimated_completion_at IS '预计完成时间';
COMMENT ON COLUMN ai_jobs.started_at IS '开始处理时间';
COMMENT ON COLUMN ai_jobs.completed_at IS '完成时间';
COMMENT ON COLUMN ai_jobs.created_at IS '创建时间';
COMMENT ON COLUMN ai_jobs.updated_at IS '更新时间';
COMMENT ON COLUMN ai_jobs.cost IS '实际成本(元)';
COMMENT ON COLUMN ai_jobs.credits_used IS '使用的积分';
-- 索引
CREATE INDEX idx_ai_jobs_user_id ON ai_jobs (user_id);
CREATE INDEX idx_ai_jobs_project_id ON ai_jobs (project_id) WHERE project_id IS NOT NULL;
CREATE INDEX idx_ai_jobs_storyboard_id ON ai_jobs (storyboard_id) WHERE storyboard_id IS NOT NULL;
CREATE INDEX idx_ai_jobs_type ON ai_jobs (job_type);
CREATE INDEX idx_ai_jobs_status ON ai_jobs (status);
CREATE INDEX idx_ai_jobs_model_id ON ai_jobs (model_id) WHERE model_id IS NOT NULL;
CREATE INDEX idx_ai_jobs_consumption_log_id ON ai_jobs (consumption_log_id) WHERE consumption_log_id IS NOT NULL;
CREATE INDEX idx_ai_jobs_created_at ON ai_jobs (created_at);
CREATE INDEX idx_ai_jobs_status_created_at ON ai_jobs (status, created_at)
WHERE status IN (1, 2); -- PENDING, PROCESSING
CREATE INDEX idx_ai_jobs_input_data_gin ON ai_jobs USING GIN (input_data);
CREATE INDEX idx_ai_jobs_output_data_gin ON ai_jobs USING GIN (output_data)
WHERE output_data IS NOT NULL;
-- 触发器
CREATE TRIGGER update_ai_jobs_updated_at
BEFORE UPDATE ON ai_jobs
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
字段说明:
-
input_data:存储任务输入参数{ "prompt": "一只可爱的猫咪", "width": 1024, "height": 1024, "style": "realistic", "temperature": 0.7 } -
output_data:存储任务输出结果{ "file_url": "https://storage.jointo.ai/ai-generated/images/abc123def456.png", "file_size": 1024000, "checksum": "abc123def456...", "mime_type": "image/png", "original_url": "https://ai-provider.com/temp/xyz789.png", "meta_data": { "width": 1024, "height": 1024, "format": "png" } }字段说明:
file_url: 自有 OSS 存储的文件 URL(永久有效)file_size: 文件大小(字节)checksum: 文件 SHA256 校验和(用于去重)mime_type: 文件 MIME 类型original_url: AI 提供商返回的原始 URL(可选,用于调试)meta_data: 文件元数据(宽度、高度、格式等)
3.2 ai_models(AI 模型配置表)
推荐表,管理可用的 AI 模型及其定价配置。
-- Python 枚举定义(app/models/ai_model.py)
-- class AIModelType(IntEnum):
-- TEXT = 1 # 文本模型(GPT, Claude 等)
-- IMAGE = 2 # 图片模型(DALL-E, Stable Diffusion 等)
-- VIDEO = 3 # 视频模型(Runway, Pika 等)
-- AUDIO = 4 # 音频模型(TTS, STT 等)
-- class AIProvider(IntEnum):
-- OPENAI = 1
-- ANTHROPIC = 2
-- GOOGLE = 3 # Google Gemini
-- STABILITY = 4
-- RUNWAY = 5
-- PIKA = 6
-- ELEVENLABS = 7
-- AZURE = 8
-- BAIDU = 9
-- ALIYUN = 10
-- CUSTOM = 99 # 自定义提供商
-- class UnitType(IntEnum):
-- TOKEN = 1 # Token(文本模型)
-- IMAGE = 2 # 图片(图片模型)
-- SECOND = 3 # 秒(视频/音频模型)
-- REQUEST = 4 # 请求(通用计费单位)
CREATE TABLE ai_models (
model_id UUID PRIMARY KEY, -- AI 模型唯一标识
-- 模型信息
model_name TEXT NOT NULL UNIQUE, -- 模型名称(如:gpt-4, dall-e-3, runway-gen2)
display_name TEXT NOT NULL, -- 显示名称
description TEXT, -- 模型描述
-- 分类
provider SMALLINT NOT NULL, -- 提供商(1=OpenAI 2=Anthropic 3=Google 4=Stability 5=Runway 6=Pika 7=ElevenLabs 8=Azure 9=Baidu 10=Aliyun 99=自定义)
model_type SMALLINT NOT NULL, -- 模型类型(1=文本 2=图片 3=视频 4=音频)
-- 定价(按使用量计费)
cost_per_unit NUMERIC(10, 4) NOT NULL, -- 单位成本(元)
unit_type SMALLINT NOT NULL, -- 单位类型(1=Token 2=图片 3=秒 4=请求)
credits_per_unit INTEGER NOT NULL, -- 每单位消耗积分
-- 配置
config JSONB NOT NULL DEFAULT '{}', -- 模型特定配置
-- 限制
rate_limit INTEGER, -- 速率限制(每分钟请求数)
daily_quota INTEGER, -- 每日配额
-- 状态
is_active BOOLEAN NOT NULL DEFAULT true, -- 是否启用
is_beta BOOLEAN NOT NULL DEFAULT false, -- 是否为测试版
-- 审计
created_at TIMESTAMPTZ NOT NULL DEFAULT now(), -- 创建时间
updated_at TIMESTAMPTZ NOT NULL DEFAULT now() -- 更新时间
);
-- 字段注释
COMMENT ON COLUMN ai_models.model_id IS 'AI 模型唯一标识';
COMMENT ON COLUMN ai_models.model_name IS '模型名称(如:gpt-4, dall-e-3, runway-gen2)';
COMMENT ON COLUMN ai_models.display_name IS '显示名称';
COMMENT ON COLUMN ai_models.description IS '模型描述';
COMMENT ON COLUMN ai_models.provider IS '提供商(1=OpenAI 2=Anthropic 3=Google 4=Stability 5=Runway 6=Pika 7=ElevenLabs 8=Azure 9=Baidu 10=Aliyun 99=自定义)';
COMMENT ON COLUMN ai_models.model_type IS '模型类型(1=文本 2=图片 3=视频 4=音频)';
COMMENT ON COLUMN ai_models.cost_per_unit IS '单位成本(元)';
COMMENT ON COLUMN ai_models.unit_type IS '单位类型(1=Token 2=图片 3=秒 4=请求)';
COMMENT ON COLUMN ai_models.credits_per_unit IS '每单位消耗积分';
COMMENT ON COLUMN ai_models.config IS '模型特定配置';
COMMENT ON COLUMN ai_models.rate_limit IS '速率限制(每分钟请求数)';
COMMENT ON COLUMN ai_models.daily_quota IS '每日配额';
COMMENT ON COLUMN ai_models.is_active IS '是否启用';
COMMENT ON COLUMN ai_models.is_beta IS '是否为测试版';
COMMENT ON COLUMN ai_models.created_at IS '创建时间';
COMMENT ON COLUMN ai_models.updated_at IS '更新时间';
-- 索引
CREATE INDEX idx_ai_models_provider ON ai_models (provider) WHERE is_active = true;
CREATE INDEX idx_ai_models_type ON ai_models (model_type) WHERE is_active = true;
CREATE INDEX idx_ai_models_is_active ON ai_models (is_active);
CREATE INDEX idx_ai_models_config_gin ON ai_models USING GIN (config);
-- 触发器
CREATE TRIGGER update_ai_models_updated_at
BEFORE UPDATE ON ai_models
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
设计说明:
- 动态配置模型,无需修改代码
- 支持多种计费单位(token、图片、秒、请求)
- 使用 JSONB 存储模型特定配置
- 支持速率限制和配额管理
- 新增提供商:
google:支持 Gemini Pro 和 Gemini 1.5 Procustom:支持自定义 AI 提供商(用户自建模型、第三方 API 等)
Custom 提供商使用场景:
- 用户自建的 AI 模型服务
- 第三方 AI API 集成
- 企业内部 AI 服务
- 测试和开发环境
Custom 提供商配置示例:
{
"api_endpoint": "https://custom-ai.example.com/v1/generate",
"auth_type": "bearer",
"headers": {
"Authorization": "Bearer ${API_KEY}"
},
"request_format": "json",
"response_format": "json"
}
3.3 ai_usage_logs(AI 使用日志表)
计费必需表,记录每次 AI 调用的详细使用情况。
CREATE TABLE ai_usage_logs (
usage_log_id UUID PRIMARY KEY, -- 使用日志唯一标识
-- 关联信息(无外键约束,应用层保证引用完整性)
user_id UUID NOT NULL, -- 用户 ID
ai_job_id UUID NOT NULL, -- AI 任务 ID
model_id UUID, -- AI 模型 ID
-- 使用量
units_used NUMERIC(10, 2) NOT NULL, -- 使用的单位数(tokens, 秒数等)
unit_type SMALLINT NOT NULL CHECK (unit_type BETWEEN 1 AND 4), -- 单位类型(1=Token 2=图片 3=秒 4=请求)
-- 成本
cost NUMERIC(10, 4) NOT NULL, -- 实际成本(元)
credits_used INTEGER NOT NULL, -- 消耗的积分
-- 元数据
meta_data JSONB NOT NULL DEFAULT '{}', -- 使用元数据
-- 时间
created_at TIMESTAMPTZ NOT NULL DEFAULT now() -- 创建时间
);
-- 字段注释
COMMENT ON TABLE ai_usage_logs IS 'AI 使用日志表 - 应用层保证引用完整性';
COMMENT ON COLUMN ai_usage_logs.usage_log_id IS '使用日志唯一标识';
COMMENT ON COLUMN ai_usage_logs.user_id IS '用户 ID - 应用层验证';
COMMENT ON COLUMN ai_usage_logs.ai_job_id IS 'AI 任务 ID - 应用层验证';
COMMENT ON COLUMN ai_usage_logs.model_id IS 'AI 模型 ID - 应用层验证';
COMMENT ON COLUMN ai_usage_logs.units_used IS '使用的单位数(tokens, 秒数等)';
COMMENT ON COLUMN ai_usage_logs.unit_type IS '单位类型(1=Token 2=图片 3=秒 4=请求)';
COMMENT ON COLUMN ai_usage_logs.cost IS '实际成本(元)';
COMMENT ON COLUMN ai_usage_logs.credits_used IS '消耗的积分';
COMMENT ON COLUMN ai_usage_logs.meta_data IS '使用元数据';
COMMENT ON COLUMN ai_usage_logs.created_at IS '创建时间';
-- 索引
CREATE INDEX idx_ai_usage_logs_user_id ON ai_usage_logs (user_id);
CREATE INDEX idx_ai_usage_logs_ai_job_id ON ai_usage_logs (ai_job_id);
CREATE INDEX idx_ai_usage_logs_model_id ON ai_usage_logs (model_id) WHERE model_id IS NOT NULL;
CREATE INDEX idx_ai_usage_logs_created_at ON ai_usage_logs (created_at);
CREATE INDEX idx_ai_usage_logs_user_created ON ai_usage_logs (user_id, created_at);
CREATE INDEX idx_ai_usage_logs_meta_data_gin ON ai_usage_logs USING GIN (meta_data);
-- 分区策略(按月分区,便于归档)
-- CREATE TABLE ai_usage_logs_2026_01 PARTITION OF ai_usage_logs
-- FOR VALUES FROM ('2026-01-01') TO ('2026-02-01');
设计说明:
- 每次 AI 调用都记录详细日志
- 支持成本分析和用户账单生成
- 建议按月分区,便于历史数据归档
- 使用 JSONB 存储详细的使用元数据
3.4 ai_quotas(AI 配额表)
推荐表,管理用户的 AI 使用配额和限流。
-- Python 枚举定义(app/models/ai_quota.py)
-- class QuotaPeriod(IntEnum):
-- DAILY = 1 # 每日配额
-- MONTHLY = 2 # 每月配额
-- TOTAL = 3 # 总配额
CREATE TABLE ai_quotas (
quota_id UUID PRIMARY KEY, -- 配额唯一标识
-- 用户信息(无外键约束,应用层保证引用完整性)
user_id UUID NOT NULL, -- 用户 ID
-- 配额类型
quota_type TEXT NOT NULL, -- 配额类型(如:image_generation, video_generation, text_processing)
period SMALLINT NOT NULL CHECK (period BETWEEN 1 AND 3), -- 配额周期(1=每日 2=每月 3=总计)
-- 配额限制
total_quota INTEGER NOT NULL, -- 总配额
used_quota INTEGER NOT NULL DEFAULT 0, -- 已使用配额
-- 重置时间
reset_at TIMESTAMPTZ NOT NULL, -- 重置时间
-- 审计
created_at TIMESTAMPTZ NOT NULL DEFAULT now(), -- 创建时间
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), -- 更新时间
-- 唯一约束
CONSTRAINT ai_quotas_unique UNIQUE (user_id, quota_type, period) NULLS NOT DISTINCT,
-- 检查约束
CONSTRAINT ai_quotas_used_check CHECK (used_quota >= 0 AND used_quota <= total_quota)
);
-- 字段注释
COMMENT ON TABLE ai_quotas IS 'AI 配额表 - 应用层保证引用完整性';
COMMENT ON COLUMN ai_quotas.quota_id IS '配额唯一标识';
COMMENT ON COLUMN ai_quotas.user_id IS '用户 ID - 应用层验证';
COMMENT ON COLUMN ai_quotas.quota_type IS '配额类型(如:image_generation, video_generation, text_processing)';
COMMENT ON COLUMN ai_quotas.period IS '配额周期(1=每日 2=每月 3=总计)';
COMMENT ON COLUMN ai_quotas.total_quota IS '总配额';
COMMENT ON COLUMN ai_quotas.used_quota IS '已使用配额';
COMMENT ON COLUMN ai_quotas.reset_at IS '重置时间';
COMMENT ON COLUMN ai_quotas.created_at IS '创建时间';
COMMENT ON COLUMN ai_quotas.updated_at IS '更新时间';
-- 索引
CREATE INDEX idx_ai_quotas_user_id ON ai_quotas (user_id);
CREATE INDEX idx_ai_quotas_type ON ai_quotas (quota_type);
CREATE INDEX idx_ai_quotas_reset_at ON ai_quotas (reset_at);
CREATE INDEX idx_ai_quotas_user_type ON ai_quotas (user_id, quota_type);
-- 触发器
CREATE TRIGGER update_ai_quotas_updated_at
BEFORE UPDATE ON ai_quotas
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
-- 自动重置配额的函数
CREATE OR REPLACE FUNCTION reset_expired_quotas()
RETURNS void AS $$
BEGIN
-- 重置每日配额 (period = 1)
UPDATE ai_quotas
SET used_quota = 0,
reset_at = reset_at + INTERVAL '1 day',
updated_at = now()
WHERE period = 1 AND reset_at <= now();
-- 重置每月配额 (period = 2)
UPDATE ai_quotas
SET used_quota = 0,
reset_at = reset_at + INTERVAL '1 month',
updated_at = now()
WHERE period = 2 AND reset_at <= now();
END;
$$ LANGUAGE plpgsql;
-- 定时任务(需要配合 pg_cron 或应用层定时任务)
-- SELECT cron.schedule('reset-ai-quotas', '0 0 * * *', 'SELECT reset_expired_quotas()');
设计说明:
- 支持每日、每月、总配额三种类型
- 自动重置过期配额
- 防止配额超限(CHECK 约束)
- 支持不同类型的配额管理
3.5 数据表关系图
┌─────────────┐
│ users │
└──────┬──────┘
│
├──────────────────────────────────┐
│ │
▼ ▼
┌─────────────┐ ┌─────────────┐
│ ai_jobs │───────────────────>│ ai_models │
└──────┬──────┘ └─────────────┘
│ ▲
│ │
▼ │
┌─────────────┐ │
│ai_usage_logs│──────────────────────────┘
└─────────────┘
┌─────────────┐
│ ai_quotas │
└─────────────┘
▲
│
└──────── users
Pydantic Schema 定义
请求 Schema
# app/schemas/ai_job.py
from pydantic import BaseModel, Field, field_validator
from uuid import UUID
from typing import Optional, Dict, Any
from enum import IntEnum
class AIJobType(IntEnum):
"""AI 任务类型"""
IMAGE = 1
VIDEO = 2
SOUND = 3
VOICE = 4
SUBTITLE = 5
TEXT_PROCESSING = 6
RESOURCE = 7
STORYBOARD_SCRIPT = 8
SCRIPT_GENERATION = 9
class VideoType(str):
"""视频生成类型"""
TEXT2VIDEO = "text2video"
IMG2VIDEO = "img2video"
KEYFRAME = "keyframe"
FUSION = "fusion"
REPLACE = "replace"
class TextProcessingTaskType(str):
"""文本处理任务类型"""
SCREENPLAY_PARSE = "screenplay_parse"
CONTENT_ANALYSIS = "content_analysis"
STYLE_TRANSFORM = "style_transform"
PROMPT_GENERATION = "prompt_generation"
# ==================== 图片生成 ====================
class ImageGenerationRequest(BaseModel):
"""图片生成请求"""
prompt: str = Field(..., description="提示词", min_length=1, max_length=2000)
model: Optional[str] = Field(None, description="模型名称")
width: int = Field(1024, description="宽度", ge=256, le=2048)
height: int = Field(1024, description="高度", ge=256, le=2048)
style: Optional[str] = Field(None, description="风格")
class Config:
json_schema_extra = {
"example": {
"prompt": "一只可爱的猫咪在花园里玩耍",
"model": "stable_diffusion",
"width": 1024,
"height": 1024,
"style": "realistic"
}
}
# ==================== 视频生成 ====================
class VideoGenerationRequest(BaseModel):
"""视频生成请求"""
video_type: str = Field(..., description="视频类型: text2video, img2video, keyframe, fusion, replace")
prompt: Optional[str] = Field(None, description="提示词(text2video 必需)")
image_url: Optional[str] = Field(None, description="图片 URL(img2video 必需)")
duration: int = Field(5, description="时长(秒)", ge=1, le=60)
fps: int = Field(30, description="帧率", ge=15, le=60)
model: Optional[str] = Field(None, description="模型名称")
@field_validator('video_type')
@classmethod
def validate_video_type(cls, v):
valid_types = ['text2video', 'img2video', 'keyframe', 'fusion', 'replace']
if v not in valid_types:
raise ValueError(f"video_type 必须是以下之一: {', '.join(valid_types)}")
return v
@field_validator('prompt')
@classmethod
def validate_prompt(cls, v, info):
if info.data.get('video_type') == 'text2video' and not v:
raise ValueError("text2video 类型必须提供 prompt")
return v
@field_validator('image_url')
@classmethod
def validate_image_url(cls, v, info):
if info.data.get('video_type') == 'img2video' and not v:
raise ValueError("img2video 类型必须提供 image_url")
return v
# ==================== 音效生成 ====================
class SoundGenerationRequest(BaseModel):
"""音效生成请求"""
description: str = Field(..., description="音效描述", min_length=1, max_length=500)
duration: int = Field(5, description="时长(秒)", ge=1, le=30)
sound_type: Optional[str] = Field(None, description="音效类型")
model: Optional[str] = Field(None, description="模型名称")
# ==================== 配音生成 ====================
class VoiceGenerationRequest(BaseModel):
"""配音生成请求"""
text: str = Field(..., description="文本内容", min_length=1, max_length=5000)
voice_type: str = Field("alloy", description="语音类型")
speed: float = Field(1.0, description="语速", ge=0.5, le=2.0)
language: str = Field("zh-CN", description="语言")
model: Optional[str] = Field(None, description="模型名称")
# ==================== 字幕生成 ====================
class SubtitleGenerationRequest(BaseModel):
"""字幕生成请求"""
audio_url: str = Field(..., description="音频 URL")
language: str = Field("zh", description="语言")
model: Optional[str] = Field(None, description="模型名称")
# ==================== 文本处理 ====================
class TextProcessingRequest(BaseModel):
"""文本处理请求"""
task_type: str = Field(..., description="任务类型: screenplay_parse, content_analysis, style_transform, prompt_generation")
text: str = Field(..., description="待处理文本", min_length=1, max_length=50000)
model: Optional[str] = Field(None, description="模型名称")
output_format: str = Field("json", description="输出格式: json, text")
temperature: float = Field(0.7, description="生成温度", ge=0.0, le=2.0)
max_tokens: int = Field(4000, description="最大 token 数", ge=100, le=8000)
@field_validator('task_type')
@classmethod
def validate_task_type(cls, v):
valid_types = ['screenplay_parse', 'content_analysis', 'style_transform', 'prompt_generation']
if v not in valid_types:
raise ValueError(f"task_type 必须是以下之一: {', '.join(valid_types)}")
return v
# ==================== 统一创建请求 ====================
class AIJobCreateRequest(BaseModel):
"""AI 任务创建请求(统一入口)"""
job_type: str = Field(..., description="任务类型: image, video, sound, voice, subtitle, textProcessing")
params: Dict[str, Any] = Field(..., description="任务参数")
@field_validator('job_type')
@classmethod
def validate_job_type(cls, v):
valid_types = ['image', 'video', 'sound', 'voice', 'subtitle', 'textProcessing']
if v not in valid_types:
raise ValueError(f"job_type 必须是以下之一: {', '.join(valid_types)}")
return v
响应 Schema
# app/schemas/ai_job.py (续)
from datetime import datetime
class AIJobResponse(BaseModel):
"""AI 任务响应"""
job_id: UUID = Field(..., description="任务 ID")
task_id: str = Field(..., description="Celery 任务 ID")
status: str = Field(..., description="任务状态")
estimated_credits: int = Field(..., description="预估积分消耗")
class Config:
json_schema_extra = {
"example": {
"job_id": "019d1234-5678-7abc-def0-111111111111",
"task_id": "abc-123-def",
"status": "pending",
"estimated_credits": 10
}
}
class AIJobStatusResponse(BaseModel):
"""AI 任务状态响应"""
job_id: UUID
job_type: int
status: int
progress: int
input_data: Dict[str, Any]
output_data: Optional[Dict[str, Any]]
error_message: Optional[str]
model_name: Optional[str]
cost: Optional[float]
credits_used: int
created_at: datetime
updated_at: datetime
started_at: Optional[datetime]
completed_at: Optional[datetime]
class AIJobListResponse(BaseModel):
"""AI 任务列表响应"""
items: list[AIJobStatusResponse]
total: int
page: int
page_size: int
total_pages: int
class AIModelResponse(BaseModel):
"""AI 模型响应"""
model_id: UUID
model_name: str
display_name: str
description: Optional[str]
provider: int
model_type: int
cost_per_unit: float
unit_type: int
credits_per_unit: int
is_beta: bool
config: Dict[str, Any]
class AIUsageStatsResponse(BaseModel):
"""AI 使用统计响应"""
total_cost: float
total_credits_used: int
total_requests: int
by_model: Dict[str, Dict[str, Any]]
by_type: Dict[str, Dict[str, Any]]
quotas: list[Dict[str, Any]]
服务实现
AIService 类
# app/services/ai_service.py
from typing import Dict, Any, Optional
from uuid import UUID
from app.tasks.ai_tasks import (
generate_image_task,
generate_video_task,
generate_sound_task,
generate_voice_task,
generate_subtitle_task,
process_text_task
)
from app.repositories.ai_job_repository import AIJobRepository
from app.repositories.ai_model_repository import AIModelRepository
from app.repositories.ai_usage_log_repository import AIUsageLogRepository
from app.repositories.ai_quota_repository import AIQuotaRepository
from app.services.credit_service import CreditService
from app.core.exceptions import InsufficientCreditsError, ValidationError
from sqlalchemy.orm import Session
class AIService:
def __init__(self, db: Session):
self.db = db
self.job_repository = AIJobRepository(db)
self.model_repository = AIModelRepository(db)
self.usage_log_repository = AIUsageLogRepository(db)
self.quota_repository = AIQuotaRepository(db)
self.credit_service = CreditService(db) # 注入 Credit Service
# ==================== 枚举验证方法 ====================
def _validate_job_type(self, job_type: int) -> None:
"""验证任务类型枚举值"""
from app.models.ai_job import AIJobType
if job_type not in AIJobType.to_dict().keys():
raise ValidationError(f"无效的任务类型: {job_type}")
def _validate_job_status(self, job_status: int) -> None:
"""验证任务状态枚举值"""
from app.models.ai_job import AIJobStatus
if job_status not in AIJobStatus.to_dict().keys():
raise ValidationError(f"无效的任务状态: {job_status}")
def _validate_provider(self, provider: int) -> None:
"""验证提供商枚举值"""
from app.models.ai_model import AIModelProvider
if provider not in AIModelProvider.to_dict().keys():
raise ValidationError(f"无效的提供商: {provider}")
def _validate_model_type(self, model_type: int) -> None:
"""验证模型类型枚举值"""
from app.models.ai_model import AIModelType
if model_type not in AIModelType.to_dict().keys():
raise ValidationError(f"无效的模型类型: {model_type}")
def _validate_unit_type(self, unit_type: int) -> None:
"""验证单位类型枚举值"""
from app.models.ai_model import AIModelUnitType
if unit_type not in AIModelUnitType.to_dict().keys():
raise ValidationError(f"无效的单位类型: {unit_type}")
# ==================== 私有方法 ====================
async def _check_quota(self, user_id: UUID, quota_type: str) -> bool:
"""检查用户配额是否充足"""
quota = await self.quota_repository.get_user_quota(user_id, quota_type)
if not quota:
return True # 没有配额限制
if quota.used_quota >= quota.total_quota:
from app.core.exceptions import QuotaExceededError
raise QuotaExceededError(f"配额不足:{quota_type}")
return True
def _get_feature_type(self, job_type: int) -> int:
"""将 AIJobType 映射到 Credit Service 的 feature_type
Args:
job_type: AIJobType 枚举值
Returns:
Credit Service 的 feature_type 枚举值(SMALLINT)
"""
from app.models.ai_job import AIJobType
# 直接返回 job_type,因为 AI Service 和 Credit Service 使用相同的枚举值
# AIJobType: 1=图片 2=视频 3=音效 4=配音 5=字幕 6=文本处理 7=资源 8=分镜脚本 9=剧本生成
# FeatureType: 1=图片生成 2=视频生成 3=文本处理 4=音频生成 5=音效生成 6=配音生成 7=字幕生成 8=资源生成 9=分镜脚本生成 10=剧本生成
# 需要特殊处理的映射:
# AIJobType.TEXT_PROCESSING (6) -> FeatureType.TEXT_PROCESSING (3)
# AIJobType.RESOURCE (7) -> FeatureType.RESOURCE_GENERATION (8)
# AIJobType.STORYBOARD_SCRIPT (8) -> FeatureType.STORYBOARD_GENERATION (9)
# AIJobType.SCRIPT_GENERATION (9) -> FeatureType.SCRIPT_GENERATION (10)
feature_type_map = {
AIJobType.IMAGE: 1, # 图片生成
AIJobType.VIDEO: 2, # 视频生成
AIJobType.TEXT_PROCESSING: 3, # 文本处理
AIJobType.SOUND: 5, # 音效生成
AIJobType.VOICE: 6, # 配音生成
AIJobType.SUBTITLE: 7, # 字幕生成
AIJobType.RESOURCE: 8, # 资源生成
AIJobType.STORYBOARD_SCRIPT: 9, # 分镜脚本生成
AIJobType.SCRIPT_GENERATION: 10 # 剧本生成
}
return feature_type_map.get(job_type, 1) # 默认返回图片生成
async def _get_model(self, model_name: Optional[str], model_type: int) -> Dict[str, Any]:
"""获取模型配置
Args:
model_name: 模型名称
model_type: 模型类型(AIModelType 枚举值)
"""
# 验证模型类型枚举值
self._validate_model_type(model_type)
if model_name:
model = await self.model_repository.get_by_name(model_name)
else:
# 获取默认模型
model = await self.model_repository.get_default_model(model_type)
if not model or not model.is_active:
from app.core.exceptions import NotFoundError
raise NotFoundError(f"模型不可用: {model_name}")
return model
async def _record_usage(
self,
user_id: UUID,
job_id: UUID,
model_id: UUID,
units_used: float,
unit_type: int, # UnitType 枚举值
cost: float,
credits_used: int,
meta_data: Dict[str, Any]
) -> None:
"""记录使用日志"""
# 验证单位类型枚举值
self._validate_unit_type(unit_type)
await self.usage_log_repository.create({
'user_id': user_id,
'ai_job_id': job_id,
'model_id': model_id,
'units_used': units_used,
'unit_type': unit_type,
'cost': cost,
'credits_used': credits_used,
'meta_data': meta_data
})
async def generate_image(
self,
user_id: UUID,
prompt: str,
model: Optional[str] = None,
width: int = 1024,
height: int = 1024,
style: Optional[str] = None,
**kwargs
) -> Dict[str, Any]:
"""生成图片(异步)
应用层引用完整性保证:
1. 验证 user_id 是否存在
2. 验证 model_id 是否存在
3. 使用事务确保原子性
"""
# 1. 验证用户是否存在(应用层引用完整性保证)
from app.repositories.user_repository import UserRepository
user_repo = UserRepository(self.db)
if not await user_repo.exists(user_id):
raise ValidationError("用户不存在")
# 2. 检查配额
await self._check_quota(user_id, 'image_generation')
# 3. 获取模型配置(AIModelType.IMAGE = 2)
from app.models.ai_model import AIModelType
model_config = await self._get_model(model, AIModelType.IMAGE)
# 4. 验证模型是否存在(应用层引用完整性保证)
if not await self.model_repository.exists(model_config.model_id):
raise ValidationError("AI 模型不存在")
# 5. 计算所需积分
credits_needed = model_config.credits_per_unit
# 6. 使用事务确保原子性
async with self.db.begin():
# 预扣积分(同步,事务保证)
try:
consumption_log = await self.credit_service.consume_credits(
user_id=user_id,
amount=credits_needed,
feature_type=1, # FeatureType.IMAGE_GENERATION
ai_job_id=None, # 稍后更新
task_params={
'prompt': prompt,
'model': model_config.model_name,
'width': width,
'height': height,
'style': style,
**kwargs
}
)
except InsufficientCreditsError as e:
raise ValidationError(f"积分不足: {str(e)}")
# 创建任务记录
from app.models.ai_job import AIJobType, AIJobStatus
job = await self.job_repository.create({
'job_type': AIJobType.IMAGE,
'status': AIJobStatus.PENDING,
'user_id': user_id,
'model_id': model_config.model_id,
'model_name': model_config.model_name,
'consumption_log_id': consumption_log.consumption_id,
'input_data': {
'prompt': prompt,
'width': width,
'height': height,
'style': style,
**kwargs
}
})
# 更新 consumption_log 的 ai_job_id
consumption_log.ai_job_id = job.ai_job_id
consumption_log.task_id = str(job.ai_job_id)
await self.db.commit()
# 7. 提交异步任务
task = generate_image_task.delay(
job_id=job.ai_job_id,
prompt=prompt,
model=model_config.model_name,
width=width,
height=height,
style=style,
**kwargs
)
# 8. 更新任务 ID
await self.job_repository.update(job.ai_job_id, {'task_id': task.id})
return {
'job_id': str(job.ai_job_id),
'task_id': task.id,
'status': 'pending',
'estimated_cost': float(model_config.cost_per_unit),
'estimated_credits': credits_needed
}
async def generate_video(
self,
user_id: UUID,
video_type: str,
prompt: Optional[str] = None,
image_url: Optional[str] = None,
duration: int = 5,
fps: int = 30,
model: Optional[str] = None,
**kwargs
) -> Dict[str, Any]:
"""生成视频(异步)"""
# 验证参数
if video_type == 'text2video' and not prompt:
raise ValidationError("文本转视频需要提供 prompt")
if video_type == 'img2video' and not image_url:
raise ValidationError("图片转视频需要提供 image_url")
# 检查配额
await self._check_quota(user_id, 'video_generation')
# 获取模型配置(AIModelType.VIDEO = 3)
from app.models.ai_model import AIModelType
model_config = await self._get_model(model, AIModelType.VIDEO)
# 计算所需积分
feature_type = self._get_feature_type(AIJobType.VIDEO)
credits_needed = await self.credit_service.calculate_credits(
feature_type=feature_type,
params={
'model_name': model_config.model_name,
'video_type': video_type,
'duration': duration,
'quality': kwargs.get('quality', 'sd')
}
)
# 预扣积分(同步,事务保证)
try:
consumption_log = await self.credit_service.consume_credits(
user_id=user_id,
amount=credits_needed,
feature_type=feature_type,
ai_job_id=None,
task_params={
'video_type': video_type,
'model': model_config.model_name,
'duration': duration,
'fps': fps,
**kwargs
}
)
except InsufficientCreditsError as e:
raise ValidationError(f"积分不足: {str(e)}")
# 创建任务记录
from app.models.ai_job import AIJobType, AIJobStatus
job = await self.job_repository.create({
'job_type': AIJobType.VIDEO,
'status': AIJobStatus.PENDING,
'user_id': user_id,
'model_id': model_config.model_id,
'model_name': model_config.model_name,
'consumption_log_id': consumption_log.consumption_id,
'credits_used': credits_needed,
'input_data': {
'video_type': video_type,
'prompt': prompt,
'image_url': image_url,
'duration': duration,
'fps': fps,
**kwargs
}
})
# 更新 consumption_log 的 ai_job_id
consumption_log.ai_job_id = job.ai_job_id
consumption_log.task_id = str(job.ai_job_id)
await self.db.commit()
# 提交异步任务
task = generate_video_task.delay(
job_id=job.ai_job_id,
video_type=video_type,
prompt=prompt,
image_url=image_url,
duration=duration,
fps=fps,
**kwargs
)
# 更新任务 ID
await self.job_repository.update(job.ai_job_id, {'task_id': task.id})
return {
'job_id': str(job.ai_job_id),
'task_id': task.id,
'status': 'pending',
'estimated_credits': credits_needed
}
async def generate_sound(
self,
user_id: UUID,
description: str,
duration: int = 5,
sound_type: Optional[str] = None,
model: Optional[str] = None,
**kwargs
) -> Dict[str, Any]:
"""生成音效(异步)"""
# 检查配额
await self._check_quota(user_id, 'sound_generation')
# 获取模型配置(AIModelType.AUDIO = 4)
from app.models.ai_model import AIModelType
model_config = await self._get_model(model, AIModelType.AUDIO)
# 计算所需积分
feature_type = self._get_feature_type(AIJobType.SOUND)
credits_needed = await self.credit_service.calculate_credits(
feature_type=feature_type,
params={
'model_name': model_config.model_name,
'duration': duration
}
)
# 预扣积分(同步,事务保证)
try:
consumption_log = await self.credit_service.consume_credits(
user_id=user_id,
amount=credits_needed,
feature_type=feature_type,
ai_job_id=None,
task_params={
'description': description,
'model': model_config.model_name,
'duration': duration,
'sound_type': sound_type,
**kwargs
}
)
except InsufficientCreditsError as e:
raise ValidationError(f"积分不足: {str(e)}")
# 创建任务记录
from app.models.ai_job import AIJobType, AIJobStatus
job = await self.job_repository.create({
'job_type': AIJobType.SOUND,
'status': AIJobStatus.PENDING,
'user_id': user_id,
'model_id': model_config.model_id,
'model_name': model_config.model_name,
'consumption_log_id': consumption_log.consumption_id,
'credits_used': credits_needed,
'input_data': {
'description': description,
'duration': duration,
'sound_type': sound_type,
**kwargs
}
})
# 更新 consumption_log 的 ai_job_id
consumption_log.ai_job_id = job.ai_job_id
consumption_log.task_id = str(job.ai_job_id)
await self.db.commit()
# 提交异步任务
task = generate_sound_task.delay(
job_id=job.ai_job_id,
description=description,
duration=duration,
sound_type=sound_type,
**kwargs
)
# 更新任务 ID
await self.job_repository.update(job.ai_job_id, {'task_id': task.id})
return {
'job_id': str(job.ai_job_id),
'task_id': task.id,
'status': 'pending',
'estimated_credits': credits_needed
}
async def generate_voice(
self,
user_id: UUID,
text: str,
voice_type: str = 'alloy',
speed: float = 1.0,
language: str = 'zh-CN',
model: Optional[str] = None,
**kwargs
) -> Dict[str, Any]:
"""生成配音(异步)"""
# 检查配额
await self._check_quota(user_id, 'voice_generation')
# 获取模型配置(AIModelType.AUDIO = 4)
from app.models.ai_model import AIModelType
model_config = await self._get_model(model, AIModelType.AUDIO)
# 计算所需积分
feature_type = self._get_feature_type(AIJobType.VOICE)
credits_needed = await self.credit_service.calculate_credits(
feature_type=feature_type,
params={
'model_name': model_config.model_name,
'char_count': len(text)
}
)
# 预扣积分(同步,事务保证)
try:
consumption_log = await self.credit_service.consume_credits(
user_id=user_id,
amount=credits_needed,
feature_type=feature_type,
ai_job_id=None,
task_params={
'text': text,
'model': model_config.model_name,
'voice_type': voice_type,
'speed': speed,
'language': language,
**kwargs
}
)
except InsufficientCreditsError as e:
raise ValidationError(f"积分不足: {str(e)}")
# 创建任务记录
from app.models.ai_job import AIJobType, AIJobStatus
job = await self.job_repository.create({
'job_type': AIJobType.VOICE,
'status': AIJobStatus.PENDING,
'user_id': user_id,
'model_id': model_config.model_id,
'model_name': model_config.model_name,
'consumption_log_id': consumption_log.consumption_id,
'credits_used': credits_needed,
'input_data': {
'text': text,
'voice_type': voice_type,
'speed': speed,
'language': language,
**kwargs
}
})
# 更新 consumption_log 的 ai_job_id
consumption_log.ai_job_id = job.ai_job_id
consumption_log.task_id = str(job.ai_job_id)
await self.db.commit()
# 提交异步任务
task = generate_voice_task.delay(
job_id=job.ai_job_id,
text=text,
voice_type=voice_type,
speed=speed,
language=language,
**kwargs
)
# 更新任务 ID
await self.job_repository.update(job.ai_job_id, {'task_id': task.id})
return {
'job_id': str(job.ai_job_id),
'task_id': task.id,
'status': 'pending',
'estimated_credits': credits_needed
}
async def generate_subtitle(
self,
user_id: UUID,
audio_url: str,
language: str = 'zh',
model: Optional[str] = None,
**kwargs
) -> Dict[str, Any]:
"""生成字幕(异步)"""
# 检查配额
await self._check_quota(user_id, 'subtitle_generation')
# 获取模型配置(AIModelType.AUDIO = 4)
from app.models.ai_model import AIModelType
model_config = await self._get_model(model, AIModelType.AUDIO)
# 计算所需积分(假设音频时长为 60 秒,实际应从音频文件获取)
duration = kwargs.get('duration', 60)
feature_type = self._get_feature_type(AIJobType.SUBTITLE)
credits_needed = await self.credit_service.calculate_credits(
feature_type=feature_type,
params={
'model_name': model_config.model_name,
'duration': duration
}
)
# 预扣积分(同步,事务保证)
try:
consumption_log = await self.credit_service.consume_credits(
user_id=user_id,
amount=credits_needed,
feature_type=feature_type,
ai_job_id=None,
task_params={
'audio_url': audio_url,
'model': model_config.model_name,
'language': language,
'duration': duration,
**kwargs
}
)
except InsufficientCreditsError as e:
raise ValidationError(f"积分不足: {str(e)}")
# 创建任务记录
from app.models.ai_job import AIJobType, AIJobStatus
job = await self.job_repository.create({
'job_type': AIJobType.SUBTITLE,
'status': AIJobStatus.PENDING,
'user_id': user_id,
'model_id': model_config.model_id,
'model_name': model_config.model_name,
'consumption_log_id': consumption_log.consumption_id,
'credits_used': credits_needed,
'input_data': {
'audio_url': audio_url,
'language': language,
'duration': duration,
**kwargs
}
})
# 更新 consumption_log 的 ai_job_id
consumption_log.ai_job_id = job.ai_job_id
consumption_log.task_id = str(job.ai_job_id)
await self.db.commit()
# 提交异步任务
task = generate_subtitle_task.delay(
job_id=job.ai_job_id,
audio_url=audio_url,
language=language,
**kwargs
)
# 更新任务 ID
await self.job_repository.update(job.ai_job_id, {'task_id': task.id})
return {
'job_id': str(job.ai_job_id),
'task_id': task.id,
'status': 'pending',
'estimated_credits': credits_needed
}
async def process_text(
self,
user_id: UUID,
task_type: str,
text: str,
model: Optional[str] = None,
**kwargs
) -> Dict[str, Any]:
"""处理文本(异步)
Args:
user_id: 用户 ID
task_type: 任务类型
- screenplay_parse: 剧本拆解
- content_analysis: 内容分析
- style_transform: 风格转换
- prompt_generation: 提示词生成
text: 待处理的文本
model: 使用的模型(默认使用配置的模型)
**kwargs: 其他参数
- output_format: 输出格式 (json/text)
- temperature: 生成温度 (0.0-2.0)
- max_tokens: 最大 token 数
"""
# 验证任务类型
valid_task_types = [
'screenplay_parse',
'content_analysis',
'style_transform',
'prompt_generation'
]
if task_type not in valid_task_types:
raise ValidationError(f"不支持的任务类型: {task_type}")
# 创建任务记录
from app.models.ai_job import AIJobType, AIJobStatus
job = await self.job_repository.create({
'job_type': AIJobType.TEXT_PROCESSING,
'status': AIJobStatus.PENDING,
'user_id': user_id,
'input_data': {
'task_type': task_type,
'text': text,
'model': model or 'gpt-4',
'output_format': kwargs.get('output_format', 'json'),
'temperature': kwargs.get('temperature', 0.7),
'max_tokens': kwargs.get('max_tokens', 4000),
**kwargs
}
})
# 提交异步任务
from app.tasks.ai_tasks import process_text_task
task = process_text_task.delay(
job_id=job.id,
task_type=task_type,
text=text,
model=model or 'gpt-4',
**kwargs
)
return {
'job_id': str(job.id),
'task_id': task.id,
'status': 'pending'
}
async def get_job_status(self, job_id: UUID) -> Dict[str, Any]:
"""查询任务状态"""
job = await self.job_repository.get_by_id(job_id)
if not job:
from app.core.exceptions import NotFoundError
raise NotFoundError("任务不存在")
return {
'job_id': str(job.id),
'job_type': job.job_type,
'status': job.status,
'progress': job.progress,
'input_data': job.input_data,
'output_data': job.output_data,
'error_message': job.error_message,
'model_name': job.model_name,
'cost': float(job.cost) if job.cost else None,
'credits_used': job.credits_used,
'created_at': job.created_at.isoformat(),
'updated_at': job.updated_at.isoformat(),
'started_at': job.started_at.isoformat() if job.started_at else None,
'completed_at': job.completed_at.isoformat() if job.completed_at else None
}
async def cancel_job(self, user_id: UUID, job_id: UUID) -> None:
"""取消任务"""
job = await self.job_repository.get_by_id(job_id)
if not job:
from app.core.exceptions import NotFoundError
raise NotFoundError("任务不存在")
if job.user_id != user_id:
from app.core.exceptions import PermissionDeniedError
raise PermissionDeniedError("没有权限取消此任务")
from app.models.ai_job import AIJobStatus
if job.status in [AIJobStatus.COMPLETED, AIJobStatus.FAILED]:
raise ValidationError("任务已完成或失败,无法取消")
# 取消 Celery 任务
from app.tasks.celery_app import celery_app
if job.task_id:
celery_app.control.revoke(job.task_id, terminate=True)
# 更新任务状态
from app.models.ai_job import AIJobStatus
await self.job_repository.update(job_id, {
'status': AIJobStatus.CANCELLED,
'error_message': '用户取消'
})
async def get_user_usage_stats(
self,
user_id: UUID,
start_date: Optional[str] = None,
end_date: Optional[str] = None
) -> Dict[str, Any]:
"""获取用户使用统计"""
stats = await self.usage_log_repository.get_user_stats(
user_id, start_date, end_date
)
quotas = await self.quota_repository.get_user_quotas(user_id)
return {
'total_cost': float(stats.get('total_cost', 0)),
'total_credits_used': stats.get('total_credits_used', 0),
'total_requests': stats.get('total_requests', 0),
'by_model': stats.get('by_model', {}),
'by_type': stats.get('by_type', {}),
'quotas': [
{
'quota_type': q.quota_type,
'period': q.period,
'total_quota': q.total_quota,
'used_quota': q.used_quota,
'remaining_quota': q.total_quota - q.used_quota,
'reset_at': q.reset_at.isoformat()
}
for q in quotas
]
}
async def get_available_models(
self,
model_type: Optional[int] = None
) -> list[Dict[str, Any]]:
"""获取可用的 AI 模型列表
Args:
model_type: 模型类型(AIModelType 枚举值,可选)
"""
models = await self.model_repository.get_active_models(model_type)
return [
{
'model_id': str(m.model_id),
'model_name': m.model_name,
'display_name': m.display_name,
'description': m.description,
'provider': m.provider,
'model_type': m.model_type,
'cost_per_unit': float(m.cost_per_unit),
'unit_type': m.unit_type,
'credits_per_unit': m.credits_per_unit,
'is_beta': m.is_beta,
'config': m.config
}
for m in models
]
API 接口
1. 创建 AI 任务(统一入口)
POST /api/v1/ai/jobs
请求体(图片生成):
{
"jobType": "image",
"prompt": "一只可爱的猫咪在花园里玩耍",
"model": "stable_diffusion",
"width": 1024,
"height": 1024,
"style": "realistic"
}
请求体(视频生成 - 文本转视频):
{
"jobType": "video",
"videoType": "text2video",
"prompt": "一只猫咪在花园里奔跑",
"duration": 5,
"fps": 30
}
请求体(视频生成 - 图片转视频):
{
"jobType": "video",
"videoType": "img2video",
"imageUrl": "https://example.com/image.jpg",
"duration": 5,
"fps": 30
}
请求体(音效生成):
{
"jobType": "sound",
"description": "雨声",
"duration": 10,
"soundType": "ambient"
}
请求体(配音生成):
{
"jobType": "voice",
"text": "欢迎来到Jointo平台",
"voiceType": "alloy",
"speed": 1.0,
"language": "zh-CN"
}
请求体(字幕生成):
{
"jobType": "subtitle",
"audioUrl": "https://example.com/audio.mp3",
"language": "zh"
}
请求体(文本处理 - 剧本拆解):
{
"jobType": "textProcessing",
"taskType": "screenplay_parse",
"text": "场景1:咖啡厅 - 白天\n小明走进咖啡厅,看到小红坐在窗边...",
"model": "gpt-4",
"outputFormat": "json",
"temperature": 0.7
}
响应:
{
"code": 200,
"message": "Success",
"data": {
"jobId": "019d1234-5678-7abc-def0-111111111111",
"taskId": "abc-123-def",
"status": "pending",
"estimatedCredits": 10
}
}
错误响应:
// 400 - 参数错误
{
"code": 400,
"message": "参数验证失败: prompt 不能为空",
"data": null
}
// 402 - 积分不足
{
"code": 402,
"message": "积分不足,当前余额: 5,需要: 10",
"data": null
}
// 429 - 配额超限
{
"code": 429,
"message": "每日配额已用完,请明天再试",
"data": null
}
// 500 - 服务器错误
{
"code": 500,
"message": "AI 服务暂时不可用,请稍后重试",
"data": null
}
2. 查询任务状态
GET /api/v1/ai/jobs/{job_id}
响应:
{
"code": 200,
"message": "Success",
"data": {
"jobId": "019d1234-5678-7abc-def0-111111111111",
"jobType": "image",
"status": "completed",
"progress": 100,
"outputData": {
"fileUrl": "https://storage.jointo.ai/ai-generated/images/abc123def456.png",
"fileSize": 1024000,
"checksum": "abc123def456...",
"mimeType": "image/png",
"meta_data": {
"width": 1024,
"height": 1024,
"format": "png"
}
},
"createdAt": "2026-01-27T10:00:00Z",
"completedAt": "2026-01-27T10:00:30Z"
}
}
3. 获取任务列表
GET /api/v1/ai/jobs
查询参数:
jobType(可选):任务类型(image/video/sound/voice/subtitle/textProcessing)status(可选):任务状态(pending/processing/completed/failed/cancelled)page(可选):页码(默认 1)pageSize(可选):每页数量(默认 20)
响应:
{
"code": 200,
"message": "Success",
"data": {
"items": [
{
"jobId": "019d1234-5678-7abc-def0-111111111111",
"jobType": "image",
"status": "completed",
"createdAt": "2026-01-27T10:00:00Z"
}
],
"total": 100,
"page": 1,
"pageSize": 20,
"totalPages": 5
}
}
4. 取消任务
POST /api/v1/ai/jobs/{job_id}/cancel
响应:
{
"code": 200,
"message": "任务已取消",
"data": null
}
5. 获取用户使用统计
5. 获取用户使用统计
GET /api/v1/ai/usage/stats
查询参数:
startDate(可选):开始日期(YYYY-MM-DD)endDate(可选):结束日期(YYYY-MM-DD)
响应:
{
"code": 200,
"message": "Success",
"data": {
"totalCost": 12.50,
"totalCreditsUsed": 250,
"totalRequests": 45,
"byModel": {
"gpt-4": {
"requests": 20,
"cost": 8.00,
"creditsUsed": 160
},
"dall-e-3": {
"requests": 15,
"cost": 3.50,
"creditsUsed": 75
}
},
"byType": {
"textProcessing": {
"requests": 20,
"cost": 8.00
},
"image": {
"requests": 15,
"cost": 3.50
}
},
"quotas": [
{
"quotaType": "imageGeneration",
"period": "daily",
"totalQuota": 50,
"usedQuota": 15,
"remainingQuota": 35,
"resetAt": "2026-01-28T00:00:00Z"
},
{
"quotaType": "textProcessing",
"period": "monthly",
"totalQuota": 1000,
"usedQuota": 250,
"remainingQuota": 750,
"resetAt": "2026-02-01T00:00:00Z"
}
]
}
}
6. 获取可用模型列表
GET /api/v1/ai/models
查询参数:
type(可选):模型类型(text/image/video/audio)
响应:
{
"code": 200,
"message": "Success",
"data": {
"items": [
{
"modelId": 1,
"modelName": "gpt-4",
"displayName": "GPT-4",
"description": "OpenAI 最强大的语言模型",
"provider": "openai",
"modelType": "text",
"costPerUnit": 0.03,
"unitType": "token",
"creditsPerUnit": 1,
"isBeta": false,
"config": {
"maxTokens": 8000
}
},
{
"modelId": 4,
"modelName": "dall-e-3",
"displayName": "DALL-E 3",
"description": "OpenAI 图片生成模型",
"provider": "openai",
"modelType": "image",
"costPerUnit": 0.04,
"unitType": "image",
"creditsPerUnit": 10,
"isBeta": false,
"config": {
"supportedResolutions": ["1024x1024", "1792x1024", "1024x1792"]
}
}
],
"total": 2,
"page": 1,
"pageSize": 20,
"totalPages": 1
}
}
文件存储集成
对象存储服务
AI 生成的文件需要存储到对象存储服务(OSS),系统通过 FileStorageService 提供统一的存储接口。
存储策略:
- 开发环境:使用 MinIO(本地对象存储)
- 生产环境:可切换到云服务商 OSS
- 阿里云 OSS
- AWS S3
- 腾讯云 COS
- 七牛云 Kodo
配置方式:
通过环境变量 STORAGE_PROVIDER 切换存储服务:
# .env
STORAGE_PROVIDER=minio # 开发环境
# STORAGE_PROVIDER=aliyun # 生产环境(阿里云 OSS)
# STORAGE_PROVIDER=aws # 生产环境(AWS S3)
文件存储流程
AI 任务完成后,需要将 AI 提供商返回的临时文件下载并存储到自有 OSS:
# AI Task 工作流(集成文件存储)
async def generate_image_task(job_id: str, ...):
try:
# 1. 调用 AI Provider 生成图片
result = await provider.generate_image(...)
# result = {"image_url": "https://ai-provider.com/temp/abc123.png"}
# 2. 下载 AI 生成的文件
async with httpx.AsyncClient() as client:
response = await client.get(result['image_url'])
image_data = response.content
# 3. 上传到 OSS(带去重)
async with async_session_maker() as session:
file_storage = FileStorageService(session)
file_meta_data = await file_storage.upload_file(
file_content=image_data,
filename=f"{job_id}.png",
content_type="image/png",
category="ai-generated/images",
user_id=user_id
)
# 4. 更新任务结果(使用自有 URL)
output_data = {
"file_url": file_meta_data.file_url, # 自有 OSS URL
"file_size": file_meta_data.file_size,
"checksum": file_meta_data.checksum,
"mime_type": file_meta_data.mime_type,
"original_url": result['image_url'], # 保留原始 URL(可选)
"meta_data": {
"width": result.get('width'),
"height": result.get('height'),
"format": "png"
}
}
await _update_job_status(
job_id,
AIJobStatus.COMPLETED,
progress=100,
output_data=output_data
)
# 5. 确认积分消耗
await _confirm_or_refund_credits(...)
except Exception as e:
# 任务失败,退还积分
await _update_job_status(job_id, AIJobStatus.FAILED, error_message=str(e))
await _confirm_or_refund_credits(..., success=False, error_message=str(e))
优势:
- 数据主权:AI 生成的内容存储在自有系统中
- 稳定性:不依赖第三方临时 URL(如 OpenAI 的临时 URL 只有 1 小时有效期)
- 去重优化:通过文件校验和(SHA256)自动去重,节省存储成本
- 灵活切换:开发环境使用 MinIO,生产环境可无缝切换到云服务商 OSS
- 统一管理:所有文件通过
FileStorageService统一管理,支持引用计数和自动清理
FileStorageService 集成
AI Service 需要注入 FileStorageService 依赖:
# app/services/ai_service.py
from app.services.file_storage_service import FileStorageService
class AIService:
def __init__(self, db: Session):
self.db = db
self.job_repository = AIJobRepository(db)
self.model_repository = AIModelRepository(db)
self.credit_service = CreditService(db)
self.file_storage = FileStorageService(db) # 注入文件存储服务
AI 提供商集成
提供商抽象层
# app/services/ai_providers/base.py
from abc import ABC, abstractmethod
from typing import Dict, Any
class AIProvider(ABC):
@abstractmethod
async def generate_image(self, prompt: str, **kwargs) -> Dict[str, Any]:
"""生成图片"""
pass
@abstractmethod
async def generate_video(self, prompt: str, **kwargs) -> Dict[str, Any]:
"""生成视频"""
pass
Stable Diffusion 提供商
# app/services/ai_providers/stable_diffusion.py
import httpx
from app.services.ai_providers.base import AIProvider
class StableDiffusionProvider(AIProvider):
def __init__(self, api_key: str, api_url: str):
self.api_key = api_key
self.api_url = api_url
async def generate_image(self, prompt: str, **kwargs) -> Dict[str, Any]:
"""调用 Stable Diffusion API 生成图片"""
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.api_url}/generate",
json={
"prompt": prompt,
"width": kwargs.get('width', 1024),
"height": kwargs.get('height', 1024),
"steps": kwargs.get('steps', 50),
"cfg_scale": kwargs.get('cfg_scale', 7.5)
},
headers={"Authorization": f"Bearer {self.api_key}"},
timeout=300
)
return response.json()
async def generate_video(self, prompt: str, **kwargs) -> Dict[str, Any]:
"""Stable Diffusion 不支持视频生成"""
raise NotImplementedError("Stable Diffusion 不支持视频生成")
OpenAI 提供商
# app/services/ai_providers/openai.py
import openai
from app.services.ai_providers.base import AIProvider
class OpenAIProvider(AIProvider):
def __init__(self, api_key: str):
self.api_key = api_key
openai.api_key = api_key
async def generate_image(self, prompt: str, **kwargs) -> Dict[str, Any]:
"""调用 DALL-E API 生成图片"""
response = await openai.Image.acreate(
prompt=prompt,
n=1,
size=f"{kwargs.get('width', 1024)}x{kwargs.get('height', 1024)}"
)
return {
'image_url': response['data'][0]['url']
}
async def generate_voice(self, text: str, **kwargs) -> Dict[str, Any]:
"""调用 OpenAI TTS API 生成配音"""
response = await openai.Audio.create(
model="tts-1",
voice=kwargs.get('voice_type', 'alloy'),
input=text,
speed=kwargs.get('speed', 1.0)
)
return {
'audio_url': response['url']
}
async def process_text(self, task_type: str, text: str, **kwargs) -> Dict[str, Any]:
"""调用 GPT API 处理文本"""
# 根据任务类型构建系统提示词
system_prompts = {
'screenplay_parse': """你是一个专业的剧本分析助手。请将输入的剧本文本拆解为结构化的场景和镜头信息。
输出 JSON 格式,包含:scenes (场景列表),每个场景包含 scene_number, location, time, description, characters, shots。
每个 shot 包含 shot_number, shot_size, camera_movement, description, duration。""",
'content_analysis': """你是一个内容分析专家。请分析输入文本,提取关键信息。
输出 JSON 格式,包含:characters (人物列表), locations (地点列表), timeline (时间线), themes (主题), emotions (情感)。""",
'style_transform': """你是一个文本风格转换专家。请根据要求调整文本的风格和语气。
保持原意,但改变表达方式。""",
'prompt_generation': """你是一个 AI 绘画提示词专家。请将输入的场景描述转换为详细的 AI 绘画提示词。
包含:主体、环境、光线、色调、风格、镜头角度等要素。输出英文提示词。"""
}
messages = [
{"role": "system", "content": system_prompts.get(task_type, "")},
{"role": "user", "content": text}
]
response = await openai.ChatCompletion.acreate(
model=kwargs.get('model', 'gpt-4'),
messages=messages,
temperature=kwargs.get('temperature', 0.7),
max_tokens=kwargs.get('max_tokens', 4000)
)
content = response.choices[0].message.content
# 如果要求 JSON 格式,尝试解析
if kwargs.get('output_format') == 'json':
import json
try:
content = json.loads(content)
except json.JSONDecodeError:
# 如果解析失败,尝试提取 JSON 部分
import re
json_match = re.search(r'\{.*\}', content, re.DOTALL)
if json_match:
content = json.loads(json_match.group())
return {
'result': content,
'model': response.model,
'usage': {
'prompt_tokens': response.usage.prompt_tokens,
'completion_tokens': response.usage.completion_tokens,
'total_tokens': response.usage.total_tokens
}
}
提供商工厂
# app/services/ai_service_factory.py
from app.services.ai_providers.stable_diffusion import StableDiffusionProvider
from app.services.ai_providers.openai import OpenAIProvider
from app.config import settings
class AIServiceFactory:
@staticmethod
def create_provider(provider_type: str):
"""创建 AI 提供商实例"""
if provider_type == 'stable_diffusion':
return StableDiffusionProvider(
api_key=settings.STABLE_DIFFUSION_API_KEY,
api_url=settings.STABLE_DIFFUSION_API_URL
)
elif provider_type == 'openai':
return OpenAIProvider(
api_key=settings.OPENAI_API_KEY
)
else:
raise ValueError(f"Unknown provider: {provider_type}")
相关文档
附录:Repository 实现示例
基础 Repository(引用完整性保证)
# app/repositories/base_repository.py
from typing import Optional, TypeVar, Generic
from uuid import UUID
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
T = TypeVar('T')
class BaseRepository(Generic[T]):
"""基础 Repository - 提供引用完整性验证"""
def __init__(self, db: AsyncSession, model: type[T]):
self.db = db
self.model = model
async def exists(self, id: UUID) -> bool:
"""检查记录是否存在(排除软删除)
这是应用层保证引用完整性的核心方法
"""
query = select(self.model.id).where(
self.model.id == id
).limit(1)
# 如果模型有 deleted_at 字段,排除软删除记录
if hasattr(self.model, 'deleted_at'):
query = query.where(self.model.deleted_at.is_(None))
result = await self.db.execute(query)
return result.scalar_one_or_none() is not None
async def get_by_id(self, id: UUID) -> Optional[T]:
"""根据 ID 获取记录"""
query = select(self.model).where(self.model.id == id)
if hasattr(self.model, 'deleted_at'):
query = query.where(self.model.deleted_at.is_(None))
result = await self.db.execute(query)
return result.scalar_one_or_none()
AIJobRepository
# app/repositories/ai_job_repository.py
from typing import List, Optional
from uuid import UUID
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.ai_job import AIJob, AIJobType, AIJobStatus
from app.repositories.base_repository import BaseRepository
class AIJobRepository(BaseRepository[AIJob]):
"""AI 任务数据访问层"""
def __init__(self, db: AsyncSession):
super().__init__(db, AIJob)
async def create(self, job_data: dict) -> AIJob:
"""创建 AI 任务"""
job = AIJob(**job_data)
self.db.add(job)
await self.db.commit()
await self.db.refresh(job)
return job
async def update(self, job_id: UUID, update_data: dict) -> AIJob:
"""更新 AI 任务"""
job = await self.get_by_id(job_id)
if not job:
raise ValueError("任务不存在")
for key, value in update_data.items():
setattr(job, key, value)
await self.db.commit()
await self.db.refresh(job)
return job
async def get_user_jobs(
self,
user_id: UUID,
job_type: Optional[int] = None,
status: Optional[int] = None,
page: int = 1,
page_size: int = 20
) -> List[AIJob]:
"""获取用户的 AI 任务列表"""
query = select(AIJob).where(AIJob.user_id == user_id)
if job_type:
query = query.where(AIJob.job_type == job_type)
if status:
query = query.where(AIJob.status == status)
query = query.order_by(AIJob.created_at.desc())
query = query.offset((page - 1) * page_size).limit(page_size)
result = await self.db.execute(query)
return result.scalars().all()
AIModelRepository
# app/repositories/ai_model_repository.py
from typing import Optional, List
from uuid import UUID
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.ai_model import AIModel, AIModelType
from app.repositories.base_repository import BaseRepository
from app.core.exceptions import ValidationError
class AIModelRepository(BaseRepository[AIModel]):
"""AI 模型数据访问层"""
def __init__(self, db: AsyncSession):
super().__init__(db, AIModel)
async def get_by_name(self, model_name: str) -> Optional[AIModel]:
"""根据模型名称获取模型"""
query = select(AIModel).where(AIModel.model_name == model_name)
result = await self.db.execute(query)
return result.scalar_one_or_none()
async def get_default_model(self, model_type: int) -> Optional[AIModel]:
"""获取指定类型的默认模型(最便宜的活跃模型)
Args:
model_type: 模型类型(AIModelType 枚举值)
"""
# 验证模型类型枚举值
if model_type not in AIModelType.__members__.values():
raise ValidationError(f"无效的模型类型: {model_type}")
query = select(AIModel).where(
AIModel.model_type == model_type,
AIModel.is_active == True
).order_by(AIModel.cost_per_unit.asc())
result = await self.db.execute(query)
return result.scalar_one_or_none()
async def get_active_models(
self,
model_type: Optional[int] = None
) -> List[AIModel]:
"""获取所有活跃模型
Args:
model_type: 模型类型(AIModelType 枚举值,可选)
"""
# 如果指定了模型类型,验证枚举值
if model_type is not None and model_type not in AIModelType.__members__.values():
raise ValidationError(f"无效的模型类型: {model_type}")
query = select(AIModel).where(AIModel.is_active == True)
if model_type:
query = query.where(AIModel.model_type == model_type)
query = query.order_by(AIModel.model_type, AIModel.cost_per_unit)
result = await self.db.execute(query)
return result.scalars().all()
AIQuotaRepository
# app/repositories/ai_quota_repository.py
from typing import Optional, List
from uuid import UUID
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from app.models.ai_quota import AIQuota, QuotaPeriod
from app.repositories.base_repository import BaseRepository
from datetime import datetime, timedelta
class AIQuotaRepository(BaseRepository[AIQuota]):
"""AI 配额数据访问层"""
def __init__(self, db: AsyncSession):
super().__init__(db, AIQuota)
async def get_user_quota(
self,
user_id: UUID,
quota_type: str,
period: int = QuotaPeriod.DAILY
) -> Optional[AIQuota]:
"""获取用户指定类型的配额
Args:
user_id: 用户 ID
quota_type: 配额类型
period: 配额周期(QuotaPeriod 枚举值,默认 1=DAILY)
"""
query = select(AIQuota).where(
AIQuota.user_id == user_id,
AIQuota.quota_type == quota_type,
AIQuota.period == period
)
result = await self.db.execute(query)
return result.scalar_one_or_none()
async def get_user_quotas(self, user_id: UUID) -> List[AIQuota]:
"""获取用户所有配额"""
query = select(AIQuota).where(AIQuota.user_id == user_id)
result = await self.db.execute(query)
return result.scalars().all()
async def increment_usage(
self,
user_id: UUID,
quota_type: str,
amount: int = 1
) -> None:
"""增加配额使用量"""
# 更新每日配额
daily_quota = await self.get_user_quota(user_id, quota_type, QuotaPeriod.DAILY)
if daily_quota:
daily_quota.used_quota += amount
await self.db.commit()
# 更新每月配额
monthly_quota = await self.get_user_quota(user_id, quota_type, QuotaPeriod.MONTHLY)
if monthly_quota:
monthly_quota.used_quota += amount
await self.db.commit()
async def create_default_quotas(self, user_id: UUID) -> None:
"""为新用户创建默认配额"""
default_quotas = [
{
'user_id': user_id,
'quota_type': 'image_generation',
'period': QuotaPeriod.DAILY,
'total_quota': 50,
'reset_at': datetime.now(timezone.utc) + timedelta(days=1)
},
{
'user_id': user_id,
'quota_type': 'image_generation',
'period': QuotaPeriod.MONTHLY,
'total_quota': 1000,
'reset_at': datetime.now(timezone.utc) + timedelta(days=30)
},
{
'user_id': user_id,
'quota_type': 'text_processing',
'period': QuotaPeriod.DAILY,
'total_quota': 100,
'reset_at': datetime.now(timezone.utc) + timedelta(days=1)
},
{
'user_id': user_id,
'quota_type': 'text_processing',
'period': QuotaPeriod.MONTHLY,
'total_quota': 2000,
'reset_at': datetime.now(timezone.utc) + timedelta(days=30)
}
]
for quota_data in default_quotas:
quota = AIQuota(**quota_data)
self.db.add(quota)
await self.db.commit()
文档版本:v1.0
最后更新:2025-01-27
附录:Python 枚举定义
app/models/ai_job.py
from enum import IntEnum
from sqlmodel import SQLModel, Field
from sqlalchemy import Column, SmallInteger, Index, text
from sqlalchemy.dialects.postgresql import UUID as PG_UUID, JSONB
from datetime import datetime, timezone
from uuid import UUID
from typing import Optional
from app.utils.id_generator import generate_uuid
class AIJobType(IntEnum):
"""AI 任务类型"""
IMAGE = 1 # 图片生成
VIDEO = 2 # 视频生成
SOUND = 3 # 音效生成
VOICE = 4 # 配音生成
SUBTITLE = 5 # 字幕生成
TEXT_PROCESSING = 6 # 文本处理(剧本拆解等)
RESOURCE = 7 # 资源生成
STORYBOARD_SCRIPT = 8 # 分镜脚本生成
SCRIPT_GENERATION = 9 # 剧本生成
class AIJobStatus(IntEnum):
"""AI 任务状态"""
PENDING = 1 # 等待处理
PROCESSING = 2 # 处理中
COMPLETED = 3 # 已完成
FAILED = 4 # 失败
CANCELLED = 5 # 已取消
class AIJob(SQLModel, table=True):
"""AI 任务表 - 不使用外键约束,应用层保证引用完整性"""
__tablename__ = "ai_jobs"
# 主键
ai_job_id: UUID = Field(
sa_column=Column(
PG_UUID(as_uuid=True),
primary_key=True,
default=generate_uuid
)
)
# 关联字段(无外键约束,仅索引)
user_id: UUID = Field(
sa_column=Column(PG_UUID(as_uuid=True), nullable=False, index=True),
description="用户 ID - 应用层验证"
)
project_id: Optional[UUID] = Field(
default=None,
sa_column=Column(PG_UUID(as_uuid=True), nullable=True, index=True),
description="项目 ID - 应用层验证"
)
storyboard_id: Optional[UUID] = Field(
default=None,
sa_column=Column(PG_UUID(as_uuid=True), nullable=True, index=True),
description="分镜 ID - 应用层验证"
)
consumption_log_id: Optional[UUID] = Field(
default=None,
sa_column=Column(PG_UUID(as_uuid=True), nullable=True, index=True),
description="积分消耗日志 ID - 应用层验证"
)
model_id: Optional[UUID] = Field(
default=None,
sa_column=Column(PG_UUID(as_uuid=True), nullable=True, index=True),
description="AI 模型 ID - 应用层验证"
)
# 任务信息(使用 SMALLINT 存储枚举)
job_type: int = Field(
sa_column=Column(SmallInteger, nullable=False),
description="任务类型:1=图片 2=视频 3=音效 4=配音 5=字幕 6=文本处理 7=资源 8=分镜脚本 9=剧本生成"
)
status: int = Field(
default=AIJobStatus.PENDING,
sa_column=Column(SmallInteger, nullable=False, default=1, index=True),
description="任务状态:1=等待处理 2=处理中 3=已完成 4=失败 5=已取消"
)
# JSONB 字段
input_data: dict = Field(
default_factory=dict,
sa_column=Column(JSONB, nullable=False, default={}),
description="输入参数"
)
output_data: Optional[dict] = Field(
default=None,
sa_column=Column(JSONB, nullable=True),
description="输出结果"
)
# 其他字段
model_name: Optional[str] = Field(default=None, max_length=255)
progress: int = Field(default=0, ge=0, le=100)
error_message: Optional[str] = Field(default=None)
task_id: Optional[str] = Field(default=None)
cost: Optional[float] = Field(default=None)
credits_used: int = Field(default=0)
# 时间戳
estimated_completion_at: Optional[datetime] = Field(default=None)
started_at: Optional[datetime] = Field(default=None)
completed_at: Optional[datetime] = Field(default=None)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
# 表级索引
__table_args__ = (
Index('idx_ai_jobs_status_created_at', 'status', 'created_at',
postgresql_where=text('status IN (1, 2)')),
Index('idx_ai_jobs_input_data_gin', 'input_data', postgresql_using='gin'),
Index('idx_ai_jobs_output_data_gin', 'output_data',
postgresql_using='gin',
postgresql_where=text('output_data IS NOT NULL')),
)
app/models/ai_model.py
from enum import IntEnum
class AIModelType(IntEnum):
"""AI 模型类型"""
TEXT = 1 # 文本模型(GPT, Claude 等)
IMAGE = 2 # 图片模型(DALL-E, Stable Diffusion 等)
VIDEO = 3 # 视频模型(Runway, Pika 等)
AUDIO = 4 # 音频模型(TTS, STT 等)
class AIProvider(IntEnum):
"""AI 提供商"""
OPENAI = 1
ANTHROPIC = 2
GOOGLE = 3 # Google Gemini
STABILITY = 4
RUNWAY = 5
PIKA = 6
ELEVENLABS = 7
AZURE = 8
BAIDU = 9
ALIYUN = 10
CUSTOM = 99 # 自定义提供商
class UnitType(IntEnum):
"""计费单位类型"""
TOKEN = 1 # Token(文本模型)
IMAGE = 2 # 图片(图片模型)
SECOND = 3 # 秒(视频/音频模型)
REQUEST = 4 # 请求(通用计费单位)
app/models/ai_quota.py
from enum import IntEnum
class QuotaPeriod(IntEnum):
"""配额周期"""
DAILY = 1 # 每日配额
MONTHLY = 2 # 每月配额
TOTAL = 3 # 总配额
测试规范
单元测试
# server/tests/unit/test_ai_service.py
import pytest
from uuid import UUID
from unittest.mock import AsyncMock, MagicMock, patch
from app.services.ai_service import AIService
from app.models.ai_job import AIJobType, AIJobStatus
from app.models.ai_model import AIModelType
from app.core.exceptions import ValidationError, InsufficientCreditsError, NotFoundError
@pytest.fixture
def mock_db():
"""模拟数据库会话"""
db = AsyncMock()
db.begin = AsyncMock()
db.commit = AsyncMock()
db.rollback = AsyncMock()
return db
@pytest.fixture
def ai_service(mock_db):
"""创建 AI Service 实例"""
return AIService(mock_db)
class TestAIService:
"""AI Service 单元测试"""
@pytest.mark.asyncio
async def test_generate_image_success(self, ai_service, mock_db):
"""测试图片生成成功"""
# 准备测试数据
user_id = UUID('019d1234-5678-7abc-def0-111111111111')
prompt = "一只可爱的猫咪"
# Mock 依赖
with patch.object(ai_service, '_check_quota', return_value=True), \
patch.object(ai_service, '_get_model') as mock_get_model, \
patch.object(ai_service.job_repository, 'create') as mock_create_job, \
patch.object(ai_service.credit_service, 'consume_credits') as mock_consume:
# 设置 mock 返回值
mock_model = MagicMock()
mock_model.model_id = UUID('019d1234-5678-7abc-def0-222222222222')
mock_model.model_name = 'stable_diffusion'
mock_model.credits_per_unit = 10
mock_model.cost_per_unit = 0.5
mock_get_model.return_value = mock_model
mock_consumption = MagicMock()
mock_consumption.consumption_id = UUID('019d1234-5678-7abc-def0-333333333333')
mock_consume.return_value = mock_consumption
mock_job = MagicMock()
mock_job.ai_job_id = UUID('019d1234-5678-7abc-def0-444444444444')
mock_create_job.return_value = mock_job
# 执行测试
result = await ai_service.generate_image(
user_id=user_id,
prompt=prompt,
width=1024,
height=1024
)
# 验证结果
assert 'job_id' in result
assert 'task_id' in result
assert result['status'] == 'pending'
assert result['estimated_credits'] == 10
# 验证调用
mock_consume.assert_called_once()
mock_create_job.assert_called_once()
@pytest.mark.asyncio
async def test_generate_image_insufficient_credits(self, ai_service):
"""测试积分不足"""
user_id = UUID('019d1234-5678-7abc-def0-111111111111')
with patch.object(ai_service, '_check_quota', return_value=True), \
patch.object(ai_service, '_get_model') as mock_get_model, \
patch.object(ai_service.credit_service, 'consume_credits') as mock_consume:
mock_model = MagicMock()
mock_model.credits_per_unit = 10
mock_get_model.return_value = mock_model
# 模拟积分不足
mock_consume.side_effect = InsufficientCreditsError("积分不足")
# 验证抛出异常
with pytest.raises(ValidationError, match="积分不足"):
await ai_service.generate_image(
user_id=user_id,
prompt="test"
)
@pytest.mark.asyncio
async def test_get_job_status_not_found(self, ai_service):
"""测试查询不存在的任务"""
job_id = UUID('019d1234-5678-7abc-def0-111111111111')
with patch.object(ai_service.job_repository, 'get_by_id', return_value=None):
with pytest.raises(NotFoundError, match="任务不存在"):
await ai_service.get_job_status(job_id)
@pytest.mark.asyncio
async def test_cancel_job_success(self, ai_service):
"""测试取消任务成功"""
user_id = UUID('019d1234-5678-7abc-def0-111111111111')
job_id = UUID('019d1234-5678-7abc-def0-222222222222')
mock_job = MagicMock()
mock_job.user_id = user_id
mock_job.status = AIJobStatus.PENDING
mock_job.task_id = 'celery-task-123'
with patch.object(ai_service.job_repository, 'get_by_id', return_value=mock_job), \
patch.object(ai_service.job_repository, 'update') as mock_update, \
patch('app.tasks.celery_app.celery_app.control.revoke') as mock_revoke:
await ai_service.cancel_job(user_id, job_id)
mock_revoke.assert_called_once_with('celery-task-123', terminate=True)
mock_update.assert_called_once()
集成测试
# server/tests/integration/test_ai_api.py
import pytest
from httpx import AsyncClient
from uuid import UUID
from app.main import app
from app.core.database import get_db
from app.models.ai_job import AIJobType, AIJobStatus
@pytest.fixture
async def client():
"""创建测试客户端"""
async with AsyncClient(app=app, base_url="http://test") as ac:
yield ac
@pytest.fixture
async def auth_headers(client):
"""获取认证 token"""
# 登录获取 token
response = await client.post("/api/v1/auth/login", json={
"email": "test@example.com",
"password": "testpass123"
})
token = response.json()["data"]["accessToken"]
return {"Authorization": f"Bearer {token}"}
class TestAIAPI:
"""AI API 集成测试"""
@pytest.mark.asyncio
async def test_create_image_job(self, client, auth_headers):
"""测试创建图片生成任务"""
response = await client.post(
"/api/v1/ai/jobs",
headers=auth_headers,
json={
"jobType": "image",
"prompt": "一只可爱的猫咪",
"width": 1024,
"height": 1024
}
)
assert response.status_code == 200
data = response.json()
assert data["code"] == 200
assert "jobId" in data["data"]
assert "taskId" in data["data"]
assert data["data"]["status"] == "pending"
@pytest.mark.asyncio
async def test_create_job_invalid_params(self, client, auth_headers):
"""测试无效参数"""
response = await client.post(
"/api/v1/ai/jobs",
headers=auth_headers,
json={
"jobType": "image",
"prompt": "", # 空提示词
"width": 1024
}
)
assert response.status_code == 400
data = response.json()
assert data["code"] == 400
assert "prompt" in data["message"].lower()
@pytest.mark.asyncio
async def test_get_job_status(self, client, auth_headers):
"""测试查询任务状态"""
# 先创建任务
create_response = await client.post(
"/api/v1/ai/jobs",
headers=auth_headers,
json={
"jobType": "image",
"prompt": "test"
}
)
job_id = create_response.json()["data"]["jobId"]
# 查询状态
response = await client.get(
f"/api/v1/ai/jobs/{job_id}",
headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert data["code"] == 200
assert data["data"]["jobId"] == job_id
assert "status" in data["data"]
@pytest.mark.asyncio
async def test_cancel_job(self, client, auth_headers):
"""测试取消任务"""
# 先创建任务
create_response = await client.post(
"/api/v1/ai/jobs",
headers=auth_headers,
json={
"jobType": "image",
"prompt": "test"
}
)
job_id = create_response.json()["data"]["jobId"]
# 取消任务
response = await client.post(
f"/api/v1/ai/jobs/{job_id}/cancel",
headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert data["code"] == 200
assert "已取消" in data["message"]
@pytest.mark.asyncio
async def test_get_usage_stats(self, client, auth_headers):
"""测试获取使用统计"""
response = await client.get(
"/api/v1/ai/usage/stats",
headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert data["code"] == 200
assert "totalCost" in data["data"]
assert "totalCreditsUsed" in data["data"]
assert "quotas" in data["data"]
@pytest.mark.asyncio
async def test_get_available_models(self, client, auth_headers):
"""测试获取可用模型列表"""
response = await client.get(
"/api/v1/ai/models",
headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert data["code"] == 200
assert "items" in data["data"]
assert len(data["data"]["items"]) > 0
测试运行
# 运行所有 AI Service 测试
docker exec jointo-server-app pytest tests/unit/test_ai_service.py -v
# 运行集成测试
docker exec jointo-server-app pytest tests/integration/test_ai_api.py -v
# 运行所有测试并生成覆盖率报告
docker exec jointo-server-app pytest tests/ --cov=app/services/ai_service --cov-report=html
# 运行特定测试
docker exec jointo-server-app pytest tests/unit/test_ai_service.py::TestAIService::test_generate_image_success -v
数据库迁移脚本
创建 AI 服务表
# server/alembic/versions/20260129_1700_create_ai_service_tables.py
"""create ai service tables
Revision ID: abc123def456
Revises: 3a3a2a1417de
Create Date: 2026-01-29 17:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'abc123def456'
down_revision = '3a3a2a1417de'
branch_labels = None
depends_on = None
def upgrade() -> None:
"""创建 AI 服务相关表(无外键约束,应用层保证引用完整性)"""
# 1. 创建 ai_models 表
op.create_table(
'ai_models',
sa.Column('model_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('model_name', sa.Text(), nullable=False),
sa.Column('display_name', sa.Text(), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('provider', sa.SmallInteger(), nullable=False),
sa.Column('model_type', sa.SmallInteger(), nullable=False),
sa.Column('cost_per_unit', sa.Numeric(10, 4), nullable=False),
sa.Column('unit_type', sa.SmallInteger(), nullable=False),
sa.Column('credits_per_unit', sa.Integer(), nullable=False),
sa.Column('config', postgresql.JSONB(astext_type=sa.Text()), nullable=False, server_default='{}'),
sa.Column('rate_limit', sa.Integer(), nullable=True),
sa.Column('daily_quota', sa.Integer(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='true'),
sa.Column('is_beta', sa.Boolean(), nullable=False, server_default='false'),
sa.Column('created_at', sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.text('now()')),
sa.Column('updated_at', sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.text('now()')),
sa.PrimaryKeyConstraint('model_id'),
sa.UniqueConstraint('model_name'),
comment='AI 模型配置表'
)
# 索引
op.create_index('idx_ai_models_provider', 'ai_models', ['provider'],
postgresql_where=sa.text('is_active = true'))
op.create_index('idx_ai_models_type', 'ai_models', ['model_type'],
postgresql_where=sa.text('is_active = true'))
op.create_index('idx_ai_models_is_active', 'ai_models', ['is_active'])
op.create_index('idx_ai_models_config_gin', 'ai_models', ['config'],
postgresql_using='gin')
# 触发器
op.execute("""
CREATE TRIGGER update_ai_models_updated_at
BEFORE UPDATE ON ai_models
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
""")
# 2. 创建 ai_jobs 表(无外键约束)
op.create_table(
'ai_jobs',
sa.Column('ai_job_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('project_id', postgresql.UUID(as_uuid=True), nullable=True),
sa.Column('storyboard_id', postgresql.UUID(as_uuid=True), nullable=True),
sa.Column('consumption_log_id', postgresql.UUID(as_uuid=True), nullable=True),
sa.Column('job_type', sa.SmallInteger(), nullable=False),
sa.Column('status', sa.SmallInteger(), nullable=False, server_default='1'),
sa.Column('input_data', postgresql.JSONB(astext_type=sa.Text()), nullable=False, server_default='{}'),
sa.Column('output_data', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('model_id', postgresql.UUID(as_uuid=True), nullable=True),
sa.Column('model_name', sa.Text(), nullable=True),
sa.Column('progress', sa.Integer(), nullable=False, server_default='0'),
sa.Column('error_message', sa.Text(), nullable=True),
sa.Column('task_id', sa.Text(), nullable=True),
sa.Column('estimated_completion_at', sa.TIMESTAMP(timezone=True), nullable=True),
sa.Column('started_at', sa.TIMESTAMP(timezone=True), nullable=True),
sa.Column('completed_at', sa.TIMESTAMP(timezone=True), nullable=True),
sa.Column('created_at', sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.text('now()')),
sa.Column('updated_at', sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.text('now()')),
sa.Column('cost', sa.Numeric(10, 4), nullable=True),
sa.Column('credits_used', sa.Integer(), nullable=False, server_default='0'),
sa.PrimaryKeyConstraint('ai_job_id'),
sa.CheckConstraint('progress >= 0 AND progress <= 100', name='ai_jobs_progress_check'),
comment='AI 任务表 - 应用层保证引用完整性'
)
# 索引(关联字段必须有索引)
op.create_index('idx_ai_jobs_user_id', 'ai_jobs', ['user_id'])
op.create_index('idx_ai_jobs_project_id', 'ai_jobs', ['project_id'],
postgresql_where=sa.text('project_id IS NOT NULL'))
op.create_index('idx_ai_jobs_storyboard_id', 'ai_jobs', ['storyboard_id'],
postgresql_where=sa.text('storyboard_id IS NOT NULL'))
op.create_index('idx_ai_jobs_type', 'ai_jobs', ['job_type'])
op.create_index('idx_ai_jobs_status', 'ai_jobs', ['status'])
op.create_index('idx_ai_jobs_model_id', 'ai_jobs', ['model_id'],
postgresql_where=sa.text('model_id IS NOT NULL'))
op.create_index('idx_ai_jobs_consumption_log_id', 'ai_jobs', ['consumption_log_id'],
postgresql_where=sa.text('consumption_log_id IS NOT NULL'))
op.create_index('idx_ai_jobs_created_at', 'ai_jobs', ['created_at'])
op.create_index('idx_ai_jobs_status_created_at', 'ai_jobs', ['status', 'created_at'],
postgresql_where=sa.text('status IN (1, 2)'))
op.create_index('idx_ai_jobs_input_data_gin', 'ai_jobs', ['input_data'],
postgresql_using='gin')
op.create_index('idx_ai_jobs_output_data_gin', 'ai_jobs', ['output_data'],
postgresql_using='gin',
postgresql_where=sa.text('output_data IS NOT NULL'))
# 触发器
op.execute("""
CREATE TRIGGER update_ai_jobs_updated_at
BEFORE UPDATE ON ai_jobs
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
""")
# 字段注释
op.execute("""
COMMENT ON COLUMN ai_jobs.user_id IS '用户 ID - 应用层验证';
COMMENT ON COLUMN ai_jobs.project_id IS '项目 ID - 应用层验证';
COMMENT ON COLUMN ai_jobs.storyboard_id IS '分镜 ID - 应用层验证';
COMMENT ON COLUMN ai_jobs.consumption_log_id IS '积分消耗日志 ID - 应用层验证';
COMMENT ON COLUMN ai_jobs.model_id IS 'AI 模型 ID - 应用层验证';
COMMENT ON COLUMN ai_jobs.job_type IS '任务类型(1=图片 2=视频 3=音效 4=配音 5=字幕 6=文本处理 7=资源 8=分镜脚本 9=剧本生成)';
COMMENT ON COLUMN ai_jobs.status IS '任务状态(1=等待处理 2=处理中 3=已完成 4=失败 5=已取消)';
""")
# 3. 创建 ai_usage_logs 表(无外键约束)
op.create_table(
'ai_usage_logs',
sa.Column('usage_log_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('ai_job_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('model_id', postgresql.UUID(as_uuid=True), nullable=True),
sa.Column('units_used', sa.Numeric(10, 2), nullable=False),
sa.Column('unit_type', sa.SmallInteger(), nullable=False),
sa.Column('cost', sa.Numeric(10, 4), nullable=False),
sa.Column('credits_used', sa.Integer(), nullable=False),
sa.Column('meta_data', postgresql.JSONB(astext_type=sa.Text()), nullable=False, server_default='{}'),
sa.Column('created_at', sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.text('now()')),
sa.PrimaryKeyConstraint('usage_log_id'),
sa.CheckConstraint('unit_type BETWEEN 1 AND 4', name='ai_usage_logs_unit_type_check'),
comment='AI 使用日志表 - 应用层保证引用完整性'
)
# 索引
op.create_index('idx_ai_usage_logs_user_id', 'ai_usage_logs', ['user_id'])
op.create_index('idx_ai_usage_logs_ai_job_id', 'ai_usage_logs', ['ai_job_id'])
op.create_index('idx_ai_usage_logs_model_id', 'ai_usage_logs', ['model_id'],
postgresql_where=sa.text('model_id IS NOT NULL'))
op.create_index('idx_ai_usage_logs_created_at', 'ai_usage_logs', ['created_at'])
op.create_index('idx_ai_usage_logs_user_created', 'ai_usage_logs', ['user_id', 'created_at'])
op.create_index('idx_ai_usage_logs_meta_data_gin', 'ai_usage_logs', ['meta_data'],
postgresql_using='gin')
# 字段注释
op.execute("""
COMMENT ON COLUMN ai_usage_logs.user_id IS '用户 ID - 应用层验证';
COMMENT ON COLUMN ai_usage_logs.ai_job_id IS 'AI 任务 ID - 应用层验证';
COMMENT ON COLUMN ai_usage_logs.model_id IS 'AI 模型 ID - 应用层验证';
COMMENT ON COLUMN ai_usage_logs.unit_type IS '单位类型(1=Token 2=图片 3=秒 4=请求)';
""")
# 4. 创建 ai_quotas 表(无外键约束)
op.create_table(
'ai_quotas',
sa.Column('quota_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('quota_type', sa.Text(), nullable=False),
sa.Column('period', sa.SmallInteger(), nullable=False),
sa.Column('total_quota', sa.Integer(), nullable=False),
sa.Column('used_quota', sa.Integer(), nullable=False, server_default='0'),
sa.Column('reset_at', sa.TIMESTAMP(timezone=True), nullable=False),
sa.Column('created_at', sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.text('now()')),
sa.Column('updated_at', sa.TIMESTAMP(timezone=True), nullable=False, server_default=sa.text('now()')),
sa.PrimaryKeyConstraint('quota_id'),
sa.UniqueConstraint('user_id', 'quota_type', 'period', name='ai_quotas_unique'),
sa.CheckConstraint('period BETWEEN 1 AND 3', name='ai_quotas_period_check'),
sa.CheckConstraint('used_quota >= 0 AND used_quota <= total_quota', name='ai_quotas_used_check'),
comment='AI 配额表 - 应用层保证引用完整性'
)
# 索引
op.create_index('idx_ai_quotas_user_id', 'ai_quotas', ['user_id'])
op.create_index('idx_ai_quotas_type', 'ai_quotas', ['quota_type'])
op.create_index('idx_ai_quotas_reset_at', 'ai_quotas', ['reset_at'])
op.create_index('idx_ai_quotas_user_type', 'ai_quotas', ['user_id', 'quota_type'])
# 触发器
op.execute("""
CREATE TRIGGER update_ai_quotas_updated_at
BEFORE UPDATE ON ai_quotas
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
""")
# 字段注释
op.execute("""
COMMENT ON COLUMN ai_quotas.user_id IS '用户 ID - 应用层验证';
COMMENT ON COLUMN ai_quotas.period IS '配额周期(1=每日 2=每月 3=总计)';
""")
# 5. 创建配额重置函数
op.execute("""
CREATE OR REPLACE FUNCTION reset_expired_quotas()
RETURNS void AS $$
BEGIN
-- 重置每日配额 (period = 1)
UPDATE ai_quotas
SET used_quota = 0,
reset_at = reset_at + INTERVAL '1 day',
updated_at = now()
WHERE period = 1 AND reset_at <= now();
-- 重置每月配额 (period = 2)
UPDATE ai_quotas
SET used_quota = 0,
reset_at = reset_at + INTERVAL '1 month',
updated_at = now()
WHERE period = 2 AND reset_at <= now();
END;
$$ LANGUAGE plpgsql;
""")
def downgrade() -> None:
"""删除 AI 服务相关表"""
op.execute("DROP FUNCTION IF EXISTS reset_expired_quotas();")
op.drop_table('ai_quotas')
op.drop_table('ai_usage_logs')
op.drop_table('ai_jobs')
op.drop_table('ai_models')
运行迁移
# 生成迁移文件
docker exec jointo-server-app alembic revision --autogenerate -m "create ai service tables"
# 应用迁移
docker exec jointo-server-app alembic upgrade head
# 查看当前版本
docker exec jointo-server-app alembic current
# 回滚迁移
docker exec jointo-server-app alembic downgrade -1