Skip to content

FastAPI

目录


1. 简介

FastAPI是一个现代、快速(高性能)的Web框架,用于构建API,基于Python 3.6+的类型提示。它具有以下特点:

  • 快速:极高的性能,与NodeJS和Go相当
  • 自动API文档:自动生成交互式API文档(Swagger UI和ReDoc)
  • 类型提示:基于Python类型提示的请求验证和自动文档
  • 依赖注入:内置依赖注入系统
  • 安全:自动处理CORS、CSRF等安全特性
  • 标准化:基于OpenAPI标准

2. 安装

2.1 基本安装

使用pip安装FastAPI及其依赖:

bash
pip install fastapi uvicorn
  • fastapi:FastAPI框架本身
  • uvicorn:ASGI服务器,用于运行FastAPI应用

2.2 可选依赖

根据需要安装其他依赖:

bash
# 用于数据库集成
pip install sqlalchemy pymongo

# 用于认证
pip install python-jose[cryptography] passlib[bcrypt]

# 用于CORS
pip install python-multipart

3. 快速开始

3.1 创建第一个FastAPI应用

功能说明:创建一个简单的FastAPI应用,包含两个路由:根路径和带路径参数的路径。

代码示例

python
from fastapi import FastAPI

# 创建FastAPI应用实例
app = FastAPI()

# 定义路由
@app.get("/")
def read_root():
    return {"message": "Hello World"}

# 定义带路径参数的路由
@app.get("/items/{item_id}")
def read_item(item_id: int, q: str = None):
    return {"item_id": item_id, "q": q}

代码解析

  • from fastapi import FastAPI:导入FastAPI类
  • app = FastAPI():创建FastAPI应用实例
  • @app.get("/"):定义GET请求的根路径路由
  • def read_root():处理根路径请求的函数,返回一个字典
  • @app.get("/items/{item_id}"):定义带路径参数的路由
  • def read_item(item_id: int, q: str = None):处理带路径参数的请求,item_id是路径参数,q是可选的查询参数

使用方法

  1. 将代码保存为main.py文件
  2. 运行uvicorn main:app --reload启动应用
  3. 访问http://localhost:8000/查看根路径响应
  4. 访问http://localhost:8000/items/42查看带路径参数的响应
  5. 访问http://localhost:8000/items/42?q=test查看带查询参数的响应

3.2 运行应用

功能说明:使用Uvicorn ASGI服务器运行FastAPI应用。

运行命令

bash
uvicorn main:app --reload

参数说明

  • main:Python文件名(不含.py扩展名)
  • app:FastAPI应用实例名
  • --reload:开发模式,代码修改后自动重启

注意事项

  • 开发环境使用--reload参数,生产环境不要使用
  • 生产环境应该指定主机和端口,例如:uvicorn main:app --host 0.0.0.0 --port 8000 --workers 4

3.3 访问API文档

功能说明:FastAPI自动生成交互式API文档。

访问地址

使用方法

  1. 启动应用后,打开浏览器访问上述地址
  2. 在Swagger UI中,可以直接测试API端点
  3. 在ReDoc中,可以查看API的详细文档

4. 路由

4.1 基本路由

功能说明:定义不同HTTP方法的路由,包括GET、POST、PUT和DELETE。

代码示例

python
from fastapi import FastAPI

app = FastAPI()

# GET请求
@app.get("/items")
def get_items():
    return {"items": ["item1", "item2"]}

# POST请求
@app.post("/items")
def create_item(item: dict):
    return {"item": item}

# PUT请求
@app.put("/items/{item_id}")
def update_item(item_id: int, item: dict):
    return {"item_id": item_id, "item": item}

# DELETE请求
@app.delete("/items/{item_id}")
def delete_item(item_id: int):
    return {"message": f"Item {item_id} deleted"}

代码解析

  • @app.get("/items"):定义GET请求的路由,用于获取资源
  • @app.post("/items"):定义POST请求的路由,用于创建资源
  • @app.put("/items/{item_id}"):定义PUT请求的路由,用于更新资源
  • @app.delete("/items/{item_id}"):定义DELETE请求的路由,用于删除资源
  • item: dict:请求体参数,类型为字典
  • item_id: int:路径参数,类型为整数

使用方法

  1. GET请求:curl http://localhost:8000/items
  2. POST请求:curl -X POST http://localhost:8000/items -H "Content-Type: application/json" -d '{"name": "item1", "price": 10.0}'
  3. PUT请求:curl -X PUT http://localhost:8000/items/1 -H "Content-Type: application/json" -d '{"name": "item1", "price": 20.0}'
  4. DELETE请求:curl -X DELETE http://localhost:8000/items/1

4.2 路径参数

功能说明:定义带路径参数的路由,并对路径参数进行类型转换和验证。

代码示例

python
from fastapi import FastAPI

app = FastAPI()

# 路径参数类型转换
@app.get("/items/{item_id}")
def read_item(item_id: int):
    return {"item_id": item_id}

# 路径参数验证
from fastapi import Path

@app.get("/items/{item_id}")
def read_item(
    item_id: int = Path(..., title="Item ID", ge=1, le=1000)
):
    return {"item_id": item_id}

代码解析

  • item_id: int:路径参数,类型为整数,FastAPI会自动进行类型转换
  • Path(..., title="Item ID", ge=1, le=1000):使用Path类对路径参数进行验证,ge表示大于等于,le表示小于等于
  • ...:表示该参数是必填的

使用方法

  1. 访问http://localhost:8000/items/42,返回{"item_id": 42}
  2. 访问http://localhost:8000/items/0,会返回422错误,因为item_id必须大于等于1
  3. 访问http://localhost:8000/items/1001,会返回422错误,因为item_id必须小于等于1000

4.3 查询参数

功能说明:定义带查询参数的路由,并对查询参数进行验证。

代码示例

python
from fastapi import FastAPI, Query

app = FastAPI()

# 可选查询参数
@app.get("/items")
def read_items(q: str = None):
    return {"q": q}

# 查询参数验证
@app.get("/items")
def read_items(
    q: str = Query(None, min_length=3, max_length=50, regex="^[a-z]+$"),
    skip: int = Query(0, ge=0),
    limit: int = Query(100, le=100)
):
    return {"q": q, "skip": skip, "limit": limit}

代码解析

  • q: str = None:可选查询参数,默认值为None
  • Query(None, min_length=3, max_length=50, regex="^[a-z]+$"):使用Query类对查询参数进行验证,min_length表示最小长度,max_length表示最大长度,regex表示正则表达式
  • skip: int = Query(0, ge=0):查询参数,默认值为0,必须大于等于0
  • limit: int = Query(100, le=100):查询参数,默认值为100,必须小于等于100

使用方法

  1. 访问http://localhost:8000/items,返回{"q": null, "skip": 0, "limit": 100}
  2. 访问http://localhost:8000/items?q=test&skip=10&limit=50,返回{"q": "test", "skip": 10, "limit": 50}
  3. 访问http://localhost:8000/items?q=te,会返回422错误,因为q的长度必须大于等于3
  4. 访问http://localhost:8000/items?q=TEST,会返回422错误,因为q必须匹配正则表达式^[a-z]+$

5. 请求体

5.1 使用Pydantic模型

功能说明:使用Pydantic模型定义请求体和响应模型,实现数据验证和类型检查。

代码示例

python
from fastapi import FastAPI
from pydantic import BaseModel

app = FastAPI()

# 定义请求体模型
class Item(BaseModel):
    name: str
    description: str = None
    price: float
    tax: float = None

# 使用请求体
@app.post("/items")
def create_item(item: Item):
    return item

# 响应模型
@app.post("/items", response_model=Item)
def create_item(item: Item):
    return item

代码解析

  • from pydantic import BaseModel:导入Pydantic的BaseModel类
  • class Item(BaseModel):定义请求体模型,继承自BaseModel
  • name: str:必填字段,类型为字符串
  • description: str = None:可选字段,类型为字符串,默认值为None
  • price: float:必填字段,类型为浮点数
  • tax: float = None:可选字段,类型为浮点数,默认值为None
  • item: Item:请求体参数,类型为Item模型
  • response_model=Item:指定响应模型为Item,FastAPI会自动过滤响应中不在模型中的字段

使用方法

  1. 发送POST请求到http://localhost:8000/items,请求体为:
    json
    {
      "name": "Item 1",
      "description": "This is item 1",
      "price": 10.0,
      "tax": 1.0
    }
  2. 响应为:
    json
    {
      "name": "Item 1",
      "description": "This is item 1",
      "price": 10.0,
      "tax": 1.0
    }
  3. 如果发送的请求体缺少必填字段或类型错误,会返回422错误

5.2 嵌套模型

功能说明:使用嵌套的Pydantic模型定义复杂的请求体结构。

代码示例

python
from pydantic import BaseModel

class User(BaseModel):
    username: str
    full_name: str = None

class Item(BaseModel):
    name: str
    description: str = None
    price: float
    tax: float = None
    user: User  # 嵌套模型

@app.post("/items")
def create_item(item: Item):
    return item

代码解析

  • class User(BaseModel):定义User模型
  • class Item(BaseModel):定义Item模型,包含一个User类型的字段
  • user: User:嵌套模型字段,类型为User

使用方法

  1. 发送POST请求到http://localhost:8000/items,请求体为:
    json
    {
      "name": "Item 1",
      "description": "This is item 1",
      "price": 10.0,
      "tax": 1.0,
      "user": {
        "username": "alice",
        "full_name": "Alice Smith"
      }
    }
  2. 响应为:
    json
    {
      "name": "Item 1",
      "description": "This is item 1",
      "price": 10.0,
      "tax": 1.0,
      "user": {
        "username": "alice",
        "full_name": "Alice Smith"
      }
    }
  3. 如果嵌套模型中的字段不符合要求,会返回422错误

6. 依赖注入

6.1 基本依赖

功能说明:使用依赖注入系统管理共享资源,如数据库连接。

代码示例

python
from fastapi import FastAPI, Depends

app = FastAPI()

# 依赖函数
def get_db():
    db = "Database connection"
    try:
        yield db
    finally:
        print("Closing database connection")

# 使用依赖
@app.get("/items")
def get_items(db: str = Depends(get_db)):
    return {"db": db, "items": ["item1", "item2"]}

代码解析

  • from fastapi import Depends:导入Depends函数
  • def get_db():定义依赖函数,使用yield语句返回资源
  • yield db:返回数据库连接
  • finally:无论是否发生异常,都会执行的代码块,用于关闭数据库连接
  • db: str = Depends(get_db):使用Depends函数注入依赖,参数类型为str

