文章目录- 1. 环境安装
- 2. 使用SQLAlchemy与SQL数据库通信
- 2.1 创建表
- 2.2 连接数据库
- 2.3 insert、select
- 2.4 update、delete
- 2.5 relationships
- 2.6 用Alembic进行数据库迁移
- 2.1 创建表
- 2.2 连接数据库
- 2.3 insert、select
- 2.4 update、delete
- 2.5 relationships
- 2.6 用Alembic进行数据库迁移
learn from 《Building Data Science Applications with FastAPI》
1. 环境安装
docker 安装 MongoDB 服务
代码语言:javascript复制 docker run -d --name fastapi-mongo -p 27017:27017 mongo:4.4
2. 使用SQLAlchemy与SQL数据库通信
安装 pip install databases[sqlite]
2.1 创建表
代码语言:javascript复制# models.py
import sqlalchemy
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field
metadata = sqlalchemy.MetaData() # 创建元数据对象
posts = sqlalchemy.Table( # 创建表对象
'posts', # 表名
metadata, # 元数据对象
# 列对象(列名,类型,其他选项)
sqlalchemy.Column('id', sqlalchemy.Integer, primary_key=True, autoincrement=True),
sqlalchemy.Column('publication_date', sqlalchemy.DateTime(), nullable=False),
sqlalchemy.Column('title', sqlalchemy.String(255), nullable=False),
sqlalchemy.Column('text', sqlalchemy.Text(), nullable=False),
)
class PostBase(BaseModel):
title: str
text: str
publication_date: datetime = Field(dafault_factory=datetime.now)
class PostPartialUpdate(BaseModel):
text: Optional[str] = None
content: Optional[str] = None
class PostCreate(PostBase):
pass
class PostDB(PostBase):
id: int
2.2 连接数据库
代码语言:javascript复制# _*_ coding: utf-8 _*_
# @Time : 2022/3/8 9:28
# @Author : Michael
# @File : database.py
# @desc :
import sqlalchemy
from databases import Database
DB_URL = 'sqlite:///cp6_sqlalchemy.db'
database = Database(DB_URL)
sqlalchemy_engine = sqlalchemy.create_engine(DB_URL)
def get_database() -> Database:
return database
2.3 insert、select
代码语言:javascript复制# _*_ coding: utf-8 _*_
# @Time : 2022/3/8 9:40
# @Author : Michael
# @File : app.py
# @desc :
from typing import List, Tuple
import uvicorn
from databases import Database
from fastapi import Depends, FastAPI, HTTPException, Query, status
from database import get_database, sqlalchemy_engine
from models import metadata, posts, PostDB, PostCreate, PostPartialUpdate
app = FastAPI()
@app.on_event('startup') # 启动的时候执行数据库连接
async def startup():
await get_database().connect()
metadata.create_all(sqlalchemy_engine)
@app.on_event("shutdown") # 关闭的时候执行数据库断开连接
async def shutdown():
await get_database().disconnect()
async def pagination(
skip: int = Query(0, ge=0),
limit: int = Query(10, ge=0),) -> Tuple[int, int]:
capped_limit = min(100, limit)
return (skip, capped_limit)
async def get_post_or_404(id: int, database: Database = Depends(get_database)) -> PostDB:
select_query = posts.select().where(posts.c.id == id)
raw_post = await database.fetch_one(select_query)
if raw_post is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return PostDB(**raw_post)
# 开始插入数据
@app.post("/posts/", response_model=PostDB, status_code=status.HTTP_201_CREATED)
async def create_post(post: PostCreate, db: Database = Depends(get_database)) -> PostDB:
# 创建插入语句,不必手写sql
insert_query = posts.insert().values(post.dict())
# 执行插入语句命令
post_id = await db.execute(insert_query)
post_db = await get_post_or_404(post_id, db)
return post_db
@app.get("/posts/{id}", response_model=PostDB)
async def get_post(post: PostDB = Depends(get_post_or_404)) -> PostDB:
return post
@app.get("/posts")
async def list_posts(
pagination: Tuple[int, int] = Depends(pagination),
database: Database = Depends(get_database),) -> List[PostDB]:
skip, limit = pagination
select_query = posts.select().offset(skip).limit(limit)
rows = await database.fetch_all(select_query)
results = [PostDB(**row) for row in rows]
return results
if __name__ == '__main__':
uvicorn.run(app='app:app', host="127.0.0.1", port=8001, reload=True, debug=True)
2.4 update、delete
代码语言:javascript复制# update
@app.patch("/posts/{id}", response_model=PostDB)
async def update_post(post_update: PostPartialUpdate,
post: PostDB = Depends(get_post_or_404),
database: Database = Depends(get_database)) -> PostDB:
update_query = (
posts.update().where(posts.c.id == post.id).values(post_update.dict(exclude_unset=True))
)
await database.execute(update_query)
post_db = await get_post_or_404(post.id, database)
return post_db
代码语言:javascript复制# delete
@app.delete("/posts/{id}",status_code=status.HTTP_204_NO_CONTENT)
async def delete_post(post: PostDB = Depends(get_post_or_404),
database: Database = Depends(get_database)) -> None:
delete_query = posts.delete().where(posts.c.id == post.id)
await database.execute(delete_query)
2.5 relationships
models.py 编写新的表
代码语言:javascript复制comments = sqlalchemy.Table(
"comments",
metadata,
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True),
# 定义连接的外键
sqlalchemy.Column(
"post_id", sqlalchemy.ForeignKey("posts.id", ondelete="CASCADE"), nullable=False
),
sqlalchemy.Column("publication_date", sqlalchemy.DateTime(), nullable=False),
sqlalchemy.Column("content", sqlalchemy.Text(), nullable=False),
)
class CommentBase(BaseModel):
post_id: int
publication_date: datetime = Field(default_factory=datetime.now)
content: str
class CommentCreate(CommentBase):
pass
class CommentDB(CommentBase):
id: int
app.py 添加内容
代码语言:javascript复制from typing import List, Mapping, Tuple, cast
from models import metadata, posts, PostDB, PostCreate, PostPartialUpdate, comments, CommentCreate, CommentDB
@app.post("/comments", response_model=CommentDB, status_code=status.HTTP_201_CREATED)
async def create_comment(
comment: CommentCreate, database: Database = Depends(get_database)
) -> CommentDB:
# 选取post表单数据
select_post_query = posts.select().where(posts.c.id == comment.post_id)
post = await database.fetch_one(select_post_query)
if post is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Post {id} does not exist"
)
# 插入comment 语句
insert_query = comments.insert().values(comment.dict())
comment_id = await database.execute(insert_query)
# 查询 comment
select_query = comments.select().where(comments.c.id == comment_id)
raw_comment = cast(Mapping, await database.fetch_one(select_query))
return CommentDB(**raw_comment)
获取一个post
的全部comments
models.py
代码语言:javascript复制class PostPublic(PostDB):
comments: Optional[List[CommentDB]] = None
app.py
代码语言:javascript复制# _*_ coding: utf-8 _*_
# @Time : 2022/3/8 9:40
# @Author : Michael
# @File : app.py
# @desc :
from typing import List, Mapping, Tuple, cast
import uvicorn
from databases import Database
from fastapi import Depends, FastAPI, HTTPException, Query, status
from database import get_database, sqlalchemy_engine
from models import metadata, posts, PostDB, PostCreate, PostPartialUpdate, comments, CommentCreate, CommentDB,
PostPublic
app = FastAPI()
@app.on_event('startup') # 启动的时候执行数据库连接
async def startup():
await get_database().connect()
metadata.create_all(sqlalchemy_engine)
@app.on_event("shutdown") # 关闭的时候执行数据库断开连接
async def shutdown():
await get_database().disconnect()
async def pagination(
skip: int = Query(0, ge=0),
limit: int = Query(10, ge=0), ) -> Tuple[int, int]:
capped_limit = min(100, limit)
return (skip, capped_limit)
async def get_post_or_404(id: int, database: Database = Depends(get_database)) -> PostPublic:
select_query = posts.select().where(posts.c.id == id)
raw_post = await database.fetch_one(select_query)
if raw_post is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
# 编号为id的post的所有comments
select_post_comment_query = comments.select().where(comments.c.post_id == id)
raw_comments = await database.fetch_all(select_post_comment_query)
comments_list = [CommentDB(**row) for row in raw_comments]
return PostPublic(**raw_post, comments=comments_list)
# 开始插入数据
@app.post("/posts", response_model=PostDB, status_code=status.HTTP_201_CREATED)
async def create_post(post: PostCreate, db: Database = Depends(get_database)) -> PostPublic:
# 创建插入语句,不必手写sql
insert_query = posts.insert().values(post.dict())
# 执行插入语句命令
post_id = await db.execute(insert_query)
post_db = await get_post_or_404(post_id, db)
return post_db
@app.get("/posts")
async def list_posts(pagination: Tuple[int, int] = Depends(pagination),
database: Database = Depends(get_database), ) -> List[PostDB]:
skip, limit = pagination
select_query = posts.select().offset(skip).limit(limit)
rows = await database.fetch_all(select_query)
results = [PostDB(**row) for row in rows]
return results
@app.get("/posts/{id}", response_model=PostPublic)
async def get_post(post: PostPublic = Depends(get_post_or_404)) -> PostPublic:
return post
@app.get("/posts")
async def list_posts(
pagination: Tuple[int, int] = Depends(pagination),
database: Database = Depends(get_database), ) -> List[PostDB]:
skip, limit = pagination
select_query = posts.select().offset(skip).limit(limit)
rows = await database.fetch_all(select_query)
results = [PostDB(**row) for row in rows]
return results
# update
@app.patch("/posts/{id}", response_model=PostPublic)
async def update_post(post_update: PostPartialUpdate,
post: PostPublic = Depends(get_post_or_404),
database: Database = Depends(get_database)) -> PostPublic:
update_query = (
posts.update().where(posts.c.id == post.id).values(post_update.dict(exclude_unset=True))
)
await database.execute(update_query)
post_db = await get_post_or_404(post.id, database)
return post_db
# delete
@app.delete("/posts/{id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_post(post: PostPublic = Depends(get_post_or_404),
database: Database = Depends(get_database)) -> None:
delete_query = posts.delete().where(posts.c.id == post.id)
await database.execute(delete_query)
@app.post("/comments", response_model=CommentDB, status_code=status.HTTP_201_CREATED)
async def create_comment(
comment: CommentCreate, database: Database = Depends(get_database)
) -> CommentDB:
select_post_query = posts.select().where(posts.c.id == comment.post_id)
post = await database.fetch_one(select_post_query)
if post is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=f"Post {comment.post_id} does not exist"
)
insert_query = comments.insert().values(comment.dict())
comment_id = await database.execute(insert_query)
select_query = comments.select().where(comments.c.id == comment_id)
raw_comment = cast(Mapping, await database.fetch_one(select_query))
return CommentDB(**raw_comment)
if __name__ == '__main__':
uvicorn.run(app='app:app', host="127.0.0.1", port=8001, reload=True, debug=True)
2.6 用Alembic进行数据库迁移
代码语言:javascript复制pip install alembic
终端输入:
代码语言:javascript复制alembic init alembic
初始化迁移环境,其中包括一组文件和目录,Alembic
将在其中存储其配置和迁移文件,需要一起提交 git
在 env.py 中导入元数据
代码语言:javascript复制from web_python_dev.sqlalchemy1.models import metadata
target_metadata = metadata
编辑ini
配置
开始迁移
代码语言:javascript复制alembic revision --autogenerate -m "Initial migration"
之后会生成一个py文件
该代码内有两个函数:upgrade
,downgrade
用于数据迁移和回滚
# 升级
alembic upgrade head
数据的迁移和升级之前请做好备份和测试,防止丢失损坏 https://alembic.sqlalchemy.org/en/latest/index.html