feat: polish generation demo workflow

This commit is contained in:
2026-04-18 14:06:38 +08:00
parent 5d8fb1ed50
commit 0f260f649c
15 changed files with 569 additions and 74 deletions

View File

@@ -38,10 +38,22 @@ from app.services.story_status import StoryAssetStatus, sync_story_status
logger = get_logger(__name__)
router = APIRouter()
RATE_LIMIT_WINDOW = 60 # seconds
RATE_LIMIT_WINDOW = 60 # seconds
RATE_LIMIT_REQUESTS = 10
def _legacy_generation_headers(successor: str) -> dict[str, str]:
return {
"Deprecation": "true",
"Link": f"<{successor}>; rel=\"successor-version\"",
"X-DreamWeaver-Successor-Endpoint": successor,
}
def _mark_legacy_generation_endpoint(response: Response, successor: str) -> None:
response.headers.update(_legacy_generation_headers(successor))
@router.post("/generations", response_model=GenerationResponse)
async def create_generation(
request: GenerationRequest,
@@ -77,23 +89,27 @@ async def retry_generation_assets(
@router.post("/stories/generate", response_model=StoryResponse)
async def generate_story(
request: GenerateRequest,
response: Response,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Generate or enhance a story."""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
return await story_service.generate_and_save_story(request, user.id, db)
db: AsyncSession = Depends(get_db),
):
"""Generate or enhance a story."""
_mark_legacy_generation_endpoint(response, "/api/generations")
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
return await story_service.generate_and_save_story(request, user.id, db)
@router.post("/stories/generate/full", response_model=FullStoryResponse)
async def generate_story_full(
request: GenerateRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Generate complete story (story + parallel image/audio generation)."""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
return await story_service.generate_full_story_service(request, user.id, db)
async def generate_story_full(
request: GenerateRequest,
response: Response,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Generate complete story (story + parallel image/audio generation)."""
_mark_legacy_generation_endpoint(response, "/api/generations")
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
return await story_service.generate_full_story_service(request, user.id, db)
@router.post("/stories/generate/stream")
@@ -212,18 +228,23 @@ async def generate_story_stream(
),
}
return EventSourceResponse(event_generator())
return EventSourceResponse(
event_generator(),
headers=_legacy_generation_headers("/api/generations"),
)
@router.post("/storybook/generate", response_model=StorybookResponse)
async def generate_storybook_api(
request: StorybookRequest,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Generate storybook."""
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
return await story_service.generate_storybook_service(request, user.id, db)
async def generate_storybook_api(
request: StorybookRequest,
response: Response,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Generate storybook."""
_mark_legacy_generation_endpoint(response, "/api/generations")
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
return await story_service.generate_storybook_service(request, user.id, db)
# ==================== Missing Endpoints (Issue #5) ====================
@@ -263,10 +284,15 @@ async def delete_story(
@router.post("/image/generate/{story_id}", response_model=StoryImageResponse)
async def generate_story_image(
story_id: int,
response: Response,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Generate cover image for story."""
_mark_legacy_generation_endpoint(
response,
f"/api/generations/{story_id}/retry-assets",
)
url = await story_service.generate_story_cover(story_id, user.id, db)
story = await story_service.get_story_detail(story_id, user.id, db)
return {
@@ -282,10 +308,15 @@ async def generate_story_image(
async def retry_story_assets(
story_id: int,
payload: StoryAssetRetryRequest,
response: Response,
user: User = Depends(require_user),
db: AsyncSession = Depends(get_db),
):
"""Retry selected generated assets for a story."""
_mark_legacy_generation_endpoint(
response,
f"/api/generations/{story_id}/retry-assets",
)
return await story_service.retry_story_assets(story_id, user.id, payload.assets, db)

View File

@@ -9,6 +9,14 @@ from app.core.config import settings
from app.services.adapters.storybook.primary import Storybook, StorybookPage
def assert_legacy_generation_headers(response, successor: str) -> None:
"""Assert that compatibility generation endpoints point callers to the unified API."""
assert response.headers["Deprecation"] == "true"
assert response.headers["X-DreamWeaver-Successor-Endpoint"] == successor
assert response.headers["Link"] == f'<{successor}>; rel="successor-version"'
def build_storybook_output() -> Storybook:
"""Create a reusable mocked storybook payload."""
@@ -62,6 +70,7 @@ class TestStoryGenerate:
json={"type": "keywords", "data": "小兔子, 森林, 勇气"},
)
assert response.status_code == 200
assert_legacy_generation_headers(response, "/api/generations")
data = response.json()
assert "id" in data
assert "title" in data
@@ -317,6 +326,7 @@ class TestGenerateFull:
json={"type": "keywords", "data": "小兔子, 森林, 勇气"},
)
assert response.status_code == 200
assert_legacy_generation_headers(response, "/api/generations")
data = response.json()
assert "id" in data
assert "title" in data
@@ -394,6 +404,7 @@ class TestUnifiedGenerations:
)
assert response.status_code == 200
assert "Deprecation" not in response.headers
data = response.json()
assert data["id"] is not None
assert data["mode"] == "generated"
@@ -534,6 +545,10 @@ class TestImageGenerateSuccess:
):
response = auth_client.post(f"/api/image/generate/{test_story.id}")
assert response.status_code == 200
assert_legacy_generation_headers(
response,
f"/api/generations/{test_story.id}/retry-assets",
)
data = response.json()
assert data["image_url"] == "https://example.com/image.png"
assert data["generation_status"] == "completed"
@@ -557,6 +572,10 @@ class TestAssetRetry:
)
assert response.status_code == 200
assert_legacy_generation_headers(
response,
f"/api/generations/{degraded_story_with_text.id}/retry-assets",
)
data = response.json()
assert data["image_url"] == "https://example.com/image.png"
assert data["generation_status"] == "completed"
@@ -585,6 +604,10 @@ class TestAssetRetry:
)
assert response.status_code == 200
assert_legacy_generation_headers(
response,
f"/api/generations/{storybook_story.id}/retry-assets",
)
data = response.json()
assert data["generation_status"] == "completed"
assert data["image_status"] == "ready"