使用方法

  1. 访问http://localhost:8000/items,会调用get_db()函数获取数据库连接
  2. 响应为:{"db": "Database connection", "items": ["item1", "item2"]}
  3. 访问完成后,会执行finally块中的代码,打印"Closing database connection"

6.2 类依赖

功能说明:使用类作为依赖,提供更复杂的依赖管理。

代码示例

python
from fastapi import FastAPI, Depends

app = FastAPI()

class Database:
    def __init__(self):
        self.connection = "Database connection"

    def close(self):
        print("Closing database connection")

def get_db():
    db = Database()
    try:
        yield db
    finally:
        db.close()

@app.get("/items")
def get_items(db: Database = Depends(get_db)):
    return {"db_connection": db.connection, "items": ["item1", "item2"]}

代码解析

  • class Database:定义数据库类,包含connection属性和close方法
  • def __init__(self):初始化方法,设置connection属性
  • def close(self):关闭数据库连接的方法
  • def get_db():依赖函数,创建Database实例并返回
  • db: Database = Depends(get_db):使用Depends函数注入依赖,参数类型为Database

使用方法

  1. 访问http://localhost:8000/items,会调用get_db()函数获取Database实例
  2. 响应为:{"db_connection": "Database connection", "items": ["item1", "item2"]}
  3. 访问完成后,会执行finally块中的代码,调用db.close()方法,打印"Closing database connection"

7. 认证和授权

7.1 JWT认证

功能说明:使用JWT(JSON Web Token)实现认证和授权。

代码示例

python
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from passlib.context import CryptContext
from datetime import datetime, timedelta

app = FastAPI()

# 配置
SECRET_KEY = "your-secret-key"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30

# 密码加密
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

# OAuth2密码承载令牌
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

# 模拟用户数据库
users_db = {
    "alice": {
        "username": "alice",
        "hashed_password": pwd_context.hash("secret"),
        "disabled": False,
    }
}

# 验证密码
def verify_password(plain_password, hashed_password):
    return pwd_context.verify(plain_password, hashed_password)

# 获取用户
def get_user(db, username: str):
    if username in db:
        user_dict = db[username]
        return user_dict

# 认证用户
def authenticate_user(fake_db, username: str, password: str):
    user = get_user(fake_db, username)
    if not user:
        return False
    if not verify_password(password, user["hashed_password"]):
        return False
    return user

# 创建访问令牌
def create_access_token(data: dict, expires_delta: timedelta = None):
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(minutes=15)
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
    return encoded_jwt

# 获取当前用户
async def get_current_user(token: str = Depends(oauth2_scheme)):
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
        if username is None:
            raise credentials_exception
    except JWTError:
        raise credentials_exception
    user = get_user(users_db, username=username)
    if user is None:
        raise credentials_exception
    return user

# 登录端点
@app.post("/token")
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
    user = authenticate_user(users_db, form_data.username, form_data.password)
    if not user:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Incorrect username or password",
            headers={"WWW-Authenticate": "Bearer"},
        )
    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    access_token = create_access_token(
        data={"sub": user["username"]}, expires_delta=access_token_expires
    )
    return {"access_token": access_token, "token_type": "bearer"}

# 需要认证的端点
@app.get("/users/me")
async def read_users_me(current_user: dict = Depends(get_current_user)):
    return current_user

代码解析

  • SECRET_KEY:用于签名JWT令牌的密钥
  • ALGORITHM:使用的JWT算法
  • ACCESS_TOKEN_EXPIRE_MINUTES:访问令牌的过期时间
  • pwd_context:用于密码加密和验证的上下文
  • oauth2_scheme:OAuth2密码承载令牌方案
  • users_db:模拟的用户数据库
  • verify_password:验证密码的函数
  • get_user:获取用户的函数
  • authenticate_user:认证用户的函数
  • create_access_token:创建访问令牌的函数
  • get_current_user:获取当前用户的依赖函数
  • @app.post("/token"):登录端点,返回访问令牌
  • @app.get("/users/me"):需要认证的端点,返回当前用户信息

使用方法

  1. 发送POST请求到http://localhost:8000/token,表单数据为:
    • username: alice
    • password: secret
  2. 响应为:
    json
    {
      "access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...",
      "token_type": "bearer"
    }
  3. 发送GET请求到http://localhost:8000/users/me,头部为:
    • Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...
  4. 响应为:
    json
    {
      "username": "alice",
      "hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW",
      "disabled": false
    }
  5. 如果未提供令牌或令牌无效,会返回401错误

8. 中间件

8.1 自定义中间件

功能说明:创建自定义中间件,用于处理请求和响应。

代码示例

python
from fastapi import FastAPI, Request
import time

app = FastAPI()

# 自定义中间件
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
    start_time = time.time()
    response = await call_next(request)
    process_time = time.time() - start_time
    response.headers["X-Process-Time"] = str(process_time)
    return response

@app.get("/")
def read_root():
    return {"message": "Hello World"}

代码解析

  • @app.middleware("http"):定义HTTP中间件
  • async def add_process_time_header(request: Request, call_next):中间件函数,接收请求和下一个处理函数
  • start_time = time.time():记录请求开始时间
  • response = await call_next(request):调用下一个处理函数,获取响应
  • process_time = time.time() - start_time:计算处理时间
  • response.headers["X-Process-Time"] = str(process_time):添加处理时间到响应头部
  • return response:返回响应

使用方法

  1. 启动应用
  2. 访问http://localhost:8000/
  3. 查看响应头部,会包含X-Process-Time字段,表示请求处理时间

8.2 CORS中间件

功能说明:配置CORS(跨域资源共享)中间件,允许跨域请求。

代码示例

python
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

app = FastAPI()

# 配置CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # 在生产环境中应该设置具体的域名
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/")
def read_root():
    return {"message": "Hello World"}

代码解析

  • from fastapi.middleware.cors import CORSMiddleware:导入CORS中间件
  • app.add_middleware(CORSMiddleware, ...):添加CORS中间件
  • allow_origins=["*"]:允许所有来源,生产环境应该设置具体的域名
  • allow_credentials=True:允许携带凭证
  • allow_methods=["*"]:允许所有HTTP方法
  • allow_headers=["*"]:允许所有HTTP头部

使用方法

  1. 启动应用
  2. 从不同域的前端应用发送请求到FastAPI应用
  3. 不会出现CORS错误,因为已经配置了CORS中间件

注意事项

  • 生产环境中,应该设置具体的allow_origins,而不是使用["*"]
  • 只有当allow_credentialsTrue时,allow_origins不能使用["*"],必须设置具体的域名

9. 数据库集成

9.1 SQLAlchemy集成

功能说明:使用SQLAlchemy ORM集成关系型数据库。

代码示例

python
from fastapi import FastAPI, Depends, HTTPException
from sqlalchemy import create_engine, Column, Integer, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session

# 数据库配置
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"

# 创建数据库引擎
engine = create_engine(
    SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)

# 创建会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

# 创建基类
Base = declarative_base()

# 定义模型
class Item(Base):
    __tablename__ = "items"
    id = Column(Integer, primary_key=True, index=True)
    name = Column(String, index=True)
    description = Column(String)

# 创建数据库表
Base.metadata.create_all(bind=engine)

# 依赖项
def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

app = FastAPI()

# 路由
@app.post("/items")
def create_item(name: str, description: str, db: Session = Depends(get_db)):
    db_item = Item(name=name, description=description)
    db.add(db_item)
    db.commit()
    db.refresh(db_item)
    return db_item

@app.get("/items")
def get_items(db: Session = Depends(get_db)):
    return db.query(Item).all()

@app.get("/items/{item_id}")
def get_item(item_id: int, db: Session = Depends(get_db)):
    item = db.query(Item).filter(Item.id == item_id).first()
    if not item:
        raise HTTPException(status_code=404, detail="Item not found")
    return item

代码解析

  • SQLALCHEMY_DATABASE_URL:数据库连接字符串
  • engine:数据库引擎
  • SessionLocal:会话工厂
  • Base:模型基类
  • class Item(Base):定义Item模型
  • Base.metadata.create_all(bind=engine):创建数据库表
  • def get_db():获取数据库会话的依赖函数
  • @app.post("/items"):创建Item的端点
  • @app.get("/items"):获取所有Item的端点
  • @app.get("/items/{item_id}"):获取单个Item的端点

使用方法

  1. 启动应用,会自动创建数据库表
  2. 发送POST请求到http://localhost:8000/items,创建Item
  3. 发送GET请求到http://localhost:8000/items,获取所有Item
  4. 发送GET请求到http://localhost:8000/items/{item_id},获取单个Item

9.2 MongoDB集成

功能说明:使用PyMongo集成MongoDB NoSQL数据库。

代码示例

python
from fastapi import FastAPI, HTTPException
from pymongo import MongoClient
from pydantic import BaseModel
from bson import ObjectId

# 连接MongoDB
client = MongoClient("mongodb://localhost:27017")
db = client["test_db"]
items_collection = db["items"]

app = FastAPI()

# 模型
class Item(BaseModel):
    name: str
    description: str = None
    price: float

# 处理ObjectId
class PyObjectId(ObjectId):
    @classmethod
    def __get_validators__(cls):
        yield cls.validate

    @classmethod
    def validate(cls, v):
        if not ObjectId.is_valid(v):
            raise ValueError("Invalid objectid")
        return ObjectId(v)

# 路由
@app.post("/items")
def create_item(item: Item):
    item_dict = item.dict()
    result = items_collection.insert_one(item_dict)
    return {"id": str(result.inserted_id), **item_dict}

@app.get("/items")
def get_items():
    items = []
    for item in items_collection.find():
        item["_id"] = str(item["_id"])
        items.append(item)
    return items

@app.get("/items/{item_id}")
def get_item(item_id: str):
    item = items_collection.find_one({"_id": ObjectId(item_id)})
    if not item:
        raise HTTPException(status_code=404, detail="Item not found")
    item["_id"] = str(item["_id"])
    return item

代码解析

  • client = MongoClient("mongodb://localhost:27017"):连接MongoDB
  • db = client["test_db"]:获取数据库
  • items_collection = db["items"]:获取集合
  • class Item(BaseModel):定义Item模型
  • class PyObjectId(ObjectId):处理ObjectId的验证
  • @app.post("/items"):创建Item的端点
  • @app.get("/items"):获取所有Item的端点
  • @app.get("/items/{item_id}"):获取单个Item的端点

使用方法

  1. 确保MongoDB服务正在运行
  2. 启动应用
  3. 发送POST请求到http://localhost:8000/items,创建Item
  4. 发送GET请求到http://localhost:8000/items,获取所有Item
  5. 发送GET请求到http://localhost:8000/items/{item_id},获取单个Item

10. 部署

