56 lines
1.4 KiB
Python
56 lines
1.4 KiB
Python
from fastapi.encoders import jsonable_encoder
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
class CRUDBase:
|
|
|
|
def __init__(self, model):
|
|
self.model = model
|
|
|
|
async def get(
|
|
self,
|
|
obj_id: int,
|
|
session: AsyncSession
|
|
):
|
|
db_obj = await session.execute(
|
|
select(self.model).where(
|
|
self.model.db == obj_id
|
|
)
|
|
)
|
|
return db_obj.scalars().first()
|
|
|
|
async def get_multi(self, session: AsyncSession):
|
|
db_objs = await session.execute(select(self.model))
|
|
return db_objs.scalars().all()
|
|
|
|
async def create(self,
|
|
obj_in,
|
|
session: AsyncSession):
|
|
obj_in_data = obj_in
|
|
db_obj = self.model(**obj_in_data)
|
|
session.add(db_obj)
|
|
await session.commit()
|
|
await session.refresh(db_obj)
|
|
return db_obj
|
|
|
|
|
|
async def update(
|
|
self,
|
|
db_obj,
|
|
obj_in,
|
|
session: AsyncSession
|
|
):
|
|
obj_data = jsonable_encoder(db_obj)
|
|
update_data = obj_in.dict(exclude_unsets=True)
|
|
for field in obj_data:
|
|
if field in update_data:
|
|
setattr(db_obj, field, update_data[field])
|
|
session.add(db_obj)
|
|
await session.commit()
|
|
await session.refresh(db_obj)
|
|
return db_obj
|
|
|
|
|
|
|