feat: add unified asset retry endpoint
Some checks failed
Build and Push Docker Images / changes (push) Has been cancelled
Build and Push Docker Images / build-backend (push) Has been cancelled
Build and Push Docker Images / build-frontend (push) Has been cancelled
Build and Push Docker Images / build-admin-frontend (push) Has been cancelled
Some checks failed
Build and Push Docker Images / changes (push) Has been cancelled
Build and Push Docker Images / build-backend (push) Has been cancelled
Build and Push Docker Images / build-frontend (push) Has been cancelled
Build and Push Docker Images / build-admin-frontend (push) Has been cancelled
This commit is contained in:
@@ -5,33 +5,34 @@ import uuid
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import APIRouter, Depends, Response
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.deps import require_user
|
||||
from app.core.logging import get_logger
|
||||
from app.core.rate_limiter import check_rate_limit
|
||||
from app.db.database import get_db
|
||||
from app.db.models import User
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from app.core.deps import require_user
|
||||
from app.core.logging import get_logger
|
||||
from app.core.rate_limiter import check_rate_limit
|
||||
from app.db.database import get_db
|
||||
from app.db.models import User
|
||||
from app.schemas.story_schemas import (
|
||||
AchievementItem,
|
||||
FullStoryResponse,
|
||||
GenerateRequest,
|
||||
StoryAssetRetryRequest,
|
||||
StorybookRequest,
|
||||
StorybookResponse,
|
||||
StoryDetailResponse,
|
||||
StoryImageResponse,
|
||||
StoryListItem,
|
||||
StoryResponse,
|
||||
StorybookRequest,
|
||||
StorybookResponse,
|
||||
)
|
||||
from app.services import story_service
|
||||
from app.services.memory_service import build_enhanced_memory_context
|
||||
from app.services.provider_router import (
|
||||
generate_story_content,
|
||||
generate_image,
|
||||
generate_story_content,
|
||||
)
|
||||
from app.services.story_status import StoryAssetStatus, sync_story_status
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
@@ -69,12 +70,12 @@ async def generate_story_stream(
|
||||
):
|
||||
"""流式生成故事(SSE)。"""
|
||||
await check_rate_limit(f"story:{user.id}", RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW)
|
||||
|
||||
|
||||
# Validation
|
||||
profile_id, universe_id = await story_service.validate_profile_and_universe(
|
||||
request.child_profile_id, request.universe_id, user.id, db
|
||||
)
|
||||
|
||||
|
||||
# Build Context
|
||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||
|
||||
@@ -241,10 +242,21 @@ async def generate_story_image(
|
||||
"audio_status": story.audio_status,
|
||||
"last_error": story.last_error,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/audio/{story_id}")
|
||||
async def get_story_audio(
|
||||
|
||||
|
||||
@router.post("/stories/{story_id}/assets/retry", response_model=StoryDetailResponse)
|
||||
async def retry_story_assets(
|
||||
story_id: int,
|
||||
payload: StoryAssetRetryRequest,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Retry selected generated assets for a story."""
|
||||
return await story_service.retry_story_assets(story_id, user.id, payload.assets, db)
|
||||
|
||||
|
||||
@router.get("/audio/{story_id}")
|
||||
async def get_story_audio(
|
||||
story_id: int,
|
||||
user: User = Depends(require_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
MAX_DATA_LENGTH = 2000
|
||||
MAX_EDU_THEME_LENGTH = 200
|
||||
MAX_TTS_LENGTH = 4000
|
||||
@@ -120,6 +119,12 @@ class StoryImageResponse(StoryStatusMixin):
|
||||
image_url: str | None
|
||||
|
||||
|
||||
class StoryAssetRetryRequest(BaseModel):
|
||||
"""Retry selected generated assets for a story."""
|
||||
|
||||
assets: list[Literal["image", "audio"]] = Field(..., min_length=1)
|
||||
|
||||
|
||||
class AchievementItem(BaseModel):
|
||||
"""Achievement item returned for a story."""
|
||||
|
||||
|
||||
@@ -10,12 +10,12 @@ from sqlalchemy.orm import joinedload
|
||||
from app.core.logging import get_logger
|
||||
from app.db.models import ChildProfile, Story, StoryUniverse
|
||||
from app.schemas.story_schemas import (
|
||||
GenerateRequest,
|
||||
StorybookRequest,
|
||||
FullStoryResponse,
|
||||
StorybookResponse,
|
||||
StorybookPageResponse,
|
||||
AchievementItem,
|
||||
FullStoryResponse,
|
||||
GenerateRequest,
|
||||
StorybookPageResponse,
|
||||
StorybookRequest,
|
||||
StorybookResponse,
|
||||
)
|
||||
from app.services.audio_storage import (
|
||||
audio_cache_exists,
|
||||
@@ -24,8 +24,8 @@ from app.services.audio_storage import (
|
||||
)
|
||||
from app.services.memory_service import build_enhanced_memory_context
|
||||
from app.services.provider_router import (
|
||||
generate_story_content,
|
||||
generate_image,
|
||||
generate_story_content,
|
||||
generate_storybook,
|
||||
)
|
||||
from app.services.story_status import (
|
||||
@@ -140,7 +140,7 @@ async def generate_and_save_story(
|
||||
profile_id, universe_id = await validate_profile_and_universe(
|
||||
request.child_profile_id, request.universe_id, user_id, db
|
||||
)
|
||||
|
||||
|
||||
# 2. Build Context
|
||||
memory_context = await build_enhanced_memory_context(profile_id, universe_id, db)
|
||||
|
||||
@@ -153,8 +153,11 @@ async def generate_and_save_story(
|
||||
memory_context=memory_context,
|
||||
db=db,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=502, detail="Story generation failed, please try again.") from exc
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Story generation failed, please try again.",
|
||||
) from exc
|
||||
|
||||
# 4. Save
|
||||
story = Story(
|
||||
@@ -247,7 +250,7 @@ async def generate_storybook_service(
|
||||
profile_id, universe_id = await validate_profile_and_universe(
|
||||
request.child_profile_id, request.universe_id, user_id, db
|
||||
)
|
||||
|
||||
|
||||
logger.info(
|
||||
"storybook_request",
|
||||
user_id=user_id,
|
||||
@@ -418,11 +421,11 @@ async def get_story_detail(
|
||||
return story
|
||||
|
||||
|
||||
async def delete_story(
|
||||
story_id: int,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
) -> None:
|
||||
async def delete_story(
|
||||
story_id: int,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
) -> None:
|
||||
"""Delete a story."""
|
||||
story = await get_story_detail(story_id, user_id, db)
|
||||
await db.delete(story)
|
||||
@@ -456,12 +459,131 @@ async def create_story_from_result(
|
||||
await db.commit()
|
||||
await db.refresh(story)
|
||||
|
||||
if universe_id:
|
||||
extract_story_achievements.delay(story.id, universe_id)
|
||||
|
||||
return story
|
||||
|
||||
|
||||
if universe_id:
|
||||
extract_story_achievements.delay(story.id, universe_id)
|
||||
|
||||
return story
|
||||
|
||||
|
||||
async def _retry_cover_image_asset(story: Story, db: AsyncSession) -> None:
|
||||
"""Retry cover generation for a text story."""
|
||||
|
||||
if not story.cover_prompt:
|
||||
raise HTTPException(status_code=400, detail="Story has no cover prompt")
|
||||
|
||||
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
|
||||
await db.commit()
|
||||
|
||||
try:
|
||||
story.image_url = await generate_image(story.cover_prompt, db=db)
|
||||
sync_story_status(story, image_status=StoryAssetStatus.READY)
|
||||
except Exception as exc:
|
||||
sync_story_status(
|
||||
story,
|
||||
image_status=StoryAssetStatus.FAILED,
|
||||
last_error=f"封面生成失败: {exc}",
|
||||
)
|
||||
logger.error("cover_asset_retry_failed", story_id=story.id, error=str(exc))
|
||||
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def _retry_storybook_image_assets(story: Story, db: AsyncSession) -> None:
|
||||
"""Retry missing storybook cover/page images."""
|
||||
|
||||
pages_data = [dict(page) for page in story.pages or [] if isinstance(page, dict)]
|
||||
has_image_prompt = bool(story.cover_prompt) or any(
|
||||
page.get("image_prompt") for page in pages_data
|
||||
)
|
||||
if not has_image_prompt:
|
||||
raise HTTPException(status_code=400, detail="Storybook has no image prompts")
|
||||
|
||||
sync_story_status(story, image_status=StoryAssetStatus.GENERATING)
|
||||
await db.commit()
|
||||
|
||||
cover_failed = False
|
||||
failed_pages: list[int] = []
|
||||
|
||||
if story.cover_prompt and not story.image_url:
|
||||
try:
|
||||
story.image_url = await generate_image(story.cover_prompt, db=db)
|
||||
except Exception as exc:
|
||||
cover_failed = True
|
||||
logger.warning(
|
||||
"storybook_cover_asset_retry_failed",
|
||||
story_id=story.id,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
for page in pages_data:
|
||||
if not page.get("image_prompt") or page.get("image_url"):
|
||||
continue
|
||||
|
||||
try:
|
||||
page["image_url"] = await generate_image(page["image_prompt"], db=db)
|
||||
except Exception as exc:
|
||||
page_number = page.get("page_number")
|
||||
if isinstance(page_number, int):
|
||||
failed_pages.append(page_number)
|
||||
logger.warning(
|
||||
"storybook_page_asset_retry_failed",
|
||||
story_id=story.id,
|
||||
page=page_number,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
story.pages = pages_data
|
||||
sync_story_status(
|
||||
story,
|
||||
image_status=_resolve_storybook_image_status(
|
||||
generate_images=True,
|
||||
cover_prompt=story.cover_prompt,
|
||||
cover_url=story.image_url,
|
||||
pages_data=pages_data,
|
||||
),
|
||||
last_error=_build_storybook_error_message(
|
||||
cover_failed=cover_failed,
|
||||
failed_pages=failed_pages,
|
||||
),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def _retry_audio_asset(story_id: int, user_id: str, db: AsyncSession) -> None:
|
||||
"""Retry audio generation while preserving persisted status on provider failure."""
|
||||
|
||||
try:
|
||||
await generate_story_audio(story_id, user_id, db)
|
||||
except HTTPException as exc:
|
||||
if exc.status_code >= 500:
|
||||
logger.warning("audio_asset_retry_failed", story_id=story_id, error=exc.detail)
|
||||
return
|
||||
raise
|
||||
|
||||
|
||||
async def retry_story_assets(
|
||||
story_id: int,
|
||||
user_id: str,
|
||||
assets: list[str],
|
||||
db: AsyncSession,
|
||||
) -> Story:
|
||||
"""Retry selected assets through one workflow-level endpoint."""
|
||||
|
||||
story = await get_story_detail(story_id, user_id, db)
|
||||
requested_assets = list(dict.fromkeys(assets))
|
||||
|
||||
if "image" in requested_assets:
|
||||
if story.mode == "storybook":
|
||||
await _retry_storybook_image_assets(story, db)
|
||||
else:
|
||||
await _retry_cover_image_asset(story, db)
|
||||
|
||||
if "audio" in requested_assets:
|
||||
await _retry_audio_asset(story_id, user_id, db)
|
||||
|
||||
return await get_story_detail(story_id, user_id, db)
|
||||
|
||||
|
||||
async def generate_story_cover(
|
||||
story_id: int,
|
||||
user_id: str,
|
||||
@@ -469,7 +591,7 @@ async def generate_story_cover(
|
||||
) -> str:
|
||||
"""Generate cover image for an existing story."""
|
||||
story = await get_story_detail(story_id, user_id, db)
|
||||
|
||||
|
||||
if not story.cover_prompt:
|
||||
raise HTTPException(status_code=400, detail="Story has no cover prompt")
|
||||
|
||||
@@ -503,7 +625,7 @@ async def generate_story_audio(
|
||||
) -> bytes:
|
||||
"""Generate audio for a story."""
|
||||
story = await get_story_detail(story_id, user_id, db)
|
||||
|
||||
|
||||
if not story.story_text:
|
||||
raise HTTPException(status_code=400, detail="Story has no text")
|
||||
|
||||
|
||||
@@ -264,7 +264,10 @@ class TestAudio:
|
||||
mock_tts_provider.assert_awaited_once()
|
||||
|
||||
def test_get_audio_failure_updates_status(self, auth_client: TestClient, test_story):
|
||||
with patch("app.services.provider_router.text_to_speech", new_callable=AsyncMock) as mock_tts:
|
||||
with patch(
|
||||
"app.services.provider_router.text_to_speech",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_tts:
|
||||
mock_tts.side_effect = Exception("TTS provider timeout")
|
||||
response = auth_client.get(f"/api/audio/{test_story.id}")
|
||||
assert response.status_code == 500
|
||||
@@ -383,12 +386,83 @@ class TestImageGenerateSuccess:
|
||||
assert data["last_error"] is None
|
||||
|
||||
|
||||
class TestAssetRetry:
|
||||
"""Tests for unified asset retry endpoint."""
|
||||
|
||||
def test_retry_cover_image_success(
|
||||
self,
|
||||
auth_client: TestClient,
|
||||
degraded_story_with_text,
|
||||
mock_image_provider,
|
||||
):
|
||||
response = auth_client.post(
|
||||
f"/api/stories/{degraded_story_with_text.id}/assets/retry",
|
||||
json={"assets": ["image"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["image_url"] == "https://example.com/image.png"
|
||||
assert data["generation_status"] == "completed"
|
||||
assert data["image_status"] == "ready"
|
||||
assert data["audio_status"] == "not_requested"
|
||||
assert data["last_error"] is None
|
||||
mock_image_provider.assert_awaited_once()
|
||||
|
||||
def test_retry_storybook_missing_page_image_success(
|
||||
self,
|
||||
auth_client: TestClient,
|
||||
storybook_story,
|
||||
):
|
||||
async def image_side_effect(prompt: str, **kwargs):
|
||||
return "https://example.com/retried-page.png"
|
||||
|
||||
with patch(
|
||||
"app.services.story_service.generate_image",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_image:
|
||||
mock_image.side_effect = image_side_effect
|
||||
|
||||
response = auth_client.post(
|
||||
f"/api/stories/{storybook_story.id}/assets/retry",
|
||||
json={"assets": ["image"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["generation_status"] == "completed"
|
||||
assert data["image_status"] == "ready"
|
||||
assert data["audio_status"] == "not_requested"
|
||||
assert data["last_error"] is None
|
||||
assert data["pages"][1]["image_url"] == "https://example.com/retried-page.png"
|
||||
mock_image.assert_awaited_once()
|
||||
|
||||
def test_retry_audio_on_storybook_is_rejected(
|
||||
self,
|
||||
auth_client: TestClient,
|
||||
storybook_story,
|
||||
):
|
||||
response = auth_client.post(
|
||||
f"/api/stories/{storybook_story.id}/assets/retry",
|
||||
json={"assets": ["audio"]},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "Story has no text"
|
||||
|
||||
|
||||
class TestStorybookGenerate:
|
||||
"""Tests for storybook generation status handling."""
|
||||
|
||||
def test_generate_storybook_success(self, auth_client: TestClient):
|
||||
with patch("app.services.story_service.generate_storybook", new_callable=AsyncMock) as mock_storybook:
|
||||
with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock_image:
|
||||
with patch(
|
||||
"app.services.story_service.generate_storybook",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_storybook:
|
||||
with patch(
|
||||
"app.services.story_service.generate_image",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_image:
|
||||
mock_storybook.return_value = build_storybook_output()
|
||||
mock_image.side_effect = [
|
||||
"https://example.com/storybook-cover.png",
|
||||
@@ -422,8 +496,14 @@ class TestStorybookGenerate:
|
||||
slug = prompt.split()[0].lower()
|
||||
return f"https://example.com/{slug}.png"
|
||||
|
||||
with patch("app.services.story_service.generate_storybook", new_callable=AsyncMock) as mock_storybook:
|
||||
with patch("app.services.story_service.generate_image", new_callable=AsyncMock) as mock_image:
|
||||
with patch(
|
||||
"app.services.story_service.generate_storybook",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_storybook:
|
||||
with patch(
|
||||
"app.services.story_service.generate_image",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_image:
|
||||
mock_storybook.return_value = build_storybook_output()
|
||||
mock_image.side_effect = image_side_effect
|
||||
|
||||
|
||||
Reference in New Issue
Block a user