10.1 使用Uvicorn

功能说明:使用Uvicorn ASGI服务器运行FastAPI应用。

运行命令

bash
# 开发模式
uvicorn main:app --reload

# 生产模式
uvicorn main:app --host 0.0.0.0 --port 8000 --workers 4

参数说明

  • main:Python文件名(不含.py扩展名)
  • app:FastAPI应用实例名
  • --reload:开发模式,代码修改后自动重启
  • --host 0.0.0.0:绑定到所有网络接口
  • --port 8000:监听端口
  • --workers 4:启动4个工作进程

使用方法

  1. 开发环境:运行uvicorn main:app --reload
  2. 生产环境:运行uvicorn main:app --host 0.0.0.0 --port 8000 --workers 4

10.2 使用Gunicorn

功能说明:使用Gunicorn作为进程管理器,运行Uvicorn工作进程。

安装和运行

bash
# 安装Gunicorn
pip install gunicorn

# 运行
Gunicorn -w 4 -k uvicorn.workers.UvicornWorker main:app

参数说明

  • -w 4:启动4个工作进程
  • -k uvicorn.workers.UvicornWorker:使用Uvicorn工作进程
  • main:app:Python文件名和FastAPI应用实例名

使用方法

  1. 安装Gunicorn:pip install gunicorn
  2. 运行:Gunicorn -w 4 -k uvicorn.workers.UvicornWorker main:app

10.3 Docker部署

功能说明:使用Docker容器部署FastAPI应用。

Dockerfile

dockerfile
FROM python:3.9

WORKDIR /app

COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

COPY . .

CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]

requirements.txt

fastapi
uvicorn

构建和运行

bash
# 构建镜像
docker build -t fastapi-app .

# 运行容器
docker run -d --name fastapi-app -p 8000:8000 fastapi-app

代码解析

  • FROM python:3.9:使用Python 3.9作为基础镜像
  • WORKDIR /app:设置工作目录
  • COPY requirements.txt .:复制requirements.txt文件
  • RUN pip install --no-cache-dir -r requirements.txt:安装依赖
  • COPY . .:复制应用代码
  • CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]:运行应用

使用方法

  1. 创建Dockerfile和requirements.txt文件
  2. 构建镜像:docker build -t fastapi-app .
  3. 运行容器:docker run -d --name fastapi-app -p 8000:8000 fastapi-app
  4. 访问http://localhost:8000/查看应用

11. 企业级开发

11.1 配置管理

功能说明:使用Pydantic V2的BaseSettings或pydantic-settings管理应用配置,支持环境变量、配置文件等多种配置来源。

代码示例

python
from pydantic_settings import BaseSettings
from typing import Optional

class Settings(BaseSettings):
    """应用配置"""
    # 基本配置
    app_name: str = "FastAPI App"
    debug: bool = False

    # 数据库配置
    db_url: str
    db_max_connections: int = 10
    db_min_connections: int = 2

    # 认证配置
    secret_key: str
    algorithm: str = "HS256"
    access_token_expire_minutes: int = 30

    # CORS配置
    backend_cors_origins: list[str] = ["*"]

    class Config:
        env_file = ".env"
        case_sensitive = False

# 创建配置实例
settings = Settings()

# 在应用中使用
from fastapi import FastAPI

app = FastAPI(title=settings.app_name, debug=settings.debug)

@app.get("/config")
def get_config():
    return {
        "app_name": settings.app_name,
        "debug": settings.debug,
        "db_url": settings.db_url
    }

代码解析

  • from pydantic_settings import BaseSettings:从pydantic-settings导入BaseSettings类
  • class Settings(BaseSettings):定义配置类,继承自BaseSettings
  • app_name: str = "FastAPI App":配置项,有默认值
  • db_url: str:必填配置项,无默认值
  • class Config:配置类的配置,指定环境变量文件和大小写敏感性
  • settings = Settings():创建配置实例,会自动从环境变量或.env文件加载配置

使用方法

  1. 安装依赖:pip install pydantic-settings
  2. 创建.env文件,添加配置:
    DB_URL="postgresql://user:password@localhost:5432/dbname"
    SECRET_KEY="your-secret-key"
  3. 运行应用,访问http://localhost:8000/config查看配置

11.2 日志系统

功能说明:配置结构化日志系统,支持日志级别控制、日志轮转、分布式追踪等。

代码示例

python
import logging
import structlog
from fastapi import FastAPI, Request
import time

# 配置结构化日志
structlog.configure(
    processors=[
        structlog.stdlib.filter_by_level,
        structlog.stdlib.add_log_level,
        structlog.stdlib.PositionalArgumentsFormatter(),
        structlog.processors.TimeStamper(fmt="iso"),
        structlog.processors.StackInfoRenderer(),
        structlog.processors.format_exc_info,
        structlog.processors.JSONRenderer()
    ],
    context_class=dict,
    logger_factory=structlog.stdlib.LoggerFactory(),
    cache_logger_on_first_use=True,
)

logger = structlog.get_logger()

app = FastAPI()

# 日志中间件
@app.middleware("http")
async def log_requests(request: Request, call_next):
    start_time = time.time()

    # 记录请求开始
    logger.info("Request started",
               method=request.method,
               url=request.url.path,
               query=dict(request.query_params))

    response = await call_next(request)

    # 记录请求结束
    process_time = time.time() - start_time
    logger.info("Request completed",
               method=request.method,
               url=request.url.path,
               status_code=response.status_code,
               process_time=process_time)

    return response

@app.get("/")
def read_root():
    logger.debug("Debug message")
    logger.info("Info message")
    logger.warning("Warning message")
    return {"message": "Hello World"}

代码解析

  • import structlog:导入结构化日志库
  • structlog.configure():配置结构化日志处理器
  • logger = structlog.get_logger():获取日志记录器
  • @app.middleware("http"):定义日志中间件,记录请求和响应信息
  • logger.info():记录信息级别的日志
  • logger.debug()logger.warning():记录不同级别的日志

使用方法

  1. 安装依赖:pip install structlog
  2. 运行应用,查看控制台输出的结构化日志
  3. 可以根据需要配置日志文件、日志轮转等

11.3 API分页

功能说明:实现API分页功能,支持偏移分页和游标分页,提供统一的分页响应格式。

代码示例

python
from fastapi import FastAPI, Query
from pydantic import BaseModel
from typing import List, Optional

app = FastAPI()

# 模拟数据
items = [f"item_{i}" for i in range(100)]

class PaginationParams(BaseModel):
    """分页参数"""
    page: int = Query(1, ge=1, description="页码")
    page_size: int = Query(10, ge=1, le=100, description="每页大小")

class PaginatedResponse(BaseModel):
    """分页响应"""
    items: List[str]
    total: int
    page: int
    page_size: int
    total_pages: int

@app.get("/items", response_model=PaginatedResponse)
def get_items(
    page: int = Query(1, ge=1, description="页码"),
    page_size: int = Query(10, ge=1, le=100, description="每页大小")
):
    """获取分页数据"""
    total = len(items)
    start = (page - 1) * page_size
    end = start + page_size

    # 计算总页数
    total_pages = (total + page_size - 1) // page_size

    # 获取当前页数据
    page_items = items[start:end]

    return PaginatedResponse(
        items=page_items,
        total=total,
        page=page,
        page_size=page_size,
        total_pages=total_pages
    )

代码解析

  • class PaginationParams:定义分页参数模型
  • class PaginatedResponse:定义分页响应模型
  • @app.get("/items"):定义获取分页数据的端点
  • start = (page - 1) * page_size:计算起始索引
  • end = start + page_size:计算结束索引
  • total_pages = (total + page_size - 1) // page_size:计算总页数
  • items[start:end]:获取当前页数据

使用方法

  1. 访问http://localhost:8000/items获取第一页数据
  2. 访问http://localhost:8000/items?page=2&page_size=20获取第二页,每页20条数据
  3. 响应会包含items、total、page、page_size、total_pages等字段

11.4 缓存策略

功能说明:使用Redis实现缓存策略,支持缓存穿透、缓存击穿、缓存雪崩防护。

代码示例

python
import redis
from fastapi import FastAPI, Depends
from typing import Optional
import json
import time

app = FastAPI()

# 连接Redis
redis_client = redis.Redis(host="localhost", port=6379, db=0)

# 缓存装饰器
def cache(key_pattern: str, expire: int = 3600):
    """缓存装饰器"""
    def decorator(func):
        async def wrapper(*args, **kwargs):
            # 生成缓存键
            key = key_pattern.format(**kwargs)

            # 尝试从缓存获取
            cached_data = redis_client.get(key)
            if cached_data:
                return json.loads(cached_data)

            # 缓存未命中,执行函数
            result = await func(*args, **kwargs)

            # 存入缓存
            redis_client.setex(key, expire, json.dumps(result))

            return result
        return wrapper
    return decorator

# 模拟数据库查询
def get_item_from_db(item_id: int):
    """从数据库获取数据"""
    print(f"Fetching item {item_id} from database")
    time.sleep(0.5)  # 模拟数据库查询延迟
    return {"id": item_id, "name": f"Item {item_id}"}

@app.get("/items/{item_id}")
async def get_item(item_id: int):
    """获取物品详情"""
    @cache(key_pattern="item:{item_id}", expire=60)
    async def fetch_item(item_id: int):
        return get_item_from_db(item_id)

    return await fetch_item(item_id)

代码解析

  • import redis:导入Redis客户端
  • redis_client = redis.Redis():创建Redis连接
  • def cache(key_pattern: str, expire: int = 3600):定义缓存装饰器
  • key = key_pattern.format(**kwargs):生成缓存键
  • cached_data = redis_client.get(key):从缓存获取数据
  • redis_client.setex(key, expire, json.dumps(result)):存入缓存并设置过期时间
  • @cache(key_pattern="item:{item_id}", expire=60):应用缓存装饰器

使用方法

  1. 安装依赖:pip install redis
  2. 启动Redis服务器
  3. 访问http://localhost:8000/items/1,第一次会从数据库获取数据
  4. 再次访问http://localhost:8000/items/1,会从缓存获取数据

11.5 异步任务队列

功能说明:使用Celery实现异步任务队列,处理耗时操作,如发送邮件、生成报表等。

代码示例

python
from fastapi import FastAPI, BackgroundTasks
from celery import Celery
import time

app = FastAPI()

# 配置Celery
celery_app = Celery(
    "tasks",
    broker="redis://localhost:6379/0",
    backend="redis://localhost:6379/0"
)

# 定义任务
@celery_app.task
def send_email(email: str, message: str):
    """发送邮件任务"""
    print(f"Sending email to {email}")
    time.sleep(5)  # 模拟发送邮件的延迟
    print(f"Email sent to {email} with message: {message}")
    return f"Email sent to {email}"

@celery_app.task
def generate_report(user_id: int, report_type: str):
    """生成报表任务"""
    print(f"Generating {report_type} report for user {user_id}")
    time.sleep(10)  # 模拟生成报表的延迟
    print(f"Report generated for user {user_id}")
    return f"Report generated: {report_type}"

@app.post("/send-email")
def send_email_endpoint(email: str, message: str):
    """发送邮件端点"""
    # 异步执行任务
    task = send_email.delay(email, message)
    return {"task_id": task.id, "status": "Task queued"}

@app.post("/generate-report")
def generate_report_endpoint(user_id: int, report_type: str):
    """生成报表端点"""
    # 异步执行任务
    task = generate_report.delay(user_id, report_type)
    return {"task_id": task.id, "status": "Task queued"}

@app.get("/task/{task_id}")
def get_task_status(task_id: str):
    """获取任务状态"""
    task = celery_app.AsyncResult(task_id)
    return {
        "task_id": task_id,
        "status": task.status,
        "result": task.result
    }

代码解析

  • from celery import Celery:导入Celery
  • celery_app = Celery():创建Celery实例
  • @celery_app.task:定义Celery任务
  • send_email.delay(email, message):异步执行任务
  • celery_app.AsyncResult(task_id):获取任务结果

使用方法

  1. 安装依赖:pip install celery redis
  2. 启动Redis服务器
  3. 启动Celery worker:celery -A main.celery_app worker --loglevel=info
  4. 发送POST请求到/send-email/generate-report
  5. 使用返回的task_id访问/task/{task_id}查看任务状态

11.6 数据库连接池

功能说明:配置数据库连接池,优化数据库连接管理,提高性能。

代码示例

python
from fastapi import FastAPI, Depends
from sqlalchemy import create_engine, Column, Integer, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session

app = FastAPI()

# 配置数据库连接池
DATABASE_URL = "postgresql://user:password@localhost:5432/dbname"

engine = create_engine(
    DATABASE_URL,
    pool_size=10,           # 连接池大小
    max_overflow=20,        # 最大溢出连接数
    pool_pre_ping=True,     # 连接前检查
    pool_recycle=3600,      # 连接回收时间
    echo=False              # 不打印SQL语句
)

SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()

# 定义模型
class Item(Base):
    __tablename__ = "items"
    id = Column(Integer, primary_key=True, index=True)
    name = Column(String, index=True)
    description = Column(String)

# 创建表
Base.metadata.create_all(bind=engine)

# 依赖项
def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

@app.get("/items")
def get_items(db: Session = Depends(get_db)):
    """获取所有物品"""
    return db.query(Item).all()

@app.post("/items")
def create_item(name: str, description: str, db: Session = Depends(get_db)):
    """创建物品"""
    item = Item(name=name, description=description)
    db.add(item)
    db.commit()
    db.refresh(item)
    return item

代码解析

  • create_engine():创建数据库引擎,配置连接池参数
  • pool_size=10:连接池大小
  • max_overflow=20:最大溢出连接数
  • pool_pre_ping=True:连接前检查
  • pool_recycle=3600:连接回收时间
  • SessionLocal = sessionmaker():创建会话工厂
  • def get_db():依赖项,提供数据库会话

使用方法

  1. 安装依赖:pip install sqlalchemy psycopg2-binary
  2. 配置数据库连接字符串
  3. 运行应用,访问/items获取物品列表
  4. 发送POST请求到/items创建新物品

11.7 健康检查

功能说明:实现健康检查端点,用于监控应用状态,支持Kubernetes等容器编排系统的探针。

代码示例

python
from fastapi import FastAPI, Depends
from sqlalchemy.orm import Session
import redis
import time

app = FastAPI()

# 数据库连接
def get_db():
    try:
        # 模拟数据库连接
        time.sleep(0.1)
        return "Database connected"
    except Exception as e:
        raise Exception("Database connection failed")

# Redis连接
def get_redis():
    try:
        redis_client = redis.Redis(host="localhost", port=6379, db=0)
        redis_client.ping()
        return "Redis connected"
    except Exception as e:
        raise Exception("Redis connection failed")

@app.get("/health")
def health_check(
    db_status: str = Depends(get_db),
    redis_status: str = Depends(get_redis)
):
    """健康检查端点"""
    return {
        "status": "healthy",
        "components": {
            "database": db_status,
            "redis": redis_status
        },
        "timestamp": time.time()
    }

@app.get("/health/liveness")
def liveness_probe():
    """存活探针"""
    return {"status": "alive"}

@app.get("/health/readiness")
def readiness_probe(
    db_status: str = Depends(get_db),
    redis_status: str = Depends(get_redis)
):
    """就绪探针"""
    return {"status": "ready"}

代码解析

  • /health:完整健康检查端点,检查所有依赖服务
  • /health/liveness:存活探针,只检查应用是否运行
  • /health/readiness:就绪探针,检查应用是否准备好处理请求
  • get_db():检查数据库连接
  • get_redis():检查Redis连接

使用方法

  1. 访问http://localhost:8000/health查看完整健康状态
  2. 访问http://localhost:8000/health/liveness查看存活状态
  3. 访问http://localhost:8000/health/readiness查看就绪状态
  4. 在Kubernetes配置中使用这些端点作为探针

11.8 请求ID追踪

功能说明:为每个请求生成唯一的请求ID,用于全链路追踪和日志关联。

代码示例

python
import uuid
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
import structlog

# 配置结构化日志
structlog.configure(
    processors=[
        structlog.stdlib.filter_by_level,
        structlog.stdlib.add_log_level,
        structlog.stdlib.PositionalArgumentsFormatter(),
        structlog.processors.TimeStamper(fmt="iso"),
        structlog.processors.StackInfoRenderer(),
        structlog.processors.format_exc_info,
        structlog.processors.JSONRenderer()
    ]
)

logger = structlog.get_logger()

app = FastAPI()

class RequestIDMiddleware(BaseHTTPMiddleware):
    """请求ID中间件"""
    async def dispatch(self, request: Request, call_next):
        # 生成或获取请求ID
        request_id = request.headers.get("X-Request-ID", str(uuid.uuid4()))

        # 添加到请求状态
        request.state.request_id = request_id

        # 添加上下文到日志
        logger = structlog.get_logger(request_id=request_id)

        # 记录请求开始
        logger.info("Request started",
                   method=request.method,
                   url=request.url.path)

        # 处理请求
        response = await call_next(request)

        # 添加请求ID到响应头
        response.headers["X-Request-ID"] = request_id

        # 记录请求结束
        logger.info("Request completed",
                   status_code=response.status_code)

        return response

# 添加中间件
app.add_middleware(RequestIDMiddleware)

@app.get("/")
def read_root(request: Request):
    """根路径"""
    # 从请求状态获取请求ID
    request_id = request.state.request_id
    logger.info("Processing request", request_id=request_id)
    return {"message": "Hello World", "request_id": request_id}

代码解析

  • class RequestIDMiddleware(BaseHTTPMiddleware):请求ID中间件
  • request_id = request.headers.get("X-Request-ID", str(uuid.uuid4())):生成或获取请求ID
  • request.state.request_id = request_id:添加到请求状态
  • logger = structlog.get_logger(request_id=request_id):添加上下文到日志
  • response.headers["X-Request-ID"] = request_id:添加请求ID到响应头
  • request.state.request_id:从请求状态获取请求ID

使用方法

  1. 访问http://localhost:8000/,响应会包含X-Request-ID头部
  2. 查看控制台输出的日志,会包含请求ID
  3. 可以在其他中间件或路由处理函数中使用request.state.request_id获取请求ID

11.9 数据序列化

功能说明:实现ORM模型与Pydantic模型之间的转换,支持复杂数据结构的序列化和反序列化。

代码示例

python
from fastapi import FastAPI
from sqlalchemy import create_engine, Column, Integer, String, ForeignKey
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session, relationship
from pydantic import BaseModel
from typing import List, Optional

app = FastAPI()

# 数据库配置
DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()

# 定义ORM模型
class User(Base):
    __tablename__ = "users"
    id = Column(Integer, primary_key=True, index=True)
    name = Column(String, index=True)
    email = Column(String, unique=True, index=True)
    items = relationship("Item", back_populates="owner")

class Item(Base):
    __tablename__ = "items"
    id = Column(Integer, primary_key=True, index=True)
    title = Column(String, index=True)
    description = Column(String)
    owner_id = Column(Integer, ForeignKey("users.id"))
    owner = relationship("User", back_populates="items")

# 创建表
Base.metadata.create_all(bind=engine)

# 定义Pydantic模型
class ItemBase(BaseModel):
    title: str
    description: Optional[str] = None

class ItemCreate(ItemBase):
    pass

class Item(ItemBase):
    id: int
    owner_id: int

    class Config:
        from_attributes = True  # Pydantic V2 语法,替代旧版的 orm_mode = True

class UserBase(BaseModel):
    name: str
    email: str

class UserCreate(UserBase):
    pass

class User(UserBase):
    id: int
    items: List[Item] = []

    class Config:
        from_attributes = True

# 依赖项
def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

# 数据访问函数
def create_user(db: Session, user: UserCreate):
    db_user = User(name=user.name, email=user.email)
    db.add(db_user)
    db.commit()
    db.refresh(db_user)
    return db_user

def create_item(db: Session, item: ItemCreate, user_id: int):
    db_item = Item(**item.model_dump(), owner_id=user_id)
    db.add(db_item)
    db.commit()
    db.refresh(db_item)
    return db_item

# 路由
@app.post("/users", response_model=User)
def create_user_endpoint(user: UserCreate, db: Session = Depends(get_db)):
    return create_user(db, user)

@app.post("/users/{user_id}/items", response_model=Item)
def create_item_endpoint(
    user_id: int,
    item: ItemCreate,
    db: Session = Depends(get_db)
):
    return create_item(db, item, user_id)

@app.get("/users", response_model=List[User])
def get_users(db: Session = Depends(get_db)):
    return db.query(User).all()

代码解析

  • class User(Base):定义用户ORM模型
  • class Item(Base):定义物品ORM模型
  • class ItemBase(BaseModel):定义物品基础Pydantic模型
  • class Item(ItemBase):定义物品响应Pydantic模型
  • class Config: from_attributes = True:启用从ORM模型创建Pydantic模型
  • db_item = Item(**item.model_dump(), owner_id=user_id):将Pydantic模型转换为字典,创建ORM模型

使用方法

  1. 运行应用
  2. 发送POST请求到/users创建用户
  3. 发送POST请求到/users/{user_id}/items为用户创建物品
  4. 发送GET请求到/users获取所有用户及其物品

