import sys
import os
from pathlib import Path
base_platform_dir = Path(__file__).resolve().parents[2] / 'base_platform'
os.chdir(base_platform_dir)
sys.path.insert(0, str(base_platform_dir))
from app.models import role, role_permission, user, tenant, knowledge_base
from app.db.session import SessionLocal
from sqlalchemy import text
from app.services.knowledge.vectorization_service import DocumentVectorizationService

def vectorize_batch(tenant_id: int, batch_size: int=500, user_id: int=1, model_id: int=69):
    db = SessionLocal()
    try:
        db.execute(text(f'SET search_path TO tenant_{tenant_id}, public'))
        total_result = db.execute(text("\n            SELECT COUNT(*)\n            FROM knowledge_documents\n            WHERE vectorization_status = 'pending'\n        ")).fetchone()
        total_pending = total_result[0] if total_result else 0
        if total_pending == 0:
            print('✓ No pending documents to vectorize')
            return
        print(f'→ Found {total_pending} pending documents')
        print(f'→ Batch size: {batch_size}')
        print(f'→ Using model_id: {model_id}')
        print(f'→ Total batches: {(total_pending + batch_size - 1) // batch_size}')
        vectorization_service = DocumentVectorizationService(db)
        total_success = 0
        total_failed = 0
        batch_num = 1
        while True:
            pending_docs = db.execute(text(f"\n                SELECT id, title\n                FROM knowledge_documents\n                WHERE vectorization_status = 'pending'\n                ORDER BY id\n                LIMIT {batch_size}\n            ")).fetchall()
            if not pending_docs:
                break
            print(f"\n{'=' * 60}")
            print(f'Batch {batch_num}: Processing {len(pending_docs)} documents')
            print(f"{'=' * 60}")
            batch_success = 0
            batch_failed = 0
            for i, (doc_id, title) in enumerate(pending_docs, 1):
                try:
                    print(f'  [{i}/{len(pending_docs)}] Vectorizing: {title[:50]}...')
                    result = vectorization_service.vectorize_document(document_id=doc_id, tenant_id=tenant_id, user_id=user_id, model_id=model_id)
                    batch_success += 1
                    total_success += 1
                    print(f"    ✓ Success ({result.get('vector_count')} vectors)")
                except Exception as e:
                    batch_failed += 1
                    total_failed += 1
                    print(f'    ✗ Failed: {e}')
            print(f'\nBatch {batch_num} completed: {batch_success} success, {batch_failed} failed')
            print(f'Overall progress: {total_success + total_failed}/{total_pending} ({(total_success + total_failed) * 100 / total_pending:.1f}%)')
            batch_num += 1
            db.commit()
        print(f"\n{'=' * 60}")
        print(f'✓ All batches completed!')
        print(f'  Total success: {total_success}')
        print(f'  Total failed: {total_failed}')
        print(f"{'=' * 60}")
    except Exception as e:
        print(f'✗ Error: {e}')
        import traceback
        traceback.print_exc()
    finally:
        db.close()
if __name__ == '__main__':
    if len(sys.argv) < 2:
        print('Usage: python vectorize_batch.py <tenant_id> [batch_size] [user_id] [model_id]')
        print('')
        print('Examples:')
        print('  python vectorize_batch.py 1              # Default batch size 500')
        print('  python vectorize_batch.py 1 300          # Batch size 300')
        print('  python vectorize_batch.py 1 500 1 69     # Full parameters')
        sys.exit(1)
    tenant_id = int(sys.argv[1])
    batch_size = int(sys.argv[2]) if len(sys.argv) > 2 else 500
    user_id = int(sys.argv[3]) if len(sys.argv) > 3 else 1
    model_id = int(sys.argv[4]) if len(sys.argv) > 4 else 69
    vectorize_batch(tenant_id, batch_size, user_id, model_id)