import re

import markdownify as _markdownify
from common_logging import get_logger

logger = get_logger(__name__)


class TaxMarkdownConverter(_markdownify.MarkdownConverter):

    def convert_table(self, el, text, convert_as_inline):
        return '\n\n' + str(el) + '\n\n'

    def convert_tr(self, el, text, convert_as_inline):
        return ''

    def convert_td(self, el, text, convert_as_inline):
        return ''

    def convert_th(self, el, text, convert_as_inline):
        return ''

    def convert_strong(self, el, text, convert_as_inline):
        return f'<strong>{text}</strong>'

    def convert_b(self, el, text, convert_as_inline):
        return f'<strong>{text}</strong>'

class DocumentCleaner:

    def __init__(self):
        pass

    def html_to_markdown(self, html_content: str) -> str:
        try:
            markdown_content = TaxMarkdownConverter(heading_style='ATX', bullets='-', strip=['script', 'style']).convert(html_content)
            markdown_content = self.clean_markdown(markdown_content)
            return markdown_content
        except Exception as e:
            logger.error(f'HTML 转 Markdown 失败: {e}')
            from bs4 import BeautifulSoup
            soup = BeautifulSoup(html_content, 'lxml')
            return soup.get_text(separator='\n', strip=True)

    def clean_markdown(self, markdown_content: str) -> str:
        markdown_content = re.sub('\\n{3,}', '\n\n', markdown_content)
        lines = [line.strip() for line in markdown_content.split('\n')]
        markdown_content = '\n'.join(lines)
        noise_patterns = ['!\\[.*?\\]\\(.*?huibiao.*?\\)\\s*\\S+.*?\\n', '\\[下载文字版\\]\\(.*?\\)\\n?', '\\[下载图片版\\]\\(.*?\\)\\n?', '字体：\\s*【大】\\s*【中】\\s*【小】', '分享到：.*', '全文有效', '成文日期：.*', '【打印】', '【下载】', '纠错或建议', '[^\\n]*\\{[^\\n]*font[^\\n]*\\}[^\\n]*']
        for pattern in noise_patterns:
            markdown_content = re.sub(pattern, '', markdown_content)
        markdown_content = re.sub('[^\\n]*\\{[^}]*font[^}]*\\}', '', markdown_content, flags=re.DOTALL)
        markdown_content = self._normalize_headings(markdown_content)
        markdown_content = self._normalize_lists(markdown_content)
        markdown_content = re.sub('\\n{3,}', '\n\n', markdown_content)
        return markdown_content.strip()

    def _normalize_headings(self, content: str) -> str:
        lines = content.split('\n')
        normalized_lines = []
        for line in lines:
            if line.startswith('#'):
                if normalized_lines and normalized_lines[-1] != '':
                    normalized_lines.append('')
                normalized_lines.append(line)
                normalized_lines.append('')
            else:
                normalized_lines.append(line)
        return '\n'.join(normalized_lines)

    def _normalize_lists(self, content: str) -> str:
        content = re.sub('^\\s*[\\*\\+]\\s+', '- ', content, flags=re.MULTILINE)
        content = re.sub('^\\s*(\\d+)\\.\\s+', '\\1. ', content, flags=re.MULTILINE)
        return content

    def clean_text(self, text: str) -> str:
        text = re.sub('\\s+', ' ', text)
        text = re.sub('[\\x00-\\x08\\x0b-\\x0c\\x0e-\\x1f\\x7f-\\x9f]', '', text)
        return text.strip()

    @staticmethod
    def strip_ocr_content(markdown: str) -> str:
        markdown = re.sub('\\n\\n---\\n\\n## 正文图片内容\\b.*', '', markdown, flags=re.DOTALL)
        markdown = re.sub('<!-- 图片\\d* OCR -->\\n?', '', markdown)
        return markdown.strip()

    def markdown_to_rag_text(self, markdown: str) -> str:
        if not markdown:
            return ''
        tables = []

        def _save_table(m):
            tables.append(m.group(0))
            return f'\n\n__TABLE_{len(tables) - 1}__\n\n'
        text = re.sub('<table[\\s\\S]*?</table>', _save_table, markdown, flags=re.IGNORECASE)
        text = re.sub('!\\[.*?\\]\\(.*?\\)', '', text)
        text = re.sub('\\[([^\\]]*)\\]\\([^)]*\\)', '\\1', text)
        text = re.sub('^#{1,6}\\s+', '', text, flags=re.MULTILINE)
        text = re.sub('\\*{1,2}([^*]+)\\*{1,2}', '\\1', text)
        text = re.sub('<strong>(.*?)</strong>', '\\1', text, flags=re.IGNORECASE | re.DOTALL)
        text = re.sub('<b>(.*?)</b>', '\\1', text, flags=re.IGNORECASE | re.DOTALL)
        text = re.sub('<em>(.*?)</em>', '\\1', text, flags=re.IGNORECASE | re.DOTALL)
        text = re.sub('^\\s*---+\\s*$', '', text, flags=re.MULTILINE)
        text = re.sub('`([^`]*)`', '\\1', text)
        for i, table in enumerate(tables):
            text = text.replace(f'__TABLE_{i}__', table)
        text = re.sub('\\n{3,}', '\n\n', text)
        return text.strip()

    def extract_plain_text(self, html_content: str) -> str:
        from bs4 import BeautifulSoup
        soup = BeautifulSoup(html_content, 'lxml')
        for script in soup(['script', 'style']):
            script.decompose()
        text = soup.get_text(separator='\n', strip=True)
        text = self.clean_text(text)
        return text