11.10 批量操作

功能说明:实现批量创建、更新、删除操作,提高API性能。

代码示例

python
from fastapi import FastAPI, HTTPException
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from pydantic import BaseModel
from typing import List, Optional

app = FastAPI()

# 数据库配置
DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()

# 定义ORM模型
from sqlalchemy import Column, Integer, String

class Item(Base):
    __tablename__ = "items"
    id = Column(Integer, primary_key=True, index=True)
    name = Column(String, index=True)
    description = Column(String)

# 创建表
Base.metadata.create_all(bind=engine)

# 定义Pydantic模型
class ItemBase(BaseModel):
    name: str
    description: Optional[str] = None

class ItemCreate(ItemBase):
    pass

class ItemUpdate(ItemBase):
    name: Optional[str] = None
    description: Optional[str] = None

class Item(ItemBase):
    id: int

    class Config:
        from_attributes = True

# 依赖项
def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

# 批量创建
@app.post("/items/batch", response_model=List[Item])
def batch_create_items(items: List[ItemCreate], db: Session = Depends(get_db)):
    """批量创建物品"""
    db_items = [Item(**item.model_dump()) for item in items]
    db.add_all(db_items)
    db.commit()
    for item in db_items:
        db.refresh(item)
    return db_items

# 批量更新
@app.put("/items/batch")
def batch_update_items(
    updates: List[dict],  # 格式: [{"id": 1, "name": "New Name"}, ...]
    db: Session = Depends(get_db)
):
    """批量更新物品"""
    updated_count = 0
    for update in updates:
        item_id = update.pop("id", None)
        if not item_id:
            continue

        item = db.query(Item).filter(Item.id == item_id).first()
        if not item:
            continue

        for key, value in update.items():
            setattr(item, key, value)

        updated_count += 1

    db.commit()
    return {"updated_count": updated_count}

# 批量删除
@app.delete("/items/batch")
def batch_delete_items(item_ids: List[int], db: Session = Depends(get_db)):
    """批量删除物品"""
    deleted_count = db.query(Item).filter(Item.id.in_(item_ids)).delete(synchronize_session=False)
    db.commit()
    return {"deleted_count": deleted_count}

代码解析

  • @app.post("/items/batch"):批量创建物品端点
  • db_items = [Item(**item.model_dump()) for item in items]:创建多个ORM模型实例
  • db.add_all(db_items):批量添加到数据库
  • @app.put("/items/batch"):批量更新物品端点
  • db.query(Item).filter(Item.id.in_(item_ids)).delete():批量删除物品

使用方法

  1. 批量创建:发送POST请求到/items/batch,请求体为物品列表
  2. 批量更新:发送PUT请求到/items/batch,请求体为更新列表
  3. 批量删除:发送DELETE请求到/items/batch,请求体为物品ID列表

11.11 软删除

功能说明:实现软删除功能,标记数据为已删除而不是物理删除,支持数据恢复。

代码示例

python
from fastapi import FastAPI, Depends, Query
from sqlalchemy import create_engine, Column, Integer, String, Boolean, DateTime
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.sql import func
from pydantic import BaseModel
from typing import List, Optional

app = FastAPI()

# 数据库配置
DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()

# 定义ORM模型
class Item(Base):
    __tablename__ = "items"
    id = Column(Integer, primary_key=True, index=True)
    name = Column(String, index=True)
    description = Column(String)
    is_deleted = Column(Boolean, default=False, index=True)
    deleted_at = Column(DateTime, nullable=True)

# 创建表
Base.metadata.create_all(bind=engine)

# 定义Pydantic模型
class ItemBase(BaseModel):
    name: str
    description: Optional[str] = None

class ItemCreate(ItemBase):
    pass

class Item(ItemBase):
    id: int
    is_deleted: bool

    class Config:
        from_attributes = True

# 依赖项
def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

# 软删除查询
def get_active_items(db: Session):
    return db.query(Item).filter(Item.is_deleted == False)

def get_deleted_items(db: Session):
    return db.query(Item).filter(Item.is_deleted == True)

# 路由
@app.post("/items", response_model=Item)
def create_item(item: ItemCreate, db: Session = Depends(get_db)):
    """创建物品"""
    db_item = Item(**item.model_dump())
    db.add(db_item)
    db.commit()
    db.refresh(db_item)
    return db_item

@app.get("/items", response_model=List[Item])
def get_items(
    show_deleted: bool = Query(False, description="是否显示已删除的物品"),
    db: Session = Depends(get_db)
):
    """获取物品列表"""
    if show_deleted:
        return get_deleted_items(db).all()
    return get_active_items(db).all()

@app.delete("/items/{item_id}")
def delete_item(item_id: int, db: Session = Depends(get_db)):
    """软删除物品"""
    item = db.query(Item).filter(Item.id == item_id).first()
    if not item:
        raise HTTPException(status_code=404, detail="Item not found")

    item.is_deleted = True
    item.deleted_at = func.now()
    db.commit()

    return {"message": "Item soft deleted"}

@app.post("/items/{item_id}/restore")
def restore_item(item_id: int, db: Session = Depends(get_db)):
    """恢复已删除的物品"""
    item = db.query(Item).filter(Item.id == item_id, Item.is_deleted == True).first()
    if not item:
        raise HTTPException(status_code=404, detail="Deleted item not found")

    item.is_deleted = False
    item.deleted_at = None
    db.commit()

    return {"message": "Item restored"}

代码解析

  • is_deleted = Column(Boolean, default=False, index=True):软删除标记
  • deleted_at = Column(DateTime, nullable=True):删除时间
  • get_active_items(db):获取未删除的物品
  • get_deleted_items(db):获取已删除的物品
  • item.is_deleted = True:标记为已删除
  • item.is_deleted = False:恢复物品

使用方法

  1. 创建物品:POST /items
  2. 获取活跃物品:GET /items
  3. 获取已删除物品:GET /items?show_deleted=true
  4. 软删除物品:DELETE /items/{item_id}
  5. 恢复物品:POST /items/{item_id}/restore

11.12 审计日志

功能说明:实现操作审计日志,记录用户对数据的操作,包括创建、更新、删除等。

代码示例

python
from fastapi import FastAPI, Depends, Request
from sqlalchemy import create_engine, Column, Integer, String, DateTime, ForeignKey, Text, Enum
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.sql import func
from pydantic import BaseModel
from typing import List, Optional
import enum
import json

app = FastAPI()

# 数据库配置
DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()

# 操作类型枚举
class OperationType(enum.Enum):
    CREATE = "create"
    UPDATE = "update"
    DELETE = "delete"

# 定义ORM模型
class Item(Base):
    __tablename__ = "items"
    id = Column(Integer, primary_key=True, index=True)
    name = Column(String, index=True)
    description = Column(String)

class AuditLog(Base):
    __tablename__ = "audit_logs"
    id = Column(Integer, primary_key=True, index=True)
    user_id = Column(Integer, index=True)
    operation = Column(Enum(OperationType))
    table_name = Column(String)
    record_id = Column(Integer)
    old_data = Column(Text, nullable=True)
    new_data = Column(Text, nullable=True)
    ip_address = Column(String, nullable=True)
    user_agent = Column(String, nullable=True)
    created_at = Column(DateTime, default=func.now())

# 创建表
Base.metadata.create_all(bind=engine)

# 定义Pydantic模型
class ItemBase(BaseModel):
    name: str
    description: Optional[str] = None

class ItemCreate(ItemBase):
    pass

class Item(ItemBase):
    id: int

    class Config:
        from_attributes = True

# 依赖项
def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

# 审计日志装饰器
def audit_log(user_id: int, operation: OperationType, table_name: str):
    """审计日志装饰器"""
    def decorator(func):
        async def wrapper(*args, **kwargs):
            db = kwargs.get("db")
            request = kwargs.get("request")

            # 提取参数
            item_id = kwargs.get("item_id")
            item = kwargs.get("item")

            old_data = None
            if item_id:
                # 获取旧数据
                old_item = db.query(Item).filter(Item.id == item_id).first()
                if old_item:
                    old_data = json.dumps({
                        "name": old_item.name,
                        "description": old_item.description
                    })

            # 执行函数
            result = await func(*args, **kwargs)

            # 记录审计日志
            new_data = None
            record_id = None

            if operation == OperationType.CREATE and result:
                record_id = result.id
                new_data = json.dumps({
                    "name": result.name,
                    "description": result.description
                })
            elif operation == OperationType.UPDATE and item_id and item:
                record_id = item_id
                new_data = json.dumps(item.model_dump())
            elif operation == OperationType.DELETE and item_id:
                record_id = item_id

            # 获取请求信息
            ip_address = None
            user_agent = None
            if request:
                ip_address = request.client.host
                user_agent = request.headers.get("User-Agent")

            # 创建审计日志
            audit_log_entry = AuditLog(
                user_id=user_id,
                operation=operation,
                table_name=table_name,
                record_id=record_id,
                old_data=old_data,
                new_data=new_data,
                ip_address=ip_address,
                user_agent=user_agent
            )
            db.add(audit_log_entry)
            db.commit()

            return result
        return wrapper
    return decorator

# 路由
@app.post("/items", response_model=Item)
async def create_item(
    item: ItemCreate,
    request: Request,
    db: Session = Depends(get_db)
):
    """创建物品"""
    db_item = Item(**item.model_dump())
    db.add(db_item)
    db.commit()
    db.refresh(db_item)

    # 记录审计日志
    ip_address = request.client.host
    user_agent = request.headers.get("User-Agent")

    audit_log_entry = AuditLog(
        user_id=1,  # 假设用户ID为1
        operation=OperationType.CREATE,
        table_name="items",
        record_id=db_item.id,
        new_data=json.dumps(item.model_dump()),
        ip_address=ip_address,
        user_agent=user_agent
    )
    db.add(audit_log_entry)
    db.commit()

    return db_item

@app.put("/items/{item_id}", response_model=Item)
async def update_item(
    item_id: int,
    item: ItemCreate,
    request: Request,
    db: Session = Depends(get_db)
):
    """更新物品"""
    db_item = db.query(Item).filter(Item.id == item_id).first()
    if not db_item:
        raise HTTPException(status_code=404, detail="Item not found")

    # 记录旧数据
    old_data = json.dumps({
        "name": db_item.name,
        "description": db_item.description
    })

    # 更新数据
    db_item.name = item.name
    db_item.description = item.description
    db.commit()
    db.refresh(db_item)

    # 记录审计日志
    ip_address = request.client.host
    user_agent = request.headers.get("User-Agent")

    audit_log_entry = AuditLog(
        user_id=1,  # 假设用户ID为1
        operation=OperationType.UPDATE,
        table_name="items",
        record_id=item_id,
        old_data=old_data,
        new_data=json.dumps(item.model_dump()),
        ip_address=ip_address,
        user_agent=user_agent
    )
    db.add(audit_log_entry)
    db.commit()

    return db_item

@app.delete("/items/{item_id}")
async def delete_item(
    item_id: int,
    request: Request,
    db: Session = Depends(get_db)
):
    """删除物品"""
    db_item = db.query(Item).filter(Item.id == item_id).first()
    if not db_item:
        raise HTTPException(status_code=404, detail="Item not found")

    # 记录旧数据
    old_data = json.dumps({
        "name": db_item.name,
        "description": db_item.description
    })

    # 删除数据
    db.delete(db_item)
    db.commit()

    # 记录审计日志
    ip_address = request.client.host
    user_agent = request.headers.get("User-Agent")

    audit_log_entry = AuditLog(
        user_id=1,  # 假设用户ID为1
        operation=OperationType.DELETE,
        table_name="items",
        record_id=item_id,
        old_data=old_data,
        ip_address=ip_address,
        user_agent=user_agent
    )
    db.add(audit_log_entry)
    db.commit()

    return {"message": "Item deleted"}

@app.get("/audit-logs")
def get_audit_logs(db: Session = Depends(get_db)):
    """获取审计日志"""
    logs = db.query(AuditLog).order_by(AuditLog.created_at.desc()).all()
    return [{
        "id": log.id,
        "user_id": log.user_id,
        "operation": log.operation.value,
        "table_name": log.table_name,
        "record_id": log.record_id,
        "old_data": json.loads(log.old_data) if log.old_data else None,
        "new_data": json.loads(log.new_data) if log.new_data else None,
        "ip_address": log.ip_address,
        "user_agent": log.user_agent,
        "created_at": log.created_at
    } for log in logs]

代码解析

  • class OperationType(enum.Enum):操作类型枚举
  • class AuditLog(Base):审计日志模型
  • audit_log 装饰器:用于记录审计日志
  • json.dumps():将数据转换为JSON字符串
  • json.loads():将JSON字符串转换为Python对象

使用方法

  1. 创建、更新、删除物品时会自动记录审计日志
  2. 访问/audit-logs查看审计日志
  3. 审计日志包含操作类型、表名、记录ID、旧数据、新数据、IP地址、用户代理等信息

11.13 异常监控

功能说明:集成Sentry等异常监控服务,实时监控和追踪应用异常。

代码示例

python
from fastapi import FastAPI, HTTPException
import sentry_sdk
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware

# 配置Sentry
sentry_sdk.init(
    dsn="your-sentry-dsn",  # 替换为你的Sentry DSN
    traces_sample_rate=1.0,  # 采样率
    profiles_sample_rate=1.0,  # 性能分析采样率
)

app = FastAPI()

# 添加Sentry中间件
app.add_middleware(SentryAsgiMiddleware)

@app.get("/")
def read_root():
    return {"message": "Hello World"}

@app.get("/error")
def trigger_error():
    """触发错误,用于测试Sentry"""
    raise HTTPException(status_code=500, detail="Internal Server Error")

@app.get("/divide-by-zero")
def divide_by_zero():
    """触发除零错误"""
    1 / 0

@app.get("/item/{item_id}")
def get_item(item_id: int):
    """获取物品,不存在时触发404错误"""
    if item_id == 42:
        raise HTTPException(status_code=404, detail="Item not found")
    return {"item_id": item_id}

代码解析

  • import sentry_sdk:导入Sentry SDK
  • sentry_sdk.init():初始化Sentry
  • app.add_middleware(SentryAsgiMiddleware):添加Sentry中间件
  • raise HTTPException():触发HTTP异常
  • 1 / 0:触发除零错误

使用方法

  1. 安装依赖:pip install sentry-sdk
  2. 在Sentry.io创建项目,获取DSN
  3. 替换代码中的your-sentry-dsn为实际的DSN
  4. 运行应用
  5. 访问/error/divide-by-zero等端点触发错误
  6. 在Sentry控制台查看错误详情

12. 最佳实践

功能说明:FastAPI开发的最佳实践,帮助你编写高质量的API。

详细说明

  1. 使用Pydantic模型

    • 定义请求体和响应模型,确保数据验证和类型安全
    • 使用Field类添加字段描述和验证规则
    • 利用嵌套模型处理复杂数据结构
  2. 使用依赖注入

    • 管理数据库连接、认证等共享资源
    • 使用依赖链处理复杂的依赖关系
    • 利用可调用类依赖提供更灵活的依赖管理
  3. 使用路径操作装饰器

    • 清晰定义API端点
    • 使用HTTP方法(GET、POST、PUT、DELETE等)表达操作意图
    • 合理组织路由结构,提高代码可读性
  4. 使用类型提示

    • 提高代码可读性和IDE支持
    • 利用Python类型系统发现潜在错误
    • 使代码更加自文档化
  5. 使用中间件

    • 处理跨域请求、日志记录等横切关注点
    • 实现请求预处理和响应后处理
    • 保持路由处理函数的专注性
  6. 使用异常处理

    • 统一处理错误,返回清晰的错误信息
    • 自定义异常类,提供更具体的错误类型
    • 使用异常处理中间件,统一处理全局异常
  7. 使用环境变量

    • 管理配置,避免硬编码敏感信息
    • 使用pydantic-settings库管理配置
    • 不同环境使用不同的配置文件
  8. 使用文档字符串

    • 为API端点添加文档
    • 使用Markdown格式编写详细的文档
    • 提供请求和响应的示例
  9. 使用测试

    • 编写单元测试和集成测试
    • 使用TestClient测试API端点
    • 确保代码的可靠性和稳定性
  10. 使用版本控制

    • 为API添加版本前缀,如/v1/items
    • 使用路由器组织不同版本的API
    • 平滑过渡到新的API版本

13. 常见问题

13.1 CORS错误

问题:前端无法访问API,出现CORS错误。

解决方案:配置CORS中间件:

python
from fastapi.middleware.cors import CORSMiddleware

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

详细说明

  • allow_origins:允许的来源,生产环境应该设置具体的域名
  • allow_credentials:是否允许携带凭证
  • allow_methods:允许的HTTP方法
  • allow_headers:允许的HTTP头部

12.2 数据库连接问题

问题:数据库连接失败。

解决方案

  1. 检查数据库连接字符串是否正确
  2. 确保数据库服务正在运行
  3. 检查数据库用户权限
  4. 确认依赖项已正确安装

示例

python
# 检查数据库连接字符串
SQLALCHEMY_DATABASE_URL = "postgresql://user:password@localhost/dbname"

# 测试数据库连接
try:
    engine = create_engine(SQLALCHEMY_DATABASE_URL)
    conn = engine.connect()
    print("Database connection successful")
    conn.close()
except Exception as e:
    print(f"Database connection failed: {e}")

12.3 部署问题

问题:部署后API无法访问。

解决方案

  1. 检查端口映射是否正确
  2. 检查防火墙设置,确保端口已开放
  3. 查看应用日志,了解具体错误
  4. 确认环境变量和配置是否正确

示例

bash
# 查看应用日志
docker logs fastapi-app

# 检查端口映射
docker ps

# 测试API是否可访问
curl http://localhost:8000/

14. 高级特性

14.1 响应模型和字段

功能说明:使用Pydantic模型和Field类定义响应模型,添加字段描述和验证规则。

代码示例

python
from fastapi import FastAPI
from pydantic import BaseModel, Field

app = FastAPI()

# 响应模型
class Item(BaseModel):
    name: str
    description: str = Field(None, title="The description of the item", max_length=300)
    price: float = Field(..., gt=0, description="The price must be greater than zero")
    tax: float = None

@app.post("/items", response_model=Item)
def create_item(item: Item):
    return item

代码解析

  • Field(None, title="The description of the item", max_length=300):可选字段,添加标题和最大长度限制
  • Field(..., gt=0, description="The price must be greater than zero"):必填字段,添加大于0的验证和描述
  • response_model=Item:指定响应模型,FastAPI会自动过滤响应中不在模型中的字段

使用方法

  1. 发送POST请求到http://localhost:8000/items,请求体为:
    json
    {
      "name": "Item 1",
      "description": "This is item 1",
      "price": 10.0
    }
  2. 响应为:
    json
    {
      "name": "Item 1",
      "description": "This is item 1",
      "price": 10.0,
      "tax": null
    }

14.2 错误处理

功能说明:自定义异常和异常处理中间件,统一处理错误。

代码示例

python
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse

app = FastAPI()

# 自定义异常
class ItemNotFoundException(HTTPException):
    def __init__(self, item_id):
        super().__init__(
            status_code=404,
            detail=f"Item {item_id} not found",
            headers={"X-Error": "Item not found"},
        )

# 异常处理中间件
@app.exception_handler(ItemNotFoundException)
async def item_not_found_exception_handler(request: Request, exc: ItemNotFoundException):
    return JSONResponse(
        status_code=exc.status_code,
        content={"detail": exc.detail},
        headers=exc.headers,
    )

@app.get("/items/{item_id}")
def get_item(item_id: int):
    if item_id == 42:
        raise ItemNotFoundException(item_id)
    return {"item_id": item_id}

代码解析

  • class ItemNotFoundException(HTTPException):自定义异常类,继承自HTTPException
  • @app.exception_handler(ItemNotFoundException):异常处理中间件,处理ItemNotFoundException异常
  • raise ItemNotFoundException(item_id):抛出自定义异常

使用方法

  1. 访问http://localhost:8000/items/42,会返回404错误,响应为:
    json
    {
      "detail": "Item 42 not found"
    }
  2. 响应头部会包含X-Error: Item not found

14.3 后台任务

功能说明:使用FastAPI的BackgroundTasks功能在后台执行耗时操作,如发送邮件、处理文件等,不阻塞主请求处理。

代码示例

python
from fastapi import FastAPI, BackgroundTasks
import time

app = FastAPI()

def send_email(email: str, message: str):
    """发送邮件的后台任务"""
    print(f"Sending email to {email}")
    time.sleep(2)  # 模拟发送邮件的延迟
    print(f"Email sent to {email} with message: {message}")

@app.post("/send-email")
def send_email_endpoint(
    email: str,
    message: str,
    background_tasks: BackgroundTasks
):
    # 添加后台任务
    background_tasks.add_task(send_email, email, message)
    return {"message": "Email will be sent in the background"}

代码解析

  • from fastapi import BackgroundTasks:导入BackgroundTasks类
  • def send_email(email: str, message: str):定义后台任务函数,模拟发送邮件的耗时操作
  • background_tasks: BackgroundTasks:将BackgroundTasks作为依赖注入到路由函数中
  • background_tasks.add_task(send_email, email, message):添加后台任务,传入任务函数和参数
  • return {"message": "Email will be sent in the background"}:立即返回响应,后台任务在后台执行

使用方法

  1. 发送POST请求到http://localhost:8000/send-email,请求体为:
    json
    {
      "email": "user@example.com",
      "message": "Hello from FastAPI"
    }
  2. 响应为:
    json
    {
      "message": "Email will be sent in the background"
    }
  3. 查看控制台输出,可以看到邮件发送的日志
  4. 后台任务会在请求处理完成后继续执行,不会阻塞响应

14.4 WebSockets

功能说明:使用FastAPI的WebSocket支持实现实时双向通信,如聊天应用、实时数据更新等。

代码示例

python
from fastapi import FastAPI, WebSocket
from fastapi.websockets import WebSocketDisconnect

app = FastAPI()

class ConnectionManager:
    def __init__(self):
        self.active_connections = []

    async def connect(self, websocket: WebSocket):
        await websocket.accept()
        self.active_connections.append(websocket)

    def disconnect(self, websocket: WebSocket):
        self.active_connections.remove(websocket)

    async def broadcast(self, message: str):
        for connection in self.active_connections:
            await connection.send_text(message)

manager = ConnectionManager()

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    await manager.connect(websocket)
    try:
        while True:
            data = await websocket.receive_text()
            await websocket.send_text(f"You sent: {data}")
            await manager.broadcast(f"Client says: {data}")
    except WebSocketDisconnect:
        manager.disconnect(websocket)
        await manager.broadcast("A client disconnected")

代码解析

  • from fastapi import WebSocket:导入WebSocket类
  • from fastapi.websockets import WebSocketDisconnect:导入WebSocketDisconnect异常
  • class ConnectionManager:创建连接管理器类,用于管理WebSocket连接
  • async def connect(self, websocket: WebSocket):接受WebSocket连接并添加到活跃连接列表
  • def disconnect(self, websocket: WebSocket):从活跃连接列表中移除断开的连接
  • async def broadcast(self, message: str):向所有活跃连接广播消息
  • @app.websocket("/ws"):定义WebSocket端点
  • async def websocket_endpoint(websocket: WebSocket):WebSocket处理函数
  • await websocket.receive_text():接收客户端发送的文本消息
  • await websocket.send_text(f"You sent: {data}"):向客户端发送文本消息
  • await manager.broadcast(f"Client says: {data}"):向所有客户端广播消息

使用方法

  1. 启动应用
  2. 使用WebSocket客户端连接到ws://localhost:8000/ws
  3. 发送消息到服务器,服务器会回显消息并广播给所有连接的客户端
  4. 当客户端断开连接时,服务器会广播断开连接的消息
  5. 可以使用浏览器的开发者工具或WebSocket测试工具(如Postman)测试WebSocket连接

14.5 事件处理

功能说明:使用FastAPI的事件处理功能在应用启动和关闭时执行特定操作,如初始化数据库连接、关闭资源等。

代码示例

python
from fastapi import FastAPI

app = FastAPI()

# 启动事件
@app.on_event("startup")
async def startup_event():
    print("Application startup")
    # 初始化数据库连接等

# 关闭事件
@app.on_event("shutdown")
async def shutdown_event():
    print("Application shutdown")
    # 关闭数据库连接等

@app.get("/")
def read_root():
    return {"message": "Hello World"}

代码解析

  • @app.on_event("startup"):注册应用启动时执行的事件处理函数
  • async def startup_event():启动事件处理函数,用于初始化资源
  • @app.on_event("shutdown"):注册应用关闭时执行的事件处理函数
  • async def shutdown_event():关闭事件处理函数,用于清理资源

使用方法

  1. 启动应用时,会执行startup_event()函数,打印"Application startup"并执行初始化操作
  2. 关闭应用时,会执行shutdown_event()函数,打印"Application shutdown"并执行清理操作
  3. 可以在启动事件中初始化数据库连接、加载配置等
  4. 可以在关闭事件中关闭数据库连接、释放资源等

14.6 测试

功能说明:使用FastAPI的TestClient进行API测试,确保API端点正常工作。

代码示例

python
from fastapi import FastAPI
from fastapi.testclient import TestClient

app = FastAPI()

@app.get("/")
def read_root():
    return {"message": "Hello World"}

@app.get("/items/{item_id}")
def read_item(item_id: int, q: str = None):
    return {"item_id": item_id, "q": q}

# 测试客户端
client = TestClient(app)

# 测试根路径
def test_read_root():
    response = client.get("/")
    assert response.status_code == 200
    assert response.json() == {"message": "Hello World"}

# 测试带路径参数的路径
def test_read_item():
    response = client.get("/items/42")
    assert response.status_code == 200
    assert response.json() == {"item_id": 42, "q": None}

# 测试带查询参数的路径
def test_read_item_with_query():
    response = client.get("/items/42?q=test")
    assert response.status_code == 200
    assert response.json() == {"item_id": 42, "q": "test"}

代码解析

  • from fastapi.testclient import TestClient:导入TestClient类
  • client = TestClient(app):创建测试客户端,传入FastAPI应用实例
  • def test_read_root():测试根路径的函数
  • response = client.get("/"):发送GET请求到根路径
  • assert response.status_code == 200:断言响应状态码为200
  • assert response.json() == {"message": "Hello World"}:断言响应内容正确
  • def test_read_item():测试带路径参数的路径
  • def test_read_item_with_query():测试带查询参数的路径

使用方法

  1. 安装测试依赖:pip install pytest
  2. 将测试代码保存为test_main.py文件
  3. 运行测试:pytest test_main.py -v
  4. 查看测试结果,确保所有测试通过
  5. 可以添加更多测试用例,测试不同的API端点和场景

14.7 性能优化

功能说明:使用FastAPI的性能优化特性,如Gzip压缩和缓存,提高API响应速度。

代码示例

python
from fastapi import FastAPI
from fastapi.middleware.gzip import GZipMiddleware
import uvicorn

app = FastAPI()

# 启用Gzip压缩
app.add_middleware(GZipMiddleware, minimum_size=1000)

# 缓存响应
from fastapi_cache import FastAPICache
from fastapi_cache.backends.redis import RedisBackend
from fastapi_cache.decorator import cache
import redis

# 初始化缓存
redis_client = redis.Redis(host="localhost", port=6379, db=0)
FastAPICache.init(RedisBackend(redis_client), prefix="fastapi-cache")

@app.get("/items/{item_id}")
@cache(expire=60)  # 缓存60秒
async def read_item(item_id: int):
    # 模拟耗时操作
    import time
    time.sleep(0.5)
    return {"item_id": item_id}

if __name__ == "__main__":
    uvicorn.run("main:app", host="0.0.0.0", port=8000, workers=4)

代码解析

  • from fastapi.middleware.gzip import GZipMiddleware:导入GZipMiddleware中间件
  • app.add_middleware(GZipMiddleware, minimum_size=1000):启用Gzip压缩,最小压缩大小为1000字节
  • from fastapi_cache import FastAPICache:导入FastAPICache
  • from fastapi_cache.backends.redis import RedisBackend:导入RedisBackend
  • from fastapi_cache.decorator import cache:导入cache装饰器
  • redis_client = redis.Redis(host="localhost", port=6379, db=0):创建Redis客户端
  • FastAPICache.init(RedisBackend(redis_client), prefix="fastapi-cache"):初始化缓存
  • @cache(expire=60):缓存装饰器,设置缓存过期时间为60秒
  • uvicorn.run("main:app", host="0.0.0.0", port=8000, workers=4):启动4个工作进程

使用方法

  1. 安装依赖:pip install fastapi-cache redis
  2. 确保Redis服务正在运行
  3. 启动应用
  4. 访问http://localhost:8000/items/42,第一次访问会有0.5秒的延迟
  5. 再次访问http://localhost:8000/items/42,会从缓存中获取响应,响应速度更快
  6. 查看响应头部,会包含Content-Encoding: gzip字段,表示启用了Gzip压缩

14.8 高级依赖注入

功能说明:使用FastAPI的高级依赖注入特性,如依赖链和可调用类依赖,实现更复杂的依赖管理。

代码示例

python
from fastapi import FastAPI, Depends
from typing import Optional

app = FastAPI()

# 依赖链
def get_db():
    db = "Database connection"
    try:
        yield db
    finally:
        print("Closing database connection")

def get_user(db: str = Depends(get_db)):
    user = "Current user"
    return {"db": db, "user": user}

@app.get("/items")
def get_items(user: dict = Depends(get_user)):
    return {"user": user, "items": ["item1", "item2"]}

# 可调用类依赖
class Pagination:
    def __init__(self, skip: int = 0, limit: int = 100):
        self.skip = skip
        self.limit = limit

@app.get("/items")
def get_items(pagination: Pagination = Depends(Pagination)):
    return {"skip": pagination.skip, "limit": pagination.limit}

代码解析

  • def get_db():定义数据库依赖函数,返回数据库连接
  • def get_user(db: str = Depends(get_db)):定义用户依赖函数,依赖于get_db
  • user: dict = Depends(get_user):使用依赖链,先获取数据库连接,再获取用户信息
  • class Pagination:定义可调用类依赖,用于处理分页参数
  • def __init__(self, skip: int = 0, limit: int = 100):初始化方法,设置默认值
  • pagination: Pagination = Depends(Pagination):使用可调用类依赖,自动解析参数

使用方法

  1. 访问http://localhost:8000/items,会调用依赖链:get_db() → get_user() → get_items()
  2. 响应为:{"user": {"db": "Database connection", "user": "Current user"}, "items": ["item1", "item2"]}
  3. 访问完成后,会执行get_db()中的finally块,打印"Closing database connection"
  4. 访问http://localhost:8000/items?skip=10&limit=50,会使用Pagination依赖,返回:{"skip": 10, "limit": 50}

14.9 安全特性

功能说明:使用FastAPI的安全特性,如HTTP Basic认证,实现API的访问控制。

代码示例

python
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import HTTPBasic, HTTPBasicCredentials
import secrets

app = FastAPI()
security = HTTPBasic()

# 模拟用户数据库
users = {
    "alice": "secret",
    "bob": "password",
}

def authenticate_user(credentials: HTTPBasicCredentials = Depends(security)):
    password = users.get(credentials.username)
    if not password or not secrets.compare_digest(password, credentials.password):
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Incorrect username or password",
            headers={"WWW-Authenticate": "Basic"},
        )
    return credentials.username

@app.get("/users/me")
def read_current_user(username: str = Depends(authenticate_user)):
    return {"username": username}

代码解析

  • from fastapi.security import HTTPBasic, HTTPBasicCredentials:导入HTTPBasic和HTTPBasicCredentials
  • security = HTTPBasic():创建HTTPBasic安全实例
  • users:模拟用户数据库
  • def authenticate_user(credentials: HTTPBasicCredentials = Depends(security)):认证用户的依赖函数
  • secrets.compare_digest(password, credentials.password):安全地比较密码
  • raise HTTPException(...):认证失败时抛出401错误
  • @app.get("/users/me"):需要认证的端点
  • username: str = Depends(authenticate_user):使用认证依赖

使用方法

  1. 访问http://localhost:8000/users/me,会弹出认证对话框
  2. 输入用户名alice和密码secret,会返回:{"username": "alice"}
  3. 输入错误的用户名或密码,会返回401错误
  4. 可以在Swagger UI中测试,点击"Authorize"按钮输入凭据

14.10 API版本控制

功能说明:使用FastAPI的API版本控制特性,管理不同版本的API端点。

代码示例

python
from fastapi import FastAPI

app = FastAPI()

# V1版本
@app.get("/v1/items")
def get_items_v1():
    return {"version": "v1", "items": ["item1", "item2"]}

# V2版本
@app.get("/v2/items")
def get_items_v2():
    return {"version": "v2", "items": ["item1", "item2", "item3"]}

# 使用路由器进行版本控制
from fastapi import APIRouter

api_v1 = APIRouter(prefix="/v1")
api_v2 = APIRouter(prefix="/v2")

@api_v1.get("/items")
def get_items_v1():
    return {"version": "v1", "items": ["item1", "item2"]}

@api_v2.get("/items")
def get_items_v2():
    return {"version": "v2", "items": ["item1", "item2", "item3"]}

app.include_router(api_v1)
app.include_router(api_v2)

代码解析

  • @app.get("/v1/items"):直接在路径中添加版本前缀
  • @app.get("/v2/items"):定义V2版本的端点
  • from fastapi import APIRouter:导入APIRouter类
  • api_v1 = APIRouter(prefix="/v1"):创建V1版本的路由器,设置前缀
  • api_v2 = APIRouter(prefix="/v2"):创建V2版本的路由器,设置前缀
  • @api_v1.get("/items"):在V1路由器上定义端点
  • @api_v2.get("/items"):在V2路由器上定义端点
  • app.include_router(api_v1):将V1路由器包含到应用中
  • app.include_router(api_v2):将V2路由器包含到应用中

使用方法

  1. 访问http://localhost:8000/v1/items,会返回V1版本的响应:{"version": "v1", "items": ["item1", "item2"]}
  2. 访问http://localhost:8000/v2/items,会返回V2版本的响应:{"version": "v2", "items": ["item1", "item2", "item3"]}
  3. 使用路由器的方式可以更好地组织代码,将不同版本的API分离到不同的模块中

14.11 数据验证

功能说明:使用FastAPI的数据验证特性,对路径参数、查询参数和请求体进行验证。

代码示例

python
from fastapi import FastAPI, Query, Path, Body
from pydantic import BaseModel, Field, validator

app = FastAPI()

# 路径参数验证
@app.get("/items/{item_id}")
def read_item(
    item_id: int = Path(..., title="Item ID", ge=1, le=1000),
    q: str = Query(None, min_length=3, max_length=50)
):
    return {"item_id": item_id, "q": q}

# 请求体验证
class Item(BaseModel):
    name: str
    description: str = None
    price: float = Field(..., gt=0)
    tax: float = None

    @validator("name")
    def name_must_not_be_empty(cls, v):
        if not v:
            raise ValueError("Name must not be empty")
        return v

@app.post("/items")
def create_item(item: Item = Body(..., embed=True)):
    return item

代码解析

  • from fastapi import Query, Path, Body:导入查询参数、路径参数和请求体验证工具
  • from pydantic import BaseModel, Field, validator:导入Pydantic模型、字段和验证器
  • item_id: int = Path(..., title="Item ID", ge=1, le=1000):使用Path验证路径参数,设置标题和范围
  • q: str = Query(None, min_length=3, max_length=50):使用Query验证查询参数,设置长度限制
  • class Item(BaseModel):定义请求体模型
  • price: float = Field(..., gt=0):使用Field验证字段,设置大于0的限制
  • @validator("name"):使用装饰器定义字段验证器
  • def name_must_not_be_empty(cls, v):验证name字段不能为空
  • item: Item = Body(..., embed=True):使用Body验证请求体,设置embed=True使请求体嵌套

使用方法

  1. 访问http://localhost:8000/items/42,返回{"item_id": 42, "q": null}
  2. 访问http://localhost:8000/items/0,会返回422错误,因为item_id必须大于等于1
  3. 访问http://localhost:8000/items/1001,会返回422错误,因为item_id必须小于等于1000
  4. 发送POST请求到http://localhost:8000/items,请求体为:
    json
    {
      "item": {
        "name": "Item 1",
        "price": 10.0
      }
    }
  5. 如果发送的请求体中name为空或price小于等于0,会返回422错误

14.12 文件上传

功能说明:使用FastAPI的文件上传特性,支持单文件和多文件上传。

代码示例

python
from fastapi import FastAPI, UploadFile, File
from typing import List

app = FastAPI()

# 单个文件上传
@app.post("/upload-file")
async def upload_file(file: UploadFile = File(...)):
    content = await file.read()
    return {
        "filename": file.filename,
        "content_type": file.content_type,
        "size": len(content)
    }

# 多个文件上传
@app.post("/upload-files")
async def upload_files(files: List[UploadFile] = File(...)):
    results = []
    for file in files:
        content = await file.read()
        results.append({
            "filename": file.filename,
            "content_type": file.content_type,
            "size": len(content)
        })
    return results

代码解析

  • from fastapi import UploadFile, File:导入UploadFile和File类
  • from typing import List:导入List类型
  • file: UploadFile = File(...):使用File定义单文件上传参数
  • async def upload_file:定义异步函数处理文件上传
  • content = await file.read():异步读取文件内容
  • files: List[UploadFile] = File(...):使用File定义多文件上传参数,类型为List[UploadFile]
  • for file in files:遍历上传的多个文件

使用方法

  1. 使用curl命令上传单文件:
    bash
    curl -X POST http://localhost:8000/upload-file -F file=@/path/to/file.txt
  2. 使用curl命令上传多文件:
    bash
    curl -X POST http://localhost:8000/upload-files -F files=@/path/to/file1.txt -F files=@/path/to/file2.txt
  3. 在Swagger UI中测试,点击"Try it out"按钮,然后选择文件进行上传
  4. 上传成功后,会返回文件的信息,包括文件名、内容类型和大小

14.13 速率限制

功能说明:使用FastAPI的中间件特性,实现API的速率限制,防止恶意请求和DoS攻击。

代码示例

python
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
import time

app = FastAPI()

# 内存存储的速率限制
rate_limit_store = {}
RATE_LIMIT = 5  # 每分钟请求数
WINDOW_SIZE = 60  # 时间窗口(秒)

@app.middleware("http")
async def rate_limit_middleware(request: Request, call_next):
    client_ip = request.client.host
    current_time = time.time()

    # 初始化客户端记录
    if client_ip not in rate_limit_store:
        rate_limit_store[client_ip] = []

    # 清理过期的请求记录
    rate_limit_store[client_ip] = [
        timestamp for timestamp in rate_limit_store[client_ip]
        if current_time - timestamp < WINDOW_SIZE
    ]

    # 检查速率限制
    if len(rate_limit_store[client_ip]) >= RATE_LIMIT:
        return JSONResponse(
            status_code=429,
            content={"detail": "Rate limit exceeded"}
        )

    # 记录当前请求
    rate_limit_store[client_ip].append(current_time)

    # 处理请求
    response = await call_next(request)
    return response

@app.get("/")
def read_root():
    return {"message": "Hello World"}

@app.get("/items/{item_id}")
def read_item(item_id: int):
    return {"item_id": item_id}

代码解析

  • rate_limit_store = {}:内存存储,用于记录客户端的请求时间戳
  • RATE_LIMIT = 5:设置每分钟的请求限制为5次
  • WINDOW_SIZE = 60:设置时间窗口为60秒
  • @app.middleware("http"):定义HTTP中间件
  • async def rate_limit_middleware(request: Request, call_next):中间件函数,处理请求
  • client_ip = request.client.host:获取客户端IP地址
  • current_time = time.time():获取当前时间戳
  • if client_ip not in rate_limit_store:初始化客户端记录
  • rate_limit_store[client_ip] = [...]:清理过期的请求记录
  • if len(rate_limit_store[client_ip]) >= RATE_LIMIT:检查是否超出速率限制
  • return JSONResponse(...):返回429错误,提示速率限制已超出
  • rate_limit_store[client_ip].append(current_time):记录本次请求的时间戳
  • response = await call_next(request):继续处理请求
  • @app.get("/"):定义根端点
  • @app.get("/items/{item_id}"):定义带路径参数的端点

使用方法

  1. 访问http://localhost:8000/,会返回{"message": "Hello World"}
  2. 访问http://localhost:8000/items/42,会返回{"item_id": 42}
  3. 在一分钟内连续访问任何端点超过5次,会返回429错误,提示"Rate limit exceeded"
  4. 等待一分钟后,又可以继续访问这些端点
  5. 可以根据需要调整RATE_LIMITWINDOW_SIZE的值,以适应不同的场景

注意:本示例使用内存存储实现速率限制,适用于开发和测试环境。在生产环境中,建议使用Redis等分布式存储来实现速率限制,以支持多实例部署。

15. 总结

FastAPI是一个强大、快速、现代化的Web框架,它结合了Python类型提示的优势,提供了自动API文档、请求验证、依赖注入等功能,使得构建API变得更加简单和高效。通过本教程,你已经了解了FastAPI的基本使用方法、核心功能和高级特性,可以开始构建自己的API应用了。

FastAPI的主要优势包括:

  • 高性能:基于Starlette和Pydantic,性能与NodeJS和Go相当
  • 自动API文档:自动生成Swagger UI和ReDoc文档
  • 类型安全:基于Python类型提示,提供编译时类型检查
  • 依赖注入:内置强大的依赖注入系统
  • 安全性:自动处理CORS、CSRF等安全特性
  • 标准化:基于OpenAPI标准
  • 易于使用:简洁的API设计,易于学习和使用
  • 扩展性:支持WebSocket、后台任务、事件处理等高级特性

FastAPI是构建现代API的理想选择,无论是构建小型API还是大型微服务架构,都能满足你的需求。