fix: stabilize auth and generation workflows

This commit is contained in:
2026-04-23 22:31:14 +08:00
parent 4db04e61e9
commit 7e450aa5fc
16 changed files with 335 additions and 127 deletions

View File

@@ -1,8 +1,8 @@
import secrets
from urllib.parse import urlencode
from urllib.parse import quote, unquote, urlencode, urlparse
import httpx
from fastapi import APIRouter, Cookie, Depends, HTTPException, Query
from fastapi import APIRouter, Cookie, Depends, HTTPException, Query, Response
from fastapi.responses import RedirectResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -26,6 +26,8 @@ GOOGLE_USER_URL = "https://www.googleapis.com/oauth2/v2/userinfo"
STATE_COOKIE = "oauth_state"
STATE_MAX_AGE = 600 # 10 minutes
NEXT_COOKIE = "oauth_next"
NEXT_MAX_AGE = 600 # 10 minutes
def _set_state_cookie(response: RedirectResponse, provider: str, state: str) -> None:
@@ -39,6 +41,53 @@ def _set_state_cookie(response: RedirectResponse, provider: str, state: str) ->
)
def _is_allowed_frontend_redirect(url: str | None) -> bool:
if not url:
return False
parsed = urlparse(url)
if not parsed.scheme or not parsed.netloc:
return False
origin = f"{parsed.scheme}://{parsed.netloc}"
return origin in settings.cors_origins
def _set_next_cookie(response: RedirectResponse, next_url: str | None) -> None:
if not _is_allowed_frontend_redirect(next_url):
return
response.set_cookie(
key=NEXT_COOKIE,
value=quote(next_url or "", safe=""),
httponly=True,
secure=not settings.debug,
samesite="lax",
max_age=NEXT_MAX_AGE,
)
def _decode_next_cookie(next_cookie: str | None) -> str | None:
if not next_cookie:
return None
return unquote(next_cookie)
def _build_default_frontend_redirect(path: str = "/my-stories") -> str:
frontend_origin = settings.cors_origins[0] if settings.cors_origins else "http://localhost:5173"
return f"{frontend_origin.rstrip('/')}{path}"
def _resolve_frontend_redirect(
next_url: str | None,
*,
fallback_path: str = "/my-stories",
) -> str:
if _is_allowed_frontend_redirect(next_url):
return str(next_url)
return _build_default_frontend_redirect(fallback_path)
def _validate_state(state_from_query: str | None, state_cookie: str | None, provider: str):
if not state_from_query or not state_cookie:
raise HTTPException(status_code=400, detail="Missing OAuth state")
@@ -51,7 +100,7 @@ def _validate_state(state_from_query: str | None, state_cookie: str | None, prov
@router.get("/github/signin")
async def github_signin():
async def github_signin(next: str | None = Query(default=None)):
"""Start GitHub OAuth with state protection."""
state = secrets.token_urlsafe(16)
params = {
@@ -63,6 +112,7 @@ async def github_signin():
url = f"{GITHUB_AUTHORIZE_URL}?{urlencode(params)}"
response = RedirectResponse(url=url)
_set_state_cookie(response, "github", state)
_set_next_cookie(response, next)
return response
@@ -71,6 +121,7 @@ async def github_callback(
code: str,
state: str | None = Query(default=None),
state_cookie: str | None = Cookie(default=None, alias=STATE_COOKIE),
next_cookie: str | None = Cookie(default=None, alias=NEXT_COOKIE),
db: AsyncSession = Depends(get_db),
):
"""Handle GitHub OAuth callback."""
@@ -112,11 +163,12 @@ async def github_callback(
user_id=str(github_id),
name=user_data.get("name") or user_data.get("login") or "GitHub User",
avatar_url=user_data.get("avatar_url"),
next_url=_decode_next_cookie(next_cookie),
)
@router.get("/google/signin")
async def google_signin():
async def google_signin(next: str | None = Query(default=None)):
"""Start Google OAuth with state protection."""
state = secrets.token_urlsafe(16)
params = {
@@ -129,6 +181,7 @@ async def google_signin():
url = f"{GOOGLE_AUTHORIZE_URL}?{urlencode(params)}"
response = RedirectResponse(url=url)
_set_state_cookie(response, "google", state)
_set_next_cookie(response, next)
return response
@@ -137,6 +190,7 @@ async def google_callback(
code: str,
state: str | None = Query(default=None),
state_cookie: str | None = Cookie(default=None, alias=STATE_COOKIE),
next_cookie: str | None = Cookie(default=None, alias=NEXT_COOKIE),
db: AsyncSession = Depends(get_db),
):
"""Handle Google OAuth callback."""
@@ -179,6 +233,7 @@ async def google_callback(
user_id=str(google_id),
name=user_data.get("name") or user_data.get("email") or "Google User",
avatar_url=user_data.get("picture"),
next_url=_decode_next_cookie(next_cookie),
)
@@ -188,6 +243,7 @@ async def _handle_oauth_user(
user_id: str,
name: str,
avatar_url: str | None,
next_url: str | None = None,
) -> RedirectResponse:
"""Create/update user and issue session cookie."""
full_id = f"{provider}:{user_id}"
@@ -211,11 +267,10 @@ async def _handle_oauth_user(
token = create_access_token({"sub": user.id})
frontend_url = "http://localhost:5173"
if settings.cors_origins and len(settings.cors_origins) > 0:
frontend_url = settings.cors_origins[0]
response = RedirectResponse(url=f"{frontend_url}/my-stories", status_code=302)
response = RedirectResponse(
url=_resolve_frontend_redirect(next_url, fallback_path="/my-stories"),
status_code=302,
)
response.set_cookie(
key="access_token",
value=token,
@@ -225,15 +280,17 @@ async def _handle_oauth_user(
max_age=60 * 60 * 24 * 7, # align with ACCESS_TOKEN_EXPIRE_DAYS
)
response.delete_cookie(STATE_COOKIE)
response.delete_cookie(NEXT_COOKIE)
return response
@router.post("/signout")
@router.post("/signout", status_code=204)
async def signout():
"""Sign out and clear cookies."""
response = RedirectResponse(url=settings.cors_origins[0], status_code=302)
response = Response(status_code=204)
response.delete_cookie("access_token", samesite="lax", secure=not settings.debug)
response.delete_cookie(STATE_COOKIE, samesite="lax", secure=not settings.debug)
response.delete_cookie(NEXT_COOKIE, samesite="lax", secure=not settings.debug)
return response
@@ -253,7 +310,10 @@ async def get_session(user: User | None = Depends(get_current_user)):
@router.get("/dev/signin")
async def dev_signin(db: AsyncSession = Depends(get_db)):
async def dev_signin(
next: str | None = Query(default=None),
db: AsyncSession = Depends(get_db),
):
"""Developer backdoor login. Only works in DEBUG mode."""
if not settings.debug:
raise HTTPException(status_code=403, detail="Developer login disabled")
@@ -264,7 +324,8 @@ async def dev_signin(db: AsyncSession = Depends(get_db)):
provider="github",
user_id="dev_user_001",
name="Developer",
avatar_url="https://api.dicebear.com/7.x/avataaars/svg?seed=Developer"
avatar_url="https://api.dicebear.com/7.x/avataaars/svg?seed=Developer",
next_url=next,
)
except Exception as e:
import traceback

View File

@@ -1,8 +1,9 @@
"""认证相关测试。"""
"""认证相关测试。"""
from fastapi.testclient import TestClient
from app.core.security import create_access_token, decode_access_token
from app.core.config import settings
from app.core.security import create_access_token, decode_access_token
class TestJWT:
@@ -55,10 +56,43 @@ class TestSession:
assert data["user"] is None
class TestSignout:
"""登出测试。"""
def test_signout(self, auth_client: TestClient):
"""测试登出。"""
response = auth_client.post("/auth/signout", follow_redirects=False)
assert response.status_code == 302
class TestSignout:
"""登出测试。"""
def test_signout(self, auth_client: TestClient):
"""测试登出。"""
response = auth_client.post("/auth/signout")
assert response.status_code == 204
assert response.content == b""
set_cookie_headers = response.headers.get_list("set-cookie")
assert any("access_token=" in value for value in set_cookie_headers)
class TestDevSigninRedirect:
"""开发登录重定向测试。"""
def test_dev_signin_uses_allowed_next_url(self, client: TestClient, monkeypatch):
"""允许的 next 参数应作为登录完成后的回跳地址。"""
monkeypatch.setattr(settings, "cors_origins", ["http://localhost:5173", "http://localhost:5174"])
response = client.get(
"/auth/dev/signin",
params={"next": "http://localhost:5174/console/providers"},
follow_redirects=False,
)
assert response.status_code == 302
assert response.headers["location"] == "http://localhost:5174/console/providers"
def test_dev_signin_rejects_untrusted_next_url(self, client: TestClient, monkeypatch):
"""不可信的 next 参数应回退到默认前端地址,避免开放重定向。"""
monkeypatch.setattr(settings, "cors_origins", ["http://localhost:5173", "http://localhost:5174"])
response = client.get(
"/auth/dev/signin",
params={"next": "https://evil.example/steal"},
follow_redirects=False,
)
assert response.status_code == 302
assert response.headers["location"] == "http://localhost:5173/my-stories"