数据库和ORMS:使用SQLAlchemy与数据库通信

2022-11-27 17:21:47 浏览数 (1)

文章目录
  • 1. 环境安装
  • 2. 使用SQLAlchemy与SQL数据库通信
    • 2.1 创建表
    • 2.2 连接数据库
    • 2.3 insert、select
    • 2.4 update、delete
    • 2.5 relationships
    • 2.6 用Alembic进行数据库迁移

learn from 《Building Data Science Applications with FastAPI》

1. 环境安装

docker 安装 MongoDB 服务

代码语言:javascript复制
 docker run -d --name fastapi-mongo -p 27017:27017 mongo:4.4

2. 使用SQLAlchemy与SQL数据库通信

安装 pip install databases[sqlite]

2.1 创建表

代码语言:javascript复制
# models.py

import sqlalchemy
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field

metadata = sqlalchemy.MetaData()  # 创建元数据对象

posts = sqlalchemy.Table(  # 创建表对象
    'posts',  # 表名
    metadata,  # 元数据对象
    # 列对象(列名,类型,其他选项)
    sqlalchemy.Column('id', sqlalchemy.Integer, primary_key=True, autoincrement=True),
    sqlalchemy.Column('publication_date', sqlalchemy.DateTime(), nullable=False),
    sqlalchemy.Column('title', sqlalchemy.String(255), nullable=False),
    sqlalchemy.Column('text', sqlalchemy.Text(), nullable=False),
)


class PostBase(BaseModel):
    title: str
    text: str
    publication_date: datetime = Field(dafault_factory=datetime.now)

class PostPartialUpdate(BaseModel):
    text: Optional[str] = None
    content: Optional[str] = None

class PostCreate(PostBase):
    pass

class PostDB(PostBase):
    id: int

2.2 连接数据库

代码语言:javascript复制
# _*_ coding: utf-8 _*_
# @Time : 2022/3/8 9:28
# @Author : Michael
# @File : database.py
# @desc :
import sqlalchemy
from databases import Database
DB_URL = 'sqlite:///cp6_sqlalchemy.db'
database = Database(DB_URL)
sqlalchemy_engine = sqlalchemy.create_engine(DB_URL)

def get_database() -> Database:
    return database

2.3 insert、select

代码语言:javascript复制
# _*_ coding: utf-8 _*_
# @Time : 2022/3/8 9:40
# @Author : Michael
# @File : app.py
# @desc :

from typing import List, Tuple
import uvicorn
from databases import Database
from fastapi import Depends, FastAPI, HTTPException, Query, status

from database import get_database, sqlalchemy_engine
from models import metadata, posts, PostDB, PostCreate, PostPartialUpdate

app = FastAPI()

@app.on_event('startup') # 启动的时候执行数据库连接
async def startup():
    await get_database().connect()
    metadata.create_all(sqlalchemy_engine)

@app.on_event("shutdown") # 关闭的时候执行数据库断开连接
async def shutdown():
    await get_database().disconnect()

async def pagination(
        skip: int = Query(0, ge=0),
        limit: int = Query(10, ge=0),) -> Tuple[int, int]:
    capped_limit = min(100, limit)
    return (skip, capped_limit)

async def get_post_or_404(id: int, database: Database = Depends(get_database)) -> PostDB:
    select_query = posts.select().where(posts.c.id == id)
    raw_post = await database.fetch_one(select_query)

    if raw_post is None:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)

    return PostDB(**raw_post)

# 开始插入数据
@app.post("/posts/", response_model=PostDB, status_code=status.HTTP_201_CREATED)
async def create_post(post: PostCreate, db: Database = Depends(get_database)) -> PostDB:
    # 创建插入语句,不必手写sql
    insert_query = posts.insert().values(post.dict())
    # 执行插入语句命令
    post_id = await db.execute(insert_query)
    post_db = await get_post_or_404(post_id, db)
    return post_db

@app.get("/posts/{id}", response_model=PostDB)
async def get_post(post: PostDB = Depends(get_post_or_404)) -> PostDB:
    return post

