fix: stabilize auth and generation workflows

This commit is contained in:
2026-04-23 22:31:14 +08:00
parent 4db04e61e9
commit 7e450aa5fc
16 changed files with 335 additions and 127 deletions

View File

@@ -1,23 +1,35 @@
const BASE_URL = '' const BASE_URL = ''
class ApiClient { class ApiClient {
async request<T>(url: string, options: RequestInit = {}): Promise<T> { async request<T>(url: string, options: RequestInit = {}): Promise<T> {
const response = await fetch(`${BASE_URL}${url}`, { const headers = new Headers(options.headers || {})
...options, const isFormData = options.body instanceof FormData
credentials: 'include', if (!isFormData && !headers.has('Content-Type')) {
headers: { headers.set('Content-Type', 'application/json')
'Content-Type': 'application/json', }
...options.headers,
}, const response = await fetch(`${BASE_URL}${url}`, {
}) ...options,
credentials: 'include',
if (!response.ok) { headers,
const error = await response.json().catch(() => ({ detail: '请求失败' })) })
throw new Error(error.detail || '请求失败')
} if (!response.ok) {
const error = await response.json().catch(() => ({ detail: '请求失败' }))
return response.json() throw new Error(error.detail || '请求失败')
} }
if (response.status === 204 || response.status === 205) {
return undefined as T
}
const contentType = response.headers.get('content-type') || ''
if (!contentType.includes('application/json')) {
return undefined as T
}
return response.json()
}
get<T>(url: string): Promise<T> { get<T>(url: string): Promise<T> {
return this.request<T>(url) return this.request<T>(url)

View File

@@ -145,12 +145,12 @@ function sleep(ms: number) {
async function waitForStoryId(jobId: string) { async function waitForStoryId(jobId: string) {
for (let attempt = 0; attempt < JOB_POLL_MAX_ATTEMPTS; attempt += 1) { for (let attempt = 0; attempt < JOB_POLL_MAX_ATTEMPTS; attempt += 1) {
const detail = await api.get<GenerationJobDetail>(`/api/generations/jobs/${jobId}`) const detail = await api.get<GenerationJobDetail>(`/api/generations/jobs/${jobId}`)
if (detail.status === 'canceled' || detail.current_step === 'generation_canceled') {
return null
}
if (detail.story_id) { if (detail.story_id) {
return detail.story_id return detail.story_id
} }
if (detail.status === 'canceled' || detail.current_step === 'generation_canceled') {
return null
}
if (detail.is_terminal) { if (detail.is_terminal) {
throw new Error(detail.error_message || '生成失败,请稍后重试') throw new Error(detail.error_message || '生成失败,请稍后重试')
} }

View File

@@ -74,11 +74,7 @@ const latestJob = computed(() => jobs.value[0] ?? null)
const activeEvents = computed(() => activeJob.value?.events.slice(-10) ?? []) const activeEvents = computed(() => activeJob.value?.events.slice(-10) ?? [])
const activeProgress = computed(() => activeJob.value?.progress_percent ?? latestJob.value?.progress_percent ?? 0) const activeProgress = computed(() => activeJob.value?.progress_percent ?? latestJob.value?.progress_percent ?? 0)
const activeProgressLabel = computed(() => activeJob.value?.progress_label ?? latestJob.value?.progress_label ?? '暂无进度') const activeProgressLabel = computed(() => activeJob.value?.progress_label ?? latestJob.value?.progress_label ?? '暂无进度')
const shouldAutoRefresh = computed(() => { const shouldAutoRefresh = computed(() => Boolean(latestJob.value && !latestJob.value.is_terminal))
if (activeJob.value) return !activeJob.value.is_terminal
if (latestJob.value) return !latestJob.value.is_terminal
return false
})
const providerSuccessRate = computed(() => { const providerSuccessRate = computed(() => {
if (!providerStats.value?.total_calls) return null if (!providerStats.value?.total_calls) return null
return Math.round((providerStats.value.successful_calls / providerStats.value.total_calls) * 100) return Math.round((providerStats.value.successful_calls / providerStats.value.total_calls) * 100)
@@ -199,6 +195,7 @@ async function refresh() {
} }
error.value = '' error.value = ''
const selectedJobId = activeJob.value?.id ?? null
try { try {
const [nextJobs, stats] = await Promise.all([ const [nextJobs, stats] = await Promise.all([
@@ -207,7 +204,11 @@ async function refresh() {
]) ])
jobs.value = nextJobs jobs.value = nextJobs
providerStats.value = stats providerStats.value = stats
const nextJobId = jobs.value[0]?.id const nextJobId = (
selectedJobId
? jobs.value.find((job) => job.id === selectedJobId)?.id
: null
) ?? jobs.value[0]?.id
if (nextJobId) { if (nextJobId) {
await selectJob(nextJobId) await selectJob(nextJobId)
} else { } else {
@@ -346,7 +347,13 @@ defineExpose({ refresh })
> >
<div class="flex items-center justify-between gap-2"> <div class="flex items-center justify-between gap-2">
<span class="text-sm font-semibold"> <span class="text-sm font-semibold">
{{ job.output_mode === 'asset_retry' ? '资源重试' : '内容生成' }} {{
job.output_mode === 'asset_retry'
? '资源重试'
: job.output_mode === 'asset_generation'
? '资源生成'
: '内容生成'
}}
</span> </span>
<span class="rounded-full border px-2 py-0.5 text-xs" :class="statusClass(job.status)"> <span class="rounded-full border px-2 py-0.5 text-xs" :class="statusClass(job.status)">
{{ statusLabel(job.status) }} {{ statusLabel(job.status) }}
@@ -366,7 +373,13 @@ defineExpose({ refresh })
<div class="flex flex-wrap items-center justify-between gap-3"> <div class="flex flex-wrap items-center justify-between gap-3">
<div> <div>
<div class="text-sm font-semibold"> <div class="text-sm font-semibold">
{{ activeJob.output_mode === 'asset_retry' ? '资源重试事件' : '生成事件' }} {{
activeJob.output_mode === 'asset_retry'
? '资源重试事件'
: activeJob.output_mode === 'asset_generation'
? '资源生成事件'
: '生成事件'
}}
</div> </div>
<div class="mt-1 text-xs" :class="mutedClass"> <div class="mt-1 text-xs" :class="mutedClass">
当前步骤{{ eventLabel(activeJob.current_step) }} 当前步骤{{ eventLabel(activeJob.current_step) }}

View File

@@ -1,5 +1,6 @@
<script setup lang="ts"> <script setup lang="ts">
import { XMarkIcon, CommandLineIcon } from '@heroicons/vue/24/outline' import { XMarkIcon, CommandLineIcon } from '@heroicons/vue/24/outline'
import { buildAuthSigninUrl } from '../../utils/auth'
defineProps<{ defineProps<{
modelValue: boolean modelValue: boolean
@@ -13,18 +14,18 @@ function close() {
emit('update:modelValue', false) emit('update:modelValue', false)
} }
function loginWithGithub() { function loginWithGithub() {
window.location.href = '/auth/github/signin' window.location.href = buildAuthSigninUrl('github')
} }
function loginWithGoogle() { function loginWithGoogle() {
window.location.href = '/auth/google/signin' window.location.href = buildAuthSigninUrl('google')
} }
function loginWithDev() { function loginWithDev() {
window.location.href = '/auth/dev/signin' window.location.href = buildAuthSigninUrl('dev')
} }
</script> </script>
<template> <template>
<Teleport to="body"> <Teleport to="body">

View File

@@ -1,6 +1,7 @@
import { defineStore } from 'pinia' import { defineStore } from 'pinia'
import { ref } from 'vue' import { ref } from 'vue'
import { api } from '../api/client' import { api } from '../api/client'
import { buildAuthSigninUrl } from '../utils/auth'
interface User { interface User {
id: string id: string
@@ -25,13 +26,13 @@ export const useUserStore = defineStore('user', () => {
} }
} }
function loginWithGithub() { function loginWithGithub() {
window.location.href = '/auth/github/signin' window.location.href = buildAuthSigninUrl('github')
} }
function loginWithGoogle() { function loginWithGoogle() {
window.location.href = '/auth/google/signin' window.location.href = buildAuthSigninUrl('google')
} }
async function logout() { async function logout() {
await api.post('/auth/signout') await api.post('/auth/signout')

View File

@@ -0,0 +1,8 @@
type AuthProvider = 'github' | 'google' | 'dev'
const DEFAULT_POST_LOGIN_PATH = '/console/providers'
export function buildAuthSigninUrl(provider: AuthProvider): string {
const next = new URL(DEFAULT_POST_LOGIN_PATH, window.location.origin).toString()
return `/auth/${provider}/signin?next=${encodeURIComponent(next)}`
}

View File

@@ -1,8 +1,8 @@
import secrets import secrets
from urllib.parse import urlencode from urllib.parse import quote, unquote, urlencode, urlparse
import httpx import httpx
from fastapi import APIRouter, Cookie, Depends, HTTPException, Query from fastapi import APIRouter, Cookie, Depends, HTTPException, Query, Response
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -26,6 +26,8 @@ GOOGLE_USER_URL = "https://www.googleapis.com/oauth2/v2/userinfo"
STATE_COOKIE = "oauth_state" STATE_COOKIE = "oauth_state"
STATE_MAX_AGE = 600 # 10 minutes STATE_MAX_AGE = 600 # 10 minutes
NEXT_COOKIE = "oauth_next"
NEXT_MAX_AGE = 600 # 10 minutes
def _set_state_cookie(response: RedirectResponse, provider: str, state: str) -> None: def _set_state_cookie(response: RedirectResponse, provider: str, state: str) -> None:
@@ -39,6 +41,53 @@ def _set_state_cookie(response: RedirectResponse, provider: str, state: str) ->
) )
def _is_allowed_frontend_redirect(url: str | None) -> bool:
if not url:
return False
parsed = urlparse(url)
if not parsed.scheme or not parsed.netloc:
return False
origin = f"{parsed.scheme}://{parsed.netloc}"
return origin in settings.cors_origins
def _set_next_cookie(response: RedirectResponse, next_url: str | None) -> None:
if not _is_allowed_frontend_redirect(next_url):
return
response.set_cookie(
key=NEXT_COOKIE,
value=quote(next_url or "", safe=""),
httponly=True,
secure=not settings.debug,
samesite="lax",
max_age=NEXT_MAX_AGE,
)
def _decode_next_cookie(next_cookie: str | None) -> str | None:
if not next_cookie:
return None
return unquote(next_cookie)
def _build_default_frontend_redirect(path: str = "/my-stories") -> str:
frontend_origin = settings.cors_origins[0] if settings.cors_origins else "http://localhost:5173"
return f"{frontend_origin.rstrip('/')}{path}"
def _resolve_frontend_redirect(
next_url: str | None,
*,
fallback_path: str = "/my-stories",
) -> str:
if _is_allowed_frontend_redirect(next_url):
return str(next_url)
return _build_default_frontend_redirect(fallback_path)
def _validate_state(state_from_query: str | None, state_cookie: str | None, provider: str): def _validate_state(state_from_query: str | None, state_cookie: str | None, provider: str):
if not state_from_query or not state_cookie: if not state_from_query or not state_cookie:
raise HTTPException(status_code=400, detail="Missing OAuth state") raise HTTPException(status_code=400, detail="Missing OAuth state")
@@ -51,7 +100,7 @@ def _validate_state(state_from_query: str | None, state_cookie: str | None, prov
@router.get("/github/signin") @router.get("/github/signin")
async def github_signin(): async def github_signin(next: str | None = Query(default=None)):
"""Start GitHub OAuth with state protection.""" """Start GitHub OAuth with state protection."""
state = secrets.token_urlsafe(16) state = secrets.token_urlsafe(16)
params = { params = {
@@ -63,6 +112,7 @@ async def github_signin():
url = f"{GITHUB_AUTHORIZE_URL}?{urlencode(params)}" url = f"{GITHUB_AUTHORIZE_URL}?{urlencode(params)}"
response = RedirectResponse(url=url) response = RedirectResponse(url=url)
_set_state_cookie(response, "github", state) _set_state_cookie(response, "github", state)
_set_next_cookie(response, next)
return response return response
@@ -71,6 +121,7 @@ async def github_callback(
code: str, code: str,
state: str | None = Query(default=None), state: str | None = Query(default=None),
state_cookie: str | None = Cookie(default=None, alias=STATE_COOKIE), state_cookie: str | None = Cookie(default=None, alias=STATE_COOKIE),
next_cookie: str | None = Cookie(default=None, alias=NEXT_COOKIE),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""Handle GitHub OAuth callback.""" """Handle GitHub OAuth callback."""
@@ -112,11 +163,12 @@ async def github_callback(
user_id=str(github_id), user_id=str(github_id),
name=user_data.get("name") or user_data.get("login") or "GitHub User", name=user_data.get("name") or user_data.get("login") or "GitHub User",
avatar_url=user_data.get("avatar_url"), avatar_url=user_data.get("avatar_url"),
next_url=_decode_next_cookie(next_cookie),
) )
@router.get("/google/signin") @router.get("/google/signin")
async def google_signin(): async def google_signin(next: str | None = Query(default=None)):
"""Start Google OAuth with state protection.""" """Start Google OAuth with state protection."""
state = secrets.token_urlsafe(16) state = secrets.token_urlsafe(16)
params = { params = {
@@ -129,6 +181,7 @@ async def google_signin():
url = f"{GOOGLE_AUTHORIZE_URL}?{urlencode(params)}" url = f"{GOOGLE_AUTHORIZE_URL}?{urlencode(params)}"
response = RedirectResponse(url=url) response = RedirectResponse(url=url)
_set_state_cookie(response, "google", state) _set_state_cookie(response, "google", state)
_set_next_cookie(response, next)
return response return response
@@ -137,6 +190,7 @@ async def google_callback(
code: str, code: str,
state: str | None = Query(default=None), state: str | None = Query(default=None),
state_cookie: str | None = Cookie(default=None, alias=STATE_COOKIE), state_cookie: str | None = Cookie(default=None, alias=STATE_COOKIE),
next_cookie: str | None = Cookie(default=None, alias=NEXT_COOKIE),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""Handle Google OAuth callback.""" """Handle Google OAuth callback."""
@@ -179,6 +233,7 @@ async def google_callback(
user_id=str(google_id), user_id=str(google_id),
name=user_data.get("name") or user_data.get("email") or "Google User", name=user_data.get("name") or user_data.get("email") or "Google User",
avatar_url=user_data.get("picture"), avatar_url=user_data.get("picture"),
next_url=_decode_next_cookie(next_cookie),
) )
@@ -188,6 +243,7 @@ async def _handle_oauth_user(
user_id: str, user_id: str,
name: str, name: str,
avatar_url: str | None, avatar_url: str | None,
next_url: str | None = None,
) -> RedirectResponse: ) -> RedirectResponse:
"""Create/update user and issue session cookie.""" """Create/update user and issue session cookie."""
full_id = f"{provider}:{user_id}" full_id = f"{provider}:{user_id}"
@@ -211,11 +267,10 @@ async def _handle_oauth_user(
token = create_access_token({"sub": user.id}) token = create_access_token({"sub": user.id})
frontend_url = "http://localhost:5173" response = RedirectResponse(
if settings.cors_origins and len(settings.cors_origins) > 0: url=_resolve_frontend_redirect(next_url, fallback_path="/my-stories"),
frontend_url = settings.cors_origins[0] status_code=302,
)
response = RedirectResponse(url=f"{frontend_url}/my-stories", status_code=302)
response.set_cookie( response.set_cookie(
key="access_token", key="access_token",
value=token, value=token,
@@ -225,15 +280,17 @@ async def _handle_oauth_user(
max_age=60 * 60 * 24 * 7, # align with ACCESS_TOKEN_EXPIRE_DAYS max_age=60 * 60 * 24 * 7, # align with ACCESS_TOKEN_EXPIRE_DAYS
) )
response.delete_cookie(STATE_COOKIE) response.delete_cookie(STATE_COOKIE)
response.delete_cookie(NEXT_COOKIE)
return response return response
@router.post("/signout") @router.post("/signout", status_code=204)
async def signout(): async def signout():
"""Sign out and clear cookies.""" """Sign out and clear cookies."""
response = RedirectResponse(url=settings.cors_origins[0], status_code=302) response = Response(status_code=204)
response.delete_cookie("access_token", samesite="lax", secure=not settings.debug) response.delete_cookie("access_token", samesite="lax", secure=not settings.debug)
response.delete_cookie(STATE_COOKIE, samesite="lax", secure=not settings.debug) response.delete_cookie(STATE_COOKIE, samesite="lax", secure=not settings.debug)
response.delete_cookie(NEXT_COOKIE, samesite="lax", secure=not settings.debug)
return response return response
@@ -253,7 +310,10 @@ async def get_session(user: User | None = Depends(get_current_user)):
@router.get("/dev/signin") @router.get("/dev/signin")
async def dev_signin(db: AsyncSession = Depends(get_db)): async def dev_signin(
next: str | None = Query(default=None),
db: AsyncSession = Depends(get_db),
):
"""Developer backdoor login. Only works in DEBUG mode.""" """Developer backdoor login. Only works in DEBUG mode."""
if not settings.debug: if not settings.debug:
raise HTTPException(status_code=403, detail="Developer login disabled") raise HTTPException(status_code=403, detail="Developer login disabled")
@@ -264,7 +324,8 @@ async def dev_signin(db: AsyncSession = Depends(get_db)):
provider="github", provider="github",
user_id="dev_user_001", user_id="dev_user_001",
name="Developer", name="Developer",
avatar_url="https://api.dicebear.com/7.x/avataaars/svg?seed=Developer" avatar_url="https://api.dicebear.com/7.x/avataaars/svg?seed=Developer",
next_url=next,
) )
except Exception as e: except Exception as e:
import traceback import traceback

View File

@@ -1,8 +1,9 @@
"""认证相关测试。""" """认证相关测试。"""
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from app.core.security import create_access_token, decode_access_token from app.core.config import settings
from app.core.security import create_access_token, decode_access_token
class TestJWT: class TestJWT:
@@ -55,10 +56,43 @@ class TestSession:
assert data["user"] is None assert data["user"] is None
class TestSignout: class TestSignout:
"""登出测试。""" """登出测试。"""
def test_signout(self, auth_client: TestClient): def test_signout(self, auth_client: TestClient):
"""测试登出。""" """测试登出。"""
response = auth_client.post("/auth/signout", follow_redirects=False) response = auth_client.post("/auth/signout")
assert response.status_code == 302 assert response.status_code == 204
assert response.content == b""
set_cookie_headers = response.headers.get_list("set-cookie")
assert any("access_token=" in value for value in set_cookie_headers)
class TestDevSigninRedirect:
"""开发登录重定向测试。"""
def test_dev_signin_uses_allowed_next_url(self, client: TestClient, monkeypatch):
"""允许的 next 参数应作为登录完成后的回跳地址。"""
monkeypatch.setattr(settings, "cors_origins", ["http://localhost:5173", "http://localhost:5174"])
response = client.get(
"/auth/dev/signin",
params={"next": "http://localhost:5174/console/providers"},
follow_redirects=False,
)
assert response.status_code == 302
assert response.headers["location"] == "http://localhost:5174/console/providers"
def test_dev_signin_rejects_untrusted_next_url(self, client: TestClient, monkeypatch):
"""不可信的 next 参数应回退到默认前端地址,避免开放重定向。"""
monkeypatch.setattr(settings, "cors_origins", ["http://localhost:5173", "http://localhost:5174"])
response = client.get(
"/auth/dev/signin",
params={"next": "https://evil.example/steal"},
follow_redirects=False,
)
assert response.status_code == 302
assert response.headers["location"] == "http://localhost:5173/my-stories"

View File

@@ -17,10 +17,19 @@ class ApiClient {
if (!response.ok) { if (!response.ok) {
const error = await response.json().catch(() => ({ detail: '请求失败' })) const error = await response.json().catch(() => ({ detail: '请求失败' }))
throw new Error(error.detail || '请求失败') throw new Error(error.detail || '请求失败')
} }
return response.json() if (response.status === 204 || response.status === 205) {
} return undefined as T
}
const contentType = response.headers.get('content-type') || ''
if (!contentType.includes('application/json')) {
return undefined as T
}
return response.json()
}
get<T>(url: string): Promise<T> { get<T>(url: string): Promise<T> {
return this.request<T>(url) return this.request<T>(url)

View File

@@ -145,12 +145,12 @@ function sleep(ms: number) {
async function waitForStoryId(jobId: string) { async function waitForStoryId(jobId: string) {
for (let attempt = 0; attempt < JOB_POLL_MAX_ATTEMPTS; attempt += 1) { for (let attempt = 0; attempt < JOB_POLL_MAX_ATTEMPTS; attempt += 1) {
const detail = await api.get<GenerationJobDetail>(`/api/generations/jobs/${jobId}`) const detail = await api.get<GenerationJobDetail>(`/api/generations/jobs/${jobId}`)
if (detail.status === 'canceled' || detail.current_step === 'generation_canceled') {
return null
}
if (detail.story_id) { if (detail.story_id) {
return detail.story_id return detail.story_id
} }
if (detail.status === 'canceled' || detail.current_step === 'generation_canceled') {
return null
}
if (detail.is_terminal) { if (detail.is_terminal) {
throw new Error(detail.error_message || '生成失败,请稍后重试') throw new Error(detail.error_message || '生成失败,请稍后重试')
} }

View File

@@ -37,11 +37,7 @@ const latestJob = computed(() => jobHistory.value[0] ?? null)
const activeJobEvents = computed(() => activeJob.value?.events.slice(-10) ?? []) const activeJobEvents = computed(() => activeJob.value?.events.slice(-10) ?? [])
const activeProgress = computed(() => activeJob.value?.progress_percent ?? latestJob.value?.progress_percent ?? 0) const activeProgress = computed(() => activeJob.value?.progress_percent ?? latestJob.value?.progress_percent ?? 0)
const activeProgressLabel = computed(() => activeJob.value?.progress_label ?? latestJob.value?.progress_label ?? '暂无进度') const activeProgressLabel = computed(() => activeJob.value?.progress_label ?? latestJob.value?.progress_label ?? '暂无进度')
const shouldAutoRefresh = computed(() => { const shouldAutoRefresh = computed(() => Boolean(latestJob.value && !latestJob.value.is_terminal))
if (activeJob.value) return !activeJob.value.is_terminal
if (latestJob.value) return !latestJob.value.is_terminal
return false
})
const providerSuccessRate = computed(() => { const providerSuccessRate = computed(() => {
if (!providerStats.value?.total_calls) return null if (!providerStats.value?.total_calls) return null
return Math.round((providerStats.value.successful_calls / providerStats.value.total_calls) * 100) return Math.round((providerStats.value.successful_calls / providerStats.value.total_calls) * 100)
@@ -186,6 +182,7 @@ async function refresh() {
} }
error.value = '' error.value = ''
const selectedJobId = activeJob.value?.id ?? null
try { try {
const [jobs, stats] = await Promise.all([ const [jobs, stats] = await Promise.all([
@@ -194,7 +191,11 @@ async function refresh() {
]) ])
jobHistory.value = jobs jobHistory.value = jobs
providerStats.value = stats providerStats.value = stats
const nextJobId = jobHistory.value[0]?.id const nextJobId = (
selectedJobId
? jobHistory.value.find((job) => job.id === selectedJobId)?.id
: null
) ?? jobHistory.value[0]?.id
if (nextJobId) { if (nextJobId) {
await selectGenerationJob(nextJobId) await selectGenerationJob(nextJobId)
} else { } else {

View File

@@ -1,5 +1,6 @@
<script setup lang="ts"> <script setup lang="ts">
import { XMarkIcon, CommandLineIcon } from '@heroicons/vue/24/outline' import { XMarkIcon, CommandLineIcon } from '@heroicons/vue/24/outline'
import { buildAuthSigninUrl } from '../../utils/auth'
defineProps<{ defineProps<{
modelValue: boolean modelValue: boolean
@@ -13,18 +14,18 @@ function close() {
emit('update:modelValue', false) emit('update:modelValue', false)
} }
function loginWithGithub() { function loginWithGithub() {
window.location.href = '/auth/github/signin' window.location.href = buildAuthSigninUrl('github')
} }
function loginWithGoogle() { function loginWithGoogle() {
window.location.href = '/auth/google/signin' window.location.href = buildAuthSigninUrl('google')
} }
function loginWithDev() { function loginWithDev() {
window.location.href = '/auth/dev/signin' window.location.href = buildAuthSigninUrl('dev')
} }
</script> </script>
<template> <template>
<Teleport to="body"> <Teleport to="body">

View File

@@ -1,6 +1,7 @@
import { defineStore } from 'pinia' import { defineStore } from 'pinia'
import { ref } from 'vue' import { ref } from 'vue'
import { api } from '../api/client' import { api } from '../api/client'
import { buildAuthSigninUrl } from '../utils/auth'
interface User { interface User {
id: string id: string
@@ -25,13 +26,13 @@ export const useUserStore = defineStore('user', () => {
} }
} }
function loginWithGithub() { function loginWithGithub() {
window.location.href = '/auth/github/signin' window.location.href = buildAuthSigninUrl('github')
} }
function loginWithGoogle() { function loginWithGoogle() {
window.location.href = '/auth/google/signin' window.location.href = buildAuthSigninUrl('google')
} }
async function logout() { async function logout() {
await api.post('/auth/signout') await api.post('/auth/signout')

View File

@@ -0,0 +1,8 @@
type AuthProvider = 'github' | 'google' | 'dev'
const DEFAULT_POST_LOGIN_PATH = '/my-stories'
export function buildAuthSigninUrl(provider: AuthProvider): string {
const next = new URL(DEFAULT_POST_LOGIN_PATH, window.location.origin).toString()
return `/auth/${provider}/signin?next=${encodeURIComponent(next)}`
}

View File

@@ -5,6 +5,7 @@ import { useI18n } from 'vue-i18n'
import { useUserStore } from '../stores/user' import { useUserStore } from '../stores/user'
import { api } from '../api/client' import { api } from '../api/client'
import type { VoiceSessionAnalytics, VoiceSessionSummary } from '../types/voiceSession' import type { VoiceSessionAnalytics, VoiceSessionSummary } from '../types/voiceSession'
import { getVoiceSessionNextAction } from '../utils/voiceSession'
import BaseButton from '../components/ui/BaseButton.vue' import BaseButton from '../components/ui/BaseButton.vue'
import LoginDialog from '../components/ui/LoginDialog.vue' import LoginDialog from '../components/ui/LoginDialog.vue'
import { import {
@@ -30,6 +31,9 @@ const showLoginDialog = ref(false)
const activeVoiceSession = ref<VoiceSessionSummary | null>(null) const activeVoiceSession = ref<VoiceSessionSummary | null>(null)
const voiceAnalytics = ref<VoiceSessionAnalytics | null>(null) const voiceAnalytics = ref<VoiceSessionAnalytics | null>(null)
type VoiceAttentionReason = 'pending_confirmation' | 'safety_intervention' | 'failed_turn'
type VoiceStudioFocusTarget = 'confirmation' | 'safety' | 'failed' | 'text'
// ========== 创作入口 ========== // ========== 创作入口 ==========
// 旧的创作变量已移除,现在只负责跳转 // 旧的创作变量已移除,现在只负责跳转
function openCreateModal() { function openCreateModal() {
@@ -54,7 +58,23 @@ function continueVoiceStudio() {
openVoiceStudio() openVoiceStudio()
return return
} }
router.push('/voice-studio') const action = getVoiceSessionNextAction(activeVoiceSession.value)
if (action.storyId) {
router.push(`/story/${action.storyId}`)
return
}
const query: Record<string, string> = {
session: activeVoiceSession.value.id,
}
if (action.reason) {
query.filter = 'attention'
query.reason = action.reason as VoiceAttentionReason
}
if (action.focus) {
query.focus = action.focus as VoiceStudioFocusTarget
}
router.push({ path: '/voice-studio', query })
} }
async function loadActiveVoiceSession() { async function loadActiveVoiceSession() {

View File

@@ -537,6 +537,10 @@ function resolveDisplayedSessions(sourceSessions: VoiceSessionSummary[]) {
return sortDisplayedSessions(visibleSessions) return sortDisplayedSessions(visibleSessions)
} }
function isSessionVisibleInCurrentFilter(sessionId: string) {
return resolveDisplayedSessions(sessions.value).some((session) => session.id === sessionId)
}
function parseSessionFilter(value: unknown): SessionFilter | null { function parseSessionFilter(value: unknown): SessionFilter | null {
if (value === 'active' || value === 'attention' || value === 'recent') { if (value === 'active' || value === 'attention' || value === 'recent') {
return value return value
@@ -743,10 +747,17 @@ async function loadSessions() {
const previousActiveSession = activeSession.value const previousActiveSession = activeSession.value
sessions.value = await api.get<VoiceSessionSummary[]>(buildVoiceSessionListPath()) sessions.value = await api.get<VoiceSessionSummary[]>(buildVoiceSessionListPath())
const displayedSessions = resolveDisplayedSessions(sessions.value) const displayedSessions = resolveDisplayedSessions(sessions.value)
const hiddenRequestedSession = requestedSessionId.value
? sessions.value.find((item) => item.id === requestedSessionId.value) ?? null
: null
const hiddenCurrentSession = previousActiveSession
? sessions.value.find((item) => item.id === previousActiveSession.id) ?? null
: null
if ( if (
(requestedSessionId.value || pendingFocusTarget.value) (requestedSessionId.value || pendingFocusTarget.value)
&& requestedSessionId.value && requestedSessionId.value
&& !displayedSessions.some((item) => item.id === requestedSessionId.value) && !displayedSessions.some((item) => item.id === requestedSessionId.value)
&& !hiddenRequestedSession
) { ) {
void syncVoiceStudioRouteState() void syncVoiceStudioRouteState()
} }
@@ -754,7 +765,10 @@ async function loadSessions() {
requestedSessionId.value requestedSessionId.value
? displayedSessions.find((item) => item.id === requestedSessionId.value) ? displayedSessions.find((item) => item.id === requestedSessionId.value)
: null : null
) ?? displayedSessions.find((item) => item.can_continue) ?? displayedSessions[0] ) ?? displayedSessions.find((item) => item.can_continue) ?? displayedSessions[0] ?? hiddenRequestedSession ?? hiddenCurrentSession
const currentSessionStillAvailable = activeSession.value
? sessions.value.some((item) => item.id === activeSession.value?.id)
: false
const currentSessionStillVisible = activeSession.value const currentSessionStillVisible = activeSession.value
? displayedSessions.some((item) => item.id === activeSession.value?.id) ? displayedSessions.some((item) => item.id === activeSession.value?.id)
: false : false
@@ -782,6 +796,9 @@ async function loadSessions() {
} }
await loadSessionDetail(preferredSession.id) await loadSessionDetail(preferredSession.id)
} else if (sessionFilter.value !== 'recent') { } else if (sessionFilter.value !== 'recent') {
if (currentSessionStillAvailable) {
return
}
if ( if (
sessionFilter.value === 'attention' sessionFilter.value === 'attention'
&& previousActiveSession && previousActiveSession
@@ -815,6 +832,16 @@ async function loadLatestActiveSession() {
try { try {
const session = await api.get<VoiceSessionSummary | null>('/api/voice-sessions/active') const session = await api.get<VoiceSessionSummary | null>('/api/voice-sessions/active')
if (session) { if (session) {
if (
!requestedSessionId.value
&& !route.query.filter
&& session.attention_reasons.length > 0
) {
const action = getVoiceSessionNextAction(session)
sessionFilter.value = 'attention'
attentionReasonFilter.value = action.reason ?? 'all'
pendingFocusTarget.value = action.focus ?? pendingFocusTarget.value
}
await loadSessionDetail(session.id) await loadSessionDetail(session.id)
} }
} catch { } catch {
@@ -845,13 +872,18 @@ function stopSessionPolling() {
function startSessionPolling() { function startSessionPolling() {
if (!activeSession.value?.id || sessionPollTimer) return if (!activeSession.value?.id || sessionPollTimer) return
sessionPollTimer = window.setInterval(() => { sessionPollTimer = window.setInterval(() => {
if (activeSession.value?.id) { const sessionId = activeSession.value?.id
void loadSessionDetail(activeSession.value.id) if (sessionId) {
void loadSessions() void refreshVisibleSessionState(sessionId)
} }
}, sessionPollIntervalMs) }, sessionPollIntervalMs)
} }
async function refreshVisibleSessionState(sessionId: string) {
await loadSessionDetail(sessionId)
await loadSessions()
}
async function createSession() { async function createSession() {
creatingSession.value = true creatingSession.value = true
error.value = '' error.value = ''
@@ -942,11 +974,12 @@ async function submitRecordedTurn() {
async function finalizeSession() { async function finalizeSession() {
if (!activeSession.value) return if (!activeSession.value) return
const sessionId = activeSession.value.id
finalizing.value = true finalizing.value = true
error.value = '' error.value = ''
try { try {
await api.post<VoiceSessionFinalizeResponse>( const result = await api.post<VoiceSessionFinalizeResponse>(
`/api/voice-sessions/${activeSession.value.id}/finalize`, `/api/voice-sessions/${sessionId}/finalize`,
{ {
save_story: true, save_story: true,
generate_cover: true, generate_cover: true,
@@ -954,7 +987,12 @@ async function finalizeSession() {
}, },
) )
await loadSessions() await loadSessions()
await loadSessionDetail(activeSession.value.id) if (isSessionVisibleInCurrentFilter(sessionId)) {
await loadSessionDetail(sessionId)
} else if (result.story_id) {
router.push(`/story/${result.story_id}`)
return
}
await loadVoiceAnalytics() await loadVoiceAnalytics()
} catch (err) { } catch (err) {
error.value = err instanceof Error ? err.message : '保存语音共创故事失败' error.value = err instanceof Error ? err.message : '保存语音共创故事失败'
@@ -1026,17 +1064,17 @@ async function resolveTurnConfirmation(turn: VoiceTurnSummary, action: 'accept'
async function abandonSession() { async function abandonSession() {
if (!activeSession.value) return if (!activeSession.value) return
const sessionId = activeSession.value.id
abandoning.value = true abandoning.value = true
error.value = '' error.value = ''
try { try {
const summary = await api.post<VoiceSessionSummary>( await api.post<VoiceSessionSummary>(
`/api/voice-sessions/${activeSession.value.id}/abandon`, `/api/voice-sessions/${sessionId}/abandon`,
{ reason: '用户在语音共创页主动结束会话' }, { reason: '用户在语音共创页主动结束会话' },
) )
await loadSessions() await loadSessions()
activeSession.value = { if (isSessionVisibleInCurrentFilter(sessionId)) {
...(activeSession.value as VoiceSessionDetail), await loadSessionDetail(sessionId)
...summary,
} }
} catch (err) { } catch (err) {
error.value = err instanceof Error ? err.message : '放弃会话失败' error.value = err instanceof Error ? err.message : '放弃会话失败'