feat: add unified generation entrypoint
This commit is contained in:
@@ -17,6 +17,8 @@ from app.schemas.story_schemas import (
|
||||
AchievementItem,
|
||||
FullStoryResponse,
|
||||
GenerateRequest,
|
||||
GenerationRequest,
|
||||
GenerationResponse,
|
||||
StoryAssetRetryRequest,
|
||||
StorybookRequest,
|
||||
StorybookResponse,
|
||||
@@ -37,13 +39,45 @@ logger = get_logger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
RATE_LIMIT_WINDOW = 60 # seconds
|
||||
RATE_LIMIT_REQUESTS = 10
|
||||
|
||||
|
||||
@router.post("/stories/generate", response_model=StoryResponse)
|
||||
async def generate_story(
|
||||
request: GenerateRequest,
|
||||
user: User = Depends(require_user),
|
||||
RATE_LIMIT_REQUESTS = 10
|
||||
|
||||
|
||||
@router.post("/generations", response_model=GenerationResponse)
|
||||
async def create_generation(
|
||||
request: GenerationRequest,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Create a story or storybook through the unified generation workflow."""
|
||||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||||
return await story_service.generate_generation_service(request, user.id, db)
|
||||
|
||||
|
||||
@router.get("/generations/{story_id}", response_model=StoryDetailResponse)
|
||||
async def get_generation(
|
||||
story_id: int,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Get a generated story/storybook through the unified generation API."""
|
||||
return await story_service.get_story_detail(story_id, user.id, db)
|
||||
|
||||
|
||||
@router.post("/generations/{story_id}/retry-assets", response_model=StoryDetailResponse)
|
||||
async def retry_generation_assets(
|
||||
story_id: int,
|
||||
payload: StoryAssetRetryRequest,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Retry generated assets through the unified generation API."""
|
||||
return await story_service.retry_story_assets(story_id, user.id, payload.assets, db)
|
||||
|
||||
|
||||
@router.post("/stories/generate", response_model=StoryResponse)
|
||||
async def generate_story(
|
||||
request: GenerateRequest,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Generate or enhance a story."""
|
||||
|
||||
@@ -29,6 +29,19 @@ class GenerateRequest(BaseModel):
|
||||
universe_id: str | None = None
|
||||
|
||||
|
||||
class GenerationRequest(BaseModel):
|
||||
"""Unified generation request for story and storybook outputs."""
|
||||
|
||||
output_mode: Literal["story", "storybook"] = Field(default="story")
|
||||
type: Literal["keywords", "full_story"] = Field(default="keywords")
|
||||
data: str = Field(..., min_length=1, max_length=MAX_DATA_LENGTH)
|
||||
education_theme: str | None = Field(default=None, max_length=MAX_EDU_THEME_LENGTH)
|
||||
generate_images: bool = Field(default=True)
|
||||
page_count: int = Field(default=6, ge=4, le=12)
|
||||
child_profile_id: str | None = None
|
||||
universe_id: str | None = None
|
||||
|
||||
|
||||
class StoryResponse(StoryStatusMixin):
|
||||
"""Story generation response."""
|
||||
|
||||
@@ -99,6 +112,25 @@ class StorybookResponse(StoryStatusMixin):
|
||||
cover_url: str | None = None
|
||||
|
||||
|
||||
class GenerationResponse(StoryStatusMixin):
|
||||
"""Unified generation response for the target workflow API."""
|
||||
|
||||
id: int
|
||||
title: str
|
||||
mode: str
|
||||
story_text: str | None = None
|
||||
pages: list[StorybookPageResponse] | None = None
|
||||
cover_prompt: str | None = None
|
||||
image_url: str | None = None
|
||||
cover_url: str | None = None
|
||||
audio_ready: bool = False
|
||||
errors: dict[str, str | None] = Field(default_factory=dict)
|
||||
main_character: str | None = None
|
||||
art_style: str | None = None
|
||||
child_profile_id: str | None = None
|
||||
universe_id: str | None = None
|
||||
|
||||
|
||||
class StoryDetailResponse(StoryStatusMixin):
|
||||
"""Story detail response for both stories and storybooks."""
|
||||
|
||||
|
||||
@@ -13,6 +13,8 @@ from app.schemas.story_schemas import (
|
||||
AchievementItem,
|
||||
FullStoryResponse,
|
||||
GenerateRequest,
|
||||
GenerationRequest,
|
||||
GenerationResponse,
|
||||
StorybookPageResponse,
|
||||
StorybookRequest,
|
||||
StorybookResponse,
|
||||
@@ -385,6 +387,94 @@ async def generate_storybook_service(
|
||||
audio_status=story.audio_status,
|
||||
last_error=story.last_error,
|
||||
)
|
||||
|
||||
|
||||
async def generate_generation_service(
|
||||
request: GenerationRequest,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
) -> GenerationResponse:
|
||||
"""Unified generation workflow entry point for stories and storybooks."""
|
||||
|
||||
if request.output_mode == "storybook":
|
||||
storybook = await generate_storybook_service(
|
||||
StorybookRequest(
|
||||
keywords=request.data,
|
||||
page_count=request.page_count,
|
||||
education_theme=request.education_theme,
|
||||
generate_images=request.generate_images,
|
||||
child_profile_id=request.child_profile_id,
|
||||
universe_id=request.universe_id,
|
||||
),
|
||||
user_id,
|
||||
db,
|
||||
)
|
||||
if storybook.id is None:
|
||||
raise HTTPException(status_code=500, detail="Storybook generation did not persist.")
|
||||
|
||||
saved_story = await get_story_detail(storybook.id, user_id, db)
|
||||
return GenerationResponse(
|
||||
id=storybook.id,
|
||||
title=storybook.title,
|
||||
mode="storybook",
|
||||
pages=storybook.pages,
|
||||
cover_prompt=storybook.cover_prompt,
|
||||
image_url=storybook.cover_url,
|
||||
cover_url=storybook.cover_url,
|
||||
main_character=storybook.main_character,
|
||||
art_style=storybook.art_style,
|
||||
generation_status=storybook.generation_status,
|
||||
image_status=storybook.image_status,
|
||||
audio_status=storybook.audio_status,
|
||||
last_error=storybook.last_error,
|
||||
child_profile_id=saved_story.child_profile_id,
|
||||
universe_id=saved_story.universe_id,
|
||||
)
|
||||
|
||||
generate_request = GenerateRequest(
|
||||
type=request.type,
|
||||
data=request.data,
|
||||
education_theme=request.education_theme,
|
||||
child_profile_id=request.child_profile_id,
|
||||
universe_id=request.universe_id,
|
||||
)
|
||||
|
||||
if request.generate_images:
|
||||
story = await generate_full_story_service(generate_request, user_id, db)
|
||||
return GenerationResponse(
|
||||
id=story.id,
|
||||
title=story.title,
|
||||
mode=story.mode,
|
||||
story_text=story.story_text,
|
||||
cover_prompt=story.cover_prompt,
|
||||
image_url=story.image_url,
|
||||
cover_url=story.image_url,
|
||||
audio_ready=story.audio_ready,
|
||||
errors=story.errors,
|
||||
generation_status=story.generation_status,
|
||||
image_status=story.image_status,
|
||||
audio_status=story.audio_status,
|
||||
last_error=story.last_error,
|
||||
child_profile_id=story.child_profile_id,
|
||||
universe_id=story.universe_id,
|
||||
)
|
||||
|
||||
story = await generate_and_save_story(generate_request, user_id, db)
|
||||
return GenerationResponse(
|
||||
id=story.id,
|
||||
title=story.title,
|
||||
mode=story.mode,
|
||||
story_text=story.story_text,
|
||||
cover_prompt=story.cover_prompt,
|
||||
image_url=story.image_url,
|
||||
cover_url=story.image_url,
|
||||
generation_status=story.generation_status,
|
||||
image_status=story.image_status,
|
||||
audio_status=story.audio_status,
|
||||
last_error=story.last_error,
|
||||
child_profile_id=story.child_profile_id,
|
||||
universe_id=story.universe_id,
|
||||
)
|
||||
|
||||
|
||||
# ==================== Missing Endpoints Logic (for Issue #5) ====================
|
||||
|
||||
@@ -367,6 +367,135 @@ class TestGenerateFull:
|
||||
assert call_kwargs["education_theme"] == "勇气与友谊"
|
||||
|
||||
|
||||
class TestUnifiedGenerations:
|
||||
"""Tests for the target unified generation API."""
|
||||
|
||||
def test_create_generation_without_auth(self, client: TestClient):
|
||||
response = client.post(
|
||||
"/api/generations",
|
||||
json={"output_mode": "story", "type": "keywords", "data": "小兔子, 森林"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_create_story_generation_success(
|
||||
self,
|
||||
auth_client: TestClient,
|
||||
mock_text_provider,
|
||||
mock_image_provider,
|
||||
):
|
||||
response = auth_client.post(
|
||||
"/api/generations",
|
||||
json={
|
||||
"output_mode": "story",
|
||||
"type": "keywords",
|
||||
"data": "小兔子, 森林, 勇气",
|
||||
"generate_images": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] is not None
|
||||
assert data["mode"] == "generated"
|
||||
assert data["story_text"] == "从前有一只小兔子。"
|
||||
assert data["image_url"] == "https://example.com/image.png"
|
||||
assert data["cover_url"] == "https://example.com/image.png"
|
||||
assert data["pages"] is None
|
||||
assert data["generation_status"] == "completed"
|
||||
assert data["image_status"] == "ready"
|
||||
assert data["audio_status"] == "not_requested"
|
||||
assert data["errors"] == {}
|
||||
|
||||
def test_create_story_generation_without_assets(
|
||||
self,
|
||||
auth_client: TestClient,
|
||||
mock_text_provider,
|
||||
):
|
||||
response = auth_client.post(
|
||||
"/api/generations",
|
||||
json={
|
||||
"output_mode": "story",
|
||||
"type": "keywords",
|
||||
"data": "小兔子, 森林",
|
||||
"generate_images": False,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["mode"] == "generated"
|
||||
assert data["image_url"] is None
|
||||
assert data["generation_status"] == "narrative_ready"
|
||||
assert data["image_status"] == "not_requested"
|
||||
|
||||
def test_create_storybook_generation_success(self, auth_client: TestClient):
|
||||
with patch(
|
||||
"app.services.story_service.generate_storybook",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_storybook:
|
||||
with patch(
|
||||
"app.services.story_service.generate_image",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_image:
|
||||
mock_storybook.return_value = build_storybook_output()
|
||||
mock_image.side_effect = [
|
||||
"https://example.com/storybook-cover.png",
|
||||
"https://example.com/storybook-page-1.png",
|
||||
"https://example.com/storybook-page-2.png",
|
||||
]
|
||||
|
||||
response = auth_client.post(
|
||||
"/api/generations",
|
||||
json={
|
||||
"output_mode": "storybook",
|
||||
"type": "keywords",
|
||||
"data": "森林, 发光, 友情",
|
||||
"page_count": 6,
|
||||
"generate_images": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] is not None
|
||||
assert data["mode"] == "storybook"
|
||||
assert data["story_text"] is None
|
||||
assert len(data["pages"]) == 2
|
||||
assert data["cover_url"] == "https://example.com/storybook-cover.png"
|
||||
assert data["image_url"] == "https://example.com/storybook-cover.png"
|
||||
assert data["main_character"] == "小兔子露露"
|
||||
assert data["art_style"] == "温暖水彩"
|
||||
assert data["generation_status"] == "completed"
|
||||
assert data["image_status"] == "ready"
|
||||
assert data["audio_status"] == "not_requested"
|
||||
|
||||
def test_get_generation_alias(self, auth_client: TestClient, test_story):
|
||||
response = auth_client.get(f"/api/generations/{test_story.id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == test_story.id
|
||||
assert data["title"] == test_story.title
|
||||
assert data["mode"] == "generated"
|
||||
|
||||
def test_retry_generation_assets_alias(
|
||||
self,
|
||||
auth_client: TestClient,
|
||||
degraded_story_with_text,
|
||||
mock_image_provider,
|
||||
):
|
||||
response = auth_client.post(
|
||||
f"/api/generations/{degraded_story_with_text.id}/retry-assets",
|
||||
json={"assets": ["image"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["image_url"] == "https://example.com/image.png"
|
||||
assert data["generation_status"] == "completed"
|
||||
assert data["image_status"] == "ready"
|
||||
|
||||
|
||||
class TestImageGenerateSuccess:
|
||||
"""Tests for successful cover generation."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user