FastAPI 数据库访问(一)使用SQLAlchemy访问关系数据库

时间:2024-03-17 15:08:00

作者:麦克煎蛋   出处:https://www.cnblogs.com/mazhiyong/ 转载请保留这段声明,谢谢!

 

SQLAlchemy是一个基于Python实现的ORM框架。它提供了一种方法,用于将用户定义的Python类与数据库表相关联,并将这些类(对象)的实例与其对应表中的行相关联。它包括一个透明地同步对象及其相关行之间状态的所有变化的系统,以及根据用户定义的类及其定义的彼此之间的关系表达数据库查询的系统。

关于SQLAlchemy的具体使用细节这里不再赘述,重点讲述数据库模型与Pydantic模型使用、以及数据库Session有关的内容。

这里我们以MySQL为例。SQLAlchemy本身无法操作数据库,其必须借助pymysql等第三方插件。

pip install pymysql
pip install sqlalchemy

一、 首先实现对数据库的操作

这里以联系人为例,实现了对联系人数据的新增、读取以及更新操作:

注意,这里的数据模型DBUser指的是与数据库相关的数据模型。

from sqlalchemy import Column, DateTime, String, text, create_engine
from sqlalchemy.dialects.mysql import INTEGER, VARCHAR
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base

from sqlalchemy.orm import Session

# db connect config(略,可自行填写)
MYSQL_USER = \'\'
MYSQL_PASS = \'\'
MYSQL_HOST = \'\'
MYSQL_PORT = \'3306\'
MYSQL_DB = \'\'

SQLALCHEMY_DATABASE_URI = \'mysql+pymysql://%s:%s@%s:%s/%s\' % (MYSQL_USER, MYSQL_PASS, MYSQL_HOST, MYSQL_PORT, MYSQL_DB)

# 创建对象的基类:
Base = declarative_base()

# 初始化数据库连接:
engine = create_engine(SQLALCHEMY_DATABASE_URI)

# SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
SessionLocal = sessionmaker(bind=engine)


class DBUser(Base):
    __tablename__ = \'test_user\'

    id = Column(INTEGER(64), primary_key=True, comment=\'编号\')
    username = Column(String(100))
    password = Column(String(100))
    sex = Column(VARCHAR(10), server_default=text("\'\'"), comment=\'性别\')
    login_time = Column(INTEGER(11), server_default=text("\'0\'"), comment=\'登陆时间,主要为了登陆JWT校验使用\')
    create_date = Column(DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP"))
    update_date = Column(DateTime, nullable=False, server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"))

    @classmethod
    def add(cls, db: Session, data):
        db.add(data)
        db.commit()
        # db.refresh(data)

    @classmethod
    def get_by_username(cls, db: Session, username):
        data = db.query(cls).filter_by(username=username).first()

        return data

    @classmethod
    def update(cls, db: Session, username, sex):
        db.query(cls).filter_by(username=username).update({cls.sex: sex})

        db.commit()

这里的db:Session从调用者中传入,每次请求只会用一个数据库Session,请求结束后关闭。

二、实现业务逻辑

这里以联系人注册、登陆、数据读取等常用流程为例。

注意以下基础数据模型,指的是Pydantic数据模型,用于返回给终端。

同时要注意到,SQLAlchemy模型用 "="来定义属性,而Pydantic模型用":"来声明类型,不要弄混了。

class User(BaseModel):
    id: Optional[int] = None
    username: str
    sex: Optional[str] = None
    login_time: Optional[int] = None

    class Config:
        orm_mode = True

 

注意,我们给Pydantic模型添加了一个 Config类。Config用来给Pydantic提供配置信息,这里我们添加了配置信息"orm_mode = True"。

配置项"orm_mode"除了可以让Pydantic读取字典类型的数据,还支持Pydantic读取属性数据,比如SQLAlchemy模型的数据。

这样Pydantic数据模型就可以兼容SQLAlchemy数据模型,我们可以在路径操作函数中直接返回SQLAlchemy数据模型(没有这个配置项的支持是不行的)。

1、用户注册

@app.post("/register", response_model=User)
async def register(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
    # 密码加密
    password = get_password_hash(form_data.password)

    db_user = DBUser.get_by_username(db, form_data.username)
    if db_user:
        return db_user

    db_user = DBUser(username=form_data.username, password=password)
    DBUser.add(db, db_user)

    return db_user

通过get_db来获取数据库Session,请求结束后关闭。

def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

2、用户登陆

@app.post("/login", response_model=Token)
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
    # 首先校验用户信息
    user = authenticate_user(db, form_data.username, form_data.password)
    if not user:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Incorrect username or password",
            headers={"WWW-Authenticate": "Bearer"},
        )

    # 生成并返回token信息
    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    access_token = create_access_token(
        data={"sub": user.username}, expires_delta=access_token_expires
    )

    return {"access_token": access_token, "token_type": "bearer"}

在登陆的时候要对用户名和密码进行校验:

# 用户信息校验:username和password分别校验
def authenticate_user(db: Session, username: str, password: str):
    user = DBUser.get_by_username(db, username)
    if not user:
        return False
    if not verify_password(password, user.password):
        return False
    return user

如果登陆成功则返回token信息:

# 生成token,带有过期时间
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(minutes=15)
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
    return encoded_jwt

3、接口访问示例

async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        username: str = payload.get("sub")
        print(username)

        if username is None:
            raise credentials_exception
        token_data = TokenData(username=username)
    except PyJWTError:
        raise credentials_exception
    user = DBUser.get_by_username(db, token_data.username)
    if user is None:
        raise credentials_exception
    return user


@app.get("/users/me/", response_model=User)
async def read_users_me(current_user: User = Depends(get_current_user)):
    return current_user

这里以读取用户信息为例,请求端要在头信息中携带token信息。

 

后端收到请求后,要对token进行解析,如果合法则继续访问,如果非法则返回401错误信息。

其他接口的校验过程与此类似。