from collections import defaultdict
import re

from fastapi import Request, Depends
from dbgpt.app.apps.core.database import db_getter
from sqlalchemy.ext.asyncio import AsyncSession

from dbgpt.app.apps.vadmin.word.crud import SensitiveDal
from dbgpt.app.apps.vadmin.media.crud import MediaDal,QuestionDal

from dbgpt.app.apps.vadmin.auth.utils.current import AllUserAuth, FullAdminAuth, OpenAuth
from dbgpt.app.apps.vadmin.auth.utils.validation.auth import Auth

class DFAFilter():
    '''Filter Messages from keywords
    Use DFA to keep algorithm perform constantly
    敏感词过滤

    #>>> f = DFAFilter()
    #>>> f.add("sexy")
    #>>> f.filter("hello sexy baby")
    hello **** baby
    '''

    def __init__(self):
        self.keyword_chains = {}
        self.delimit = '\x00'

    def add(self, keyword):
        if not isinstance(keyword, str):
            keyword = keyword.decode('utf-8')
        keyword = keyword.lower()
        chars = keyword.strip()
        if not chars:
            return
        level = self.keyword_chains
        for i in range(len(chars)):
            if chars[i] in level:
                level = level[chars[i]]
            else:
                if not isinstance(level, dict):
                    break
                for j in range(i, len(chars)):
                    level[chars[j]] = {}
                    last_level, last_char = level, chars[j]
                    level = level[chars[j]]
                last_level[last_char] = {self.delimit: 0}
                break
        if i == len(chars) - 1:
            level[self.delimit] = 0

    #从字典中删除一个关键词
    def remove(self, keyword):
        """Remove a keyword from the DFA filter"""
        if not isinstance(keyword, str):
            keyword = keyword.decode('utf-8')
        keyword = keyword.lower()
        chars = keyword.strip()
        if not chars:
            return

        def _remove_recursively(level, chars, index):
            """Helper function to recursively remove the keyword."""
            if index == len(chars):
                if self.delimit in level:
                    # Remove the terminal node (end of the keyword)
                    del level[self.delimit]
                return len(level) == 0
            char = chars[index]
            if char in level and _remove_recursively(level[char], chars, index + 1):
                # If the sub-level is empty, remove this character
                del level[char]
                return len(level) == 0
            return False

        _remove_recursively(self.keyword_chains, chars, 0)

    #从文本里加载 敏感词
    def parse(self, path):
        with open(path, encoding='UTF-8') as f:
            for keyword in f:
                self.add(keyword.strip())

    #从数据库中加载 图片资源
    async def parse_picture_from_db(self, db: AsyncSession):
        #db: AsyncSession = Depends(db_getter)
        print('---------parse_picture_from_db--load-------------')
        # 获取数据库中所有的图片--->添加到词库内存中
        images_dic = {'page': 1, 'limit': 0, 'v_order': None, 'v_order_field': None, 'type': 1, 'group_id': None}
        images_datas, count = await MediaDal(db).get_datas(**images_dic, v_return_count=True)
        print(f"-----图片库列表为:---->:{images_datas}")
        for imagedata in images_datas:
            self.add(imagedata.get('key_word'))

    #从数据库中加载 问答对
    async def parse_question_from_db(self, db: AsyncSession):
        #db: AsyncSession = Depends(db_getter)
        print('---------parse_question_from_db--load-------------')
        # 获取数据库中所有的问答对--->添加到词库内存中
        question_dic = {'page': 1, 'limit': 0, 'v_order': None, 'v_order_field': None}
        question_datas, count = await QuestionDal(db).get_datas(**question_dic, v_return_count=True)
        print(f"-----问答对库列表为:---->:{question_datas}")
        for questiondata in question_datas:
            self.add(questiondata.get('key_word'))

    #从数据库中加载 视频资源
    async def parse_video_from_db(self, db: AsyncSession):
        #db: AsyncSession = Depends(db_getter)
        print('---------parse_video_from_db--load-------------')
        # 获取数据库中所有的视频--->添加到词库内存中
        video_dic = {'page': 1, 'limit': 0, 'v_order': None, 'v_order_field': None, 'type': 2, 'group_id': None}
        video_datas, count = await MediaDal(db).get_datas(**video_dic, v_return_count=True)
        print(f"-----视频库列表为:---->:{video_datas}")
        for videodata in video_datas:
            self.add(videodata.get('key_word'))

    # 从数据库中加载 敏感词
    async def parse_from_db(self, db: AsyncSession):
        # db: AsyncSession = Depends(db_getter)
        print('---------parse_from_db-load-------------')
        sdl = SensitiveDal(db)
        datas = await sdl.get_sensitives()
        for keyword in datas:
            self.add(keyword.word_name)


    #最长匹配模式，确保敏感词过滤器优先匹配和替换较长的敏感词
    def filter(self, message, repl="*"):
        is_sensitive = False
        if not isinstance(message, str):
            message = message.decode('utf-8')
        message = message.lower()
        ret = []
        start = 0
        matched_sensitives = []  # 用来存储匹配的敏感词

        while start < len(message):
            level = self.keyword_chains
            longest_match_len = 0  # 记录最长匹配长度
            longest_match_word = None  # 记录最长匹配的敏感词
            step_ins = 0

            for i, char in enumerate(message[start:], start=1):
                if char in level:
                    level = level[char]
                    step_ins += 1
                    if self.delimit in level:
                        # 找到一个完整的敏感词
                        longest_match_len = step_ins
                        longest_match_word = message[start:start + step_ins]
                else:
                    break

            if longest_match_len > 0:
                # 进行最长匹配替换
                matched_sensitives.append(longest_match_word)
                ret.append(repl * longest_match_len)
                start += longest_match_len
                is_sensitive = True
            else:
                # 无匹配，直接保留原字符
                ret.append(message[start])
                start += 1

        # 返回三个参数
        return ''.join(ret), is_sensitive, matched_sensitives

#初始化全局对象
#初始化-敏感词
mydfafiter = DFAFilter()

#初始化-图片关键词
mydfafiter_picture = DFAFilter()

#初始化-视频关键词
mydfafiter_video = DFAFilter()

#初始化-问答对关键词
mydfafiter_question = DFAFilter()

"""
if __name__ == "__main__":
    gfw = DFAFilter()
    gfw.parse("keywords")
    import time

    t = time.process_time()
    print(gfw.filter("法轮功 我操操操", "*"))
    print(gfw.filter("针孔摄像机 我操操操", "*"))
    print(gfw.filter("售假人民币 我操操操", "*"))
    print(gfw.filter("传世私服 我操操操", "*"))
    print('Cost is %6.6f' % (time.process_time() - t))
"""