@app.get("/posts")
async def list_posts(
        pagination: Tuple[int, int] = Depends(pagination),
        database: Database = Depends(get_database),) -> List[PostDB]:
    skip, limit = pagination
    select_query = posts.select().offset(skip).limit(limit)
    rows = await database.fetch_all(select_query)

    results = [PostDB(**row) for row in rows]

    return results

if __name__ == '__main__':
    uvicorn.run(app='app:app', host="127.0.0.1", port=8001, reload=True, debug=True)

2.4 update、delete

代码语言:javascript复制
# update
@app.patch("/posts/{id}", response_model=PostDB)
async def update_post(post_update: PostPartialUpdate,
                      post: PostDB = Depends(get_post_or_404),
                      database: Database = Depends(get_database)) -> PostDB:
    update_query = (
        posts.update().where(posts.c.id == post.id).values(post_update.dict(exclude_unset=True))
    )
    await database.execute(update_query)
    post_db = await get_post_or_404(post.id, database)
    return post_db
代码语言:javascript复制
# delete
@app.delete("/posts/{id}",status_code=status.HTTP_204_NO_CONTENT)
async def delete_post(post: PostDB = Depends(get_post_or_404),
                      database: Database = Depends(get_database)) -> None:
    delete_query = posts.delete().where(posts.c.id == post.id)
    await database.execute(delete_query)

2.5 relationships

models.py 编写新的表

代码语言:javascript复制
comments = sqlalchemy.Table(
    "comments",
    metadata,
    sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True),
    # 定义连接的外键
    sqlalchemy.Column(
        "post_id", sqlalchemy.ForeignKey("posts.id", ondelete="CASCADE"), nullable=False
    ),
    sqlalchemy.Column("publication_date", sqlalchemy.DateTime(), nullable=False),
    sqlalchemy.Column("content", sqlalchemy.Text(), nullable=False),
)


class CommentBase(BaseModel):
    post_id: int
    publication_date: datetime = Field(default_factory=datetime.now)
    content: str


class CommentCreate(CommentBase):
    pass


class CommentDB(CommentBase):
    id: int

app.py 添加内容

代码语言:javascript复制
from typing import List, Mapping, Tuple, cast
from models import metadata, posts, PostDB, PostCreate, PostPartialUpdate, comments, CommentCreate, CommentDB

@app.post("/comments", response_model=CommentDB, status_code=status.HTTP_201_CREATED)
async def create_comment(
        comment: CommentCreate, database: Database = Depends(get_database)
) -> CommentDB:
	# 选取post表单数据
    select_post_query = posts.select().where(posts.c.id == comment.post_id)
    post = await database.fetch_one(select_post_query)

    if post is None:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST, detail=f"Post {id} does not exist"
        )
	# 插入comment 语句
    insert_query = comments.insert().values(comment.dict())
    comment_id = await database.execute(insert_query)
	# 查询 comment
    select_query = comments.select().where(comments.c.id == comment_id)
    raw_comment = cast(Mapping, await database.fetch_one(select_query))

    return CommentDB(**raw_comment)

获取一个post的全部comments

models.py

代码语言:javascript复制
class PostPublic(PostDB):
    comments: Optional[List[CommentDB]] = None

app.py

代码语言:javascript复制
# _*_ coding: utf-8 _*_
# @Time : 2022/3/8 9:40
# @Author : Michael
# @File : app.py
# @desc :

from typing import List, Mapping, Tuple, cast
import uvicorn
from databases import Database
from fastapi import Depends, FastAPI, HTTPException, Query, status

from database import get_database, sqlalchemy_engine
from models import metadata, posts, PostDB, PostCreate, PostPartialUpdate, comments, CommentCreate, CommentDB, 
    PostPublic

app = FastAPI()


@app.on_event('startup')  # 启动的时候执行数据库连接
async def startup():
    await get_database().connect()
    metadata.create_all(sqlalchemy_engine)


@app.on_event("shutdown")  # 关闭的时候执行数据库断开连接
async def shutdown():
    await get_database().disconnect()


async def pagination(
        skip: int = Query(0, ge=0),
        limit: int = Query(10, ge=0), ) -> Tuple[int, int]:
    capped_limit = min(100, limit)
    return (skip, capped_limit)


async def get_post_or_404(id: int, database: Database = Depends(get_database)) -> PostPublic:
    select_query = posts.select().where(posts.c.id == id)
    raw_post = await database.fetch_one(select_query)

    if raw_post is None:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
    # 编号为id的post的所有comments
    select_post_comment_query = comments.select().where(comments.c.post_id == id)
    raw_comments = await database.fetch_all(select_post_comment_query)
    comments_list = [CommentDB(**row) for row in raw_comments]
    return PostPublic(**raw_post, comments=comments_list)


# 开始插入数据
@app.post("/posts", response_model=PostDB, status_code=status.HTTP_201_CREATED)
async def create_post(post: PostCreate, db: Database = Depends(get_database)) -> PostPublic:
    # 创建插入语句,不必手写sql
    insert_query = posts.insert().values(post.dict())
    # 执行插入语句命令
    post_id = await db.execute(insert_query)
    post_db = await get_post_or_404(post_id, db)
    return post_db


@app.get("/posts")
async def list_posts(pagination: Tuple[int, int] = Depends(pagination),
                     database: Database = Depends(get_database), ) -> List[PostDB]:
    skip, limit = pagination
    select_query = posts.select().offset(skip).limit(limit)
    rows = await database.fetch_all(select_query)

    results = [PostDB(**row) for row in rows]

    return results


@app.get("/posts/{id}", response_model=PostPublic)
async def get_post(post: PostPublic = Depends(get_post_or_404)) -> PostPublic:
    return post


@app.get("/posts")
async def list_posts(
        pagination: Tuple[int, int] = Depends(pagination),
        database: Database = Depends(get_database), ) -> List[PostDB]:
    skip, limit = pagination
    select_query = posts.select().offset(skip).limit(limit)
    rows = await database.fetch_all(select_query)

    results = [PostDB(**row) for row in rows]

    return results


# update
@app.patch("/posts/{id}", response_model=PostPublic)
async def update_post(post_update: PostPartialUpdate,
                      post: PostPublic = Depends(get_post_or_404),
                      database: Database = Depends(get_database)) -> PostPublic:
    update_query = (
        posts.update().where(posts.c.id == post.id).values(post_update.dict(exclude_unset=True))
    )
    await database.execute(update_query)
    post_db = await get_post_or_404(post.id, database)
    return post_db


# delete
@app.delete("/posts/{id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_post(post: PostPublic = Depends(get_post_or_404),
                      database: Database = Depends(get_database)) -> None:
    delete_query = posts.delete().where(posts.c.id == post.id)
    await database.execute(delete_query)


@app.post("/comments", response_model=CommentDB, status_code=status.HTTP_201_CREATED)
async def create_comment(
        comment: CommentCreate, database: Database = Depends(get_database)
) -> CommentDB:
    select_post_query = posts.select().where(posts.c.id == comment.post_id)
    post = await database.fetch_one(select_post_query)

    if post is None:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST, detail=f"Post {comment.post_id} does not exist"
        )

    insert_query = comments.insert().values(comment.dict())
    comment_id = await database.execute(insert_query)

    select_query = comments.select().where(comments.c.id == comment_id)
    raw_comment = cast(Mapping, await database.fetch_one(select_query))

    return CommentDB(**raw_comment)


if __name__ == '__main__':
    uvicorn.run(app='app:app', host="127.0.0.1", port=8001, reload=True, debug=True)

2.6 用Alembic进行数据库迁移

代码语言:javascript复制
pip install alembic

终端输入:

代码语言:javascript复制
alembic init alembic

初始化迁移环境,其中包括一组文件和目录,Alembic将在其中存储其配置和迁移文件,需要一起提交 git

在 env.py 中导入元数据

代码语言:javascript复制
from web_python_dev.sqlalchemy1.models import metadata

target_metadata = metadata

编辑ini配置

开始迁移

代码语言:javascript复制
alembic revision --autogenerate -m "Initial migration"

之后会生成一个py文件

该代码内有两个函数:upgradedowngrade用于数据迁移和回滚

代码语言:javascript复制
# 升级
alembic upgrade head

数据的迁移和升级之前请做好备份和测试,防止丢失损坏 https://alembic.sqlalchemy.org/en/latest/index.html

0 人点赞