#!/usr/bin/python
# -*- coding: utf-8 -*-
# @version        : 1.0
# @Create Time    : 2022/2/24 10:21 
# @File           : crud.py
# @IDE            : PyCharm
# @desc           : 增删改查

from typing import Any

from requests import Session
from sqlalchemy.orm.strategy_options import _AbstractLoad, contains_eager
from dbgpt.app.apps.core.exception import CustomException
from sqlalchemy import select, false, and_, func
from dbgpt.app.apps.core.crud import DalBase
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update
from . import models, schemas

from dbgpt.app.apps.utils import status
from .params.media_list import MediaEditParams
from .schemas import QuestionEditParams

class MediaDal(DalBase):

    def __init__(self, db: AsyncSession):
        super(MediaDal, self).__init__()
        self.db = db
        self.model = models.VadminMedia
        self.schema = schemas.MediaOut

    async def update_media_datas(self, params: MediaEditParams) -> None:
        """
        编辑资源名称、分组
        """
        if params.file_name is not None:
            # 文件名称不为空 判定为重命名
            sql = update(self.model).where(self.model.id.in_(params.ids)).values(file_name=params.file_name,key_word=params.key_word)
            await self.db.execute(sql)
        elif params.group_id is not None:
            # 文件分组不为空 判定为移动分组
            sql = update(self.model).where(self.model.id.in_(params.ids)).values(group_id=params.group_id)
            await self.db.execute(sql)

    async def get_count_by_groupIds(self, ids: list) -> int:
        stmt = select(func.count(self.model.id)).where(and_(
            self.model.group_id.in_(ids),
            self.model.is_delete == 0
        ))
        result = await self.db.execute(stmt)
        count = result.scalar()
        return count


class GroupDal(DalBase):

    def __init__(self, db: AsyncSession):
        super(GroupDal, self).__init__()
        self.db = db
        self.model = models.VadminGroup
        self.schema = schemas.GroupOut


class QuestionDal(DalBase):

    def __init__(self, db: AsyncSession):
        super(QuestionDal, self).__init__()
        self.db = db
        self.model = models.VadminQuestion
        self.schema = schemas.QuestionOut

    async def update_question_datas(self, params: QuestionEditParams) -> None:
        sql = update(self.model).where(self.model.id.in_(params.ids)).values(group_id=params.group_id,
                                                                             title=params.title,
                                                                             key_word=params.key_word,
                                                                             answer=params.answer)
        await self.db.execute(sql)



class CorrelationDal(DalBase):

    def __init__(self, db: AsyncSession):
        super(CorrelationDal, self).__init__()
        self.db = db
        self.model = models.VadminCorrelation
        self.schema = schemas.CorrelationOut
