from collections.abc import Sequence

import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql

revision: str = 'consolidated_001'
down_revision: str | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None

def upgrade() -> None:
    from app.db.base import Base
    try:
        from app.models.role_permission import RolePermission
    except ImportError:
        pass
    try:
        from app.models.data_model import (
            DocumentMetadataValue,
            KnowledgeMetadataField,
            TagAutoRule,
            TagCategory,
        )
    except ImportError:
        pass
    Base.metadata.create_all(bind=op.get_bind())
    from sqlalchemy import inspect
    bind = op.get_bind()
    inspector = inspect(bind)
    existing_tables = inspector.get_table_names()
    if 'local_models' not in existing_tables:
        op.create_table('local_models', sa.Column('id', sa.Integer(), nullable=False), sa.Column('name', sa.String(length=255), nullable=False), sa.Column('identifier', sa.String(length=255), nullable=False), sa.Column('model_type', sa.String(length=50), nullable=True), sa.Column('base_model', sa.String(length=255), nullable=True), sa.Column('model_path', sa.Text(), nullable=False), sa.Column('status', sa.String(length=50), nullable=True), sa.Column('config', postgresql.JSON(astext_type=sa.Text()), nullable=True), sa.Column('tenant_id', sa.Integer(), nullable=True), sa.Column('created_at', sa.DateTime(), nullable=True), sa.Column('updated_at', sa.DateTime(), nullable=True), sa.Column('is_deleted', sa.Boolean(), nullable=True, default=False), sa.Column('deleted_at', sa.DateTime(), nullable=True), sa.Column('deleted_by', sa.Integer(), nullable=True), sa.Column('asset_kind', sa.String(length=50), nullable=True), sa.Column('runtime_kind', sa.String(length=50), nullable=True), sa.Column('source_kind', sa.String(length=50), nullable=True), sa.Column('published_model_id', sa.Integer(), nullable=True), sa.PrimaryKeyConstraint('id'), sa.UniqueConstraint('identifier'), sa.ForeignKeyConstraint(['published_model_id'], ['models.id'], name='fk_local_models_published_model_id_models'))
        op.create_index('ix_local_models_tenant_id', 'local_models', ['tenant_id'])
    mp_columns = {c['name'] for c in inspector.get_columns('model_providers')}
    for col, coldef in [('provider_kind', sa.Column('provider_kind', sa.String(50), nullable=True)), ('protocol', sa.Column('protocol', sa.String(50), nullable=True)), ('auth_type', sa.Column('auth_type', sa.String(50), nullable=True)), ('capabilities', sa.Column('capabilities', postgresql.JSON(astext_type=sa.Text()), nullable=True)), ('is_local', sa.Column('is_local', sa.Boolean(), nullable=True)), ('healthcheck_path', sa.Column('healthcheck_path', sa.String(255), nullable=True)), ('extra_config', sa.Column('extra_config', postgresql.JSON(astext_type=sa.Text()), nullable=True))]:
        if col not in mp_columns:
            op.add_column('model_providers', coldef, schema='public')
    m_columns = {c['name'] for c in inspector.get_columns('models')}
    for col, coldef in [('remote_model_id', sa.Column('remote_model_id', sa.String(255), nullable=True)), ('context_length', sa.Column('context_length', sa.Integer(), nullable=True)), ('max_output_tokens', sa.Column('max_output_tokens', sa.Integer(), nullable=True)), ('supports_stream', sa.Column('supports_stream', sa.Boolean(), nullable=True)), ('supports_tools', sa.Column('supports_tools', sa.Boolean(), nullable=True)), ('priority', sa.Column('priority', sa.Integer(), nullable=True)), ('extra_config', sa.Column('extra_config', postgresql.JSON(astext_type=sa.Text()), nullable=True))]:
        if col not in m_columns:
            op.add_column('models', coldef, schema='public')
    op.execute("UPDATE public.model_providers SET provider_kind = 'public_api' WHERE provider_kind IS NULL")
    op.execute("UPDATE public.model_providers SET protocol = 'openai_compatible' WHERE protocol IS NULL")
    op.execute("UPDATE public.model_providers SET auth_type = CASE WHEN api_key IS NULL OR api_key = '' THEN 'none' ELSE 'bearer' END WHERE auth_type IS NULL")
    op.execute("UPDATE public.model_providers SET capabilities = '[]'::json WHERE capabilities IS NULL")
    op.execute('UPDATE public.model_providers SET is_local = false WHERE is_local IS NULL')
    op.execute('UPDATE public.models SET remote_model_id = code WHERE remote_model_id IS NULL')
    op.execute('UPDATE public.models SET supports_stream = true WHERE supports_stream IS NULL')
    op.execute('UPDATE public.models SET supports_tools = false WHERE supports_tools IS NULL')
    op.execute('UPDATE public.models SET priority = 100 WHERE priority IS NULL')
    op.execute("UPDATE local_models SET asset_kind = CASE WHEN model_type = 'lora' THEN 'adapter' ELSE 'full_model' END WHERE asset_kind IS NULL")
    op.execute("UPDATE local_models SET runtime_kind = 'generic' WHERE runtime_kind IS NULL")
    op.execute("UPDATE local_models SET source_kind = 'imported' WHERE source_kind IS NULL")
    op.execute('\n        INSERT INTO public.model_providers (name, description, configured, default_base_url, api_key, base_url, priority, enabled,\n            provider_kind, protocol, auth_type, capabilities, is_local, healthcheck_path, extra_config, created_at, updated_at, is_deleted)\n        SELECT \'本地 MLX\', \'Local MLX provider\', true, \'http://127.0.0.1:8010\', NULL, \'http://127.0.0.1:8010\', 50, true,\n               \'local_mlx\', \'openai_compatible\', \'none\', \'["chat", "embedding", "rerank"]\'::json, true, \'/health\', NULL, NOW(), NOW(), false\n        WHERE NOT EXISTS (SELECT 1 FROM public.model_providers WHERE name = \'本地 MLX\')\n    ')
    op.execute('\n        INSERT INTO public.model_providers (name, description, configured, default_base_url, api_key, base_url, priority, enabled,\n            provider_kind, protocol, auth_type, capabilities, is_local, healthcheck_path, extra_config, created_at, updated_at, is_deleted)\n        SELECT \'ECS A100\', \'Self-hosted ECS model provider\', false, NULL, NULL, NULL, 60, true,\n               \'self_hosted\', \'openai_compatible\', \'bearer\', \'["chat", "embedding", "rerank"]\'::json, false, \'/health\', NULL, NOW(), NOW(), false\n        WHERE NOT EXISTS (SELECT 1 FROM public.model_providers WHERE name = \'ECS A100\')\n    ')
    a_columns = {c['name'] for c in inspector.get_columns('agents')}
    if 'use_reranker' not in a_columns:
        op.add_column('agents', sa.Column('use_reranker', sa.Boolean(), nullable=False, server_default='true'))
    if 'model_id' not in a_columns:
        op.add_column('agents', sa.Column('model_id', sa.Integer(), nullable=True))
        op.create_foreign_key('fk_agents_model_id_models', 'agents', 'models', ['model_id'], ['id'], source_schema=None, referent_schema='public')
        op.create_index('ix_agents_model_id', 'agents', ['model_id'])
        if 'orchestrator_model' in a_columns:
            op.execute('\n                UPDATE agents SET model_id = models.id\n                FROM public.models\n                WHERE agents.model_id IS NULL\n                  AND (agents.orchestrator_model = models.name OR agents.orchestrator_model = models.code)\n            ')
    if 'orchestrator_model' in a_columns:
        op.drop_column('agents', 'orchestrator_model')
    conn = op.get_bind()
    result = conn.execute(sa.text('\n        SELECT name, MIN(id) as keep_id, ARRAY_AGG(id ORDER BY id) as all_ids\n        FROM public.model_providers GROUP BY name HAVING COUNT(*) > 1\n    '))
    for row in result.fetchall():
        name, keep_id, all_ids = row
        remove_ids = [i for i in all_ids if i != keep_id]
        if remove_ids:
            conn.execute(sa.text('UPDATE public.models SET provider_id = :k WHERE provider_id = ANY(:r)'), {'k': keep_id, 'r': remove_ids})
            conn.execute(sa.text('DELETE FROM public.model_providers WHERE id = ANY(:r)'), {'r': remove_ids})
    dv_columns = {c['name'] for c in inspector.get_columns('document_vectors')}
    dv_new_cols = [('parent_chunk_id', sa.Column('parent_chunk_id', sa.String(100), nullable=True)), ('is_parent', sa.Column('is_parent', sa.Boolean(), server_default='false', nullable=False)), ('chunk_level', sa.Column('chunk_level', sa.String(20), nullable=True)), ('prev_chunk_id', sa.Column('prev_chunk_id', sa.String(100), nullable=True)), ('next_chunk_id', sa.Column('next_chunk_id', sa.String(100), nullable=True)), ('doc_type', sa.Column('doc_type', sa.String(50), nullable=True)), ('doc_number', sa.Column('doc_number', sa.String(100), nullable=True)), ('issuing_authority', sa.Column('issuing_authority', sa.String(100), nullable=True)), ('references', sa.Column('references', sa.Text(), nullable=True)), ('chunk_id', sa.Column('chunk_id', sa.String(100), nullable=True)), ('doc_status', sa.Column('doc_status', sa.String(20), nullable=True)), ('issue_date_int', sa.Column('issue_date_int', sa.BigInteger(), nullable=True))]
    for col, coldef in dv_new_cols:
        if col not in dv_columns:
            op.add_column('document_vectors', coldef)
    dv_indexes = {i['name'] for i in inspector.get_indexes('document_vectors')}
    for idx, cols in [('ix_document_vectors_parent_chunk_id', ['parent_chunk_id']), ('ix_document_vectors_is_parent', ['is_parent']), ('ix_document_vectors_doc_type', ['doc_type']), ('ix_document_vectors_chunk_id', ['chunk_id'])]:
        if idx not in dv_indexes:
            op.create_index(idx, 'document_vectors', cols)
    kd_columns = {c['name'] for c in inspector.get_columns('knowledge_documents')}
    kd_new_cols = [('doc_type', sa.Column('doc_type', sa.String(50), nullable=True)), ('doc_number', sa.Column('doc_number', sa.String(100), nullable=True)), ('issuing_authority', sa.Column('issuing_authority', sa.String(200), nullable=True)), ('enable_parent_child', sa.Column('enable_parent_child', sa.Boolean(), server_default='false', nullable=False)), ('enable_reference_extraction', sa.Column('enable_reference_extraction', sa.Boolean(), server_default='false', nullable=False)), ('window_size', sa.Column('window_size', sa.Integer(), nullable=True)), ('doc_number_year', sa.Column('doc_number_year', sa.Integer(), nullable=True)), ('doc_number_serial', sa.Column('doc_number_serial', sa.Integer(), nullable=True)), ('issue_date', sa.Column('issue_date', sa.Date(), nullable=True)), ('effective_date', sa.Column('effective_date', sa.Date(), nullable=True)), ('expire_date', sa.Column('expire_date', sa.Date(), nullable=True)), ('doc_status', sa.Column('doc_status', sa.String(20), nullable=True)), ('supersedes_doc_ids', sa.Column('supersedes_doc_ids', sa.JSON(), nullable=True)), ('superseded_by_doc_id', sa.Column('superseded_by_doc_id', sa.Integer(), sa.ForeignKey('knowledge_documents.id', ondelete='SET NULL'), nullable=True)), ('tax_type_tags', sa.Column('tax_type_tags', sa.JSON(), nullable=True)), ('has_attachment', sa.Column('has_attachment', sa.Boolean(), nullable=True, server_default='false')), ('attachment_types', sa.Column('attachment_types', sa.JSON(), nullable=True)), ('parse_quality_score', sa.Column('parse_quality_score', sa.Float(), nullable=True)), ('content_hash', sa.Column('content_hash', sa.String(64), nullable=True)), ('version_number', sa.Column('version_number', sa.Integer(), nullable=True, server_default='1'))]
    for col, coldef in kd_new_cols:
        if col not in kd_columns:
            op.add_column('knowledge_documents', coldef)
    kd_indexes = {i['name'] for i in inspector.get_indexes('knowledge_documents')}
    for idx, cols in [('ix_knowledge_documents_doc_type', ['doc_type']), ('ix_knowledge_documents_doc_number', ['doc_number']), ('ix_knowledge_documents_enable_parent_child', ['enable_parent_child']), ('ix_knowledge_documents_enable_reference_extraction', ['enable_reference_extraction']), ('ix_knowledge_documents_content_hash', ['content_hash'])]:
        if idx not in kd_indexes:
            op.create_index(idx, 'knowledge_documents', cols)

def downgrade() -> None:
    from app.db.base import Base
    Base.metadata.drop_all(bind=op.get_bind())
