import secrets
from datetime import datetime, timedelta
from typing import TYPE_CHECKING

import redis
from jose import JWTError, jwt
from passlib.context import CryptContext

from app.config import settings
from common_logging import get_logger

logger = get_logger(__name__)

if TYPE_CHECKING:
    from app.models import User

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
_redis_client: redis.Redis | None = None


def get_redis_client() -> redis.Redis:
    global _redis_client
    if _redis_client is None:
        if settings.REDIS_URL:
            _redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True)
        else:
            _redis_client = redis.Redis(
                host=settings.REDIS_HOST,
                port=settings.REDIS_PORT,
                db=settings.REDIS_DB,
                password=settings.REDIS_PASSWORD,
                decode_responses=True,
            )
    return _redis_client


def verify_password(plain_password: str, hashed_password: str) -> bool:
    if isinstance(plain_password, str):
        plain_password = plain_password.encode("utf-8")[:72].decode("utf-8", errors="ignore")
    result = pwd_context.verify(plain_password, hashed_password)
    if not result:
        logger.warning("Password verification failed")
    return result


def get_password_hash(password: str) -> str:
    hashed = pwd_context.hash(password)
    logger.info("Password hashed successfully")
    return hashed


def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.utcnow() + expires_delta
    else:
        expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
    return encoded_jwt


def decode_access_token(token: str) -> dict | None:
    try:
        payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
        return payload
    except JWTError as e:
        logger.warning(f"Token decode failed: {e}")
        return None


verify_token = decode_access_token


def generate_verification_token() -> str:
    return secrets.token_urlsafe(32)


def validate_password_strength(password: str) -> tuple[bool, str]:
    if len(password) < 8:
        return (False, "密码至少需要8个字符")
    if not any(c.isupper() for c in password):
        return (False, "密码必须包含大写字母")
    if not any(c.islower() for c in password):
        return (False, "密码必须包含小写字母")
    if not any(c.isdigit() for c in password):
        return (False, "密码必须包含数字")
    special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?"
    if not any(c in special_chars for c in password):
        return (False, "密码必须包含特殊字符")
    return (True, "")


def revoke_token(token: str, exp_timestamp: int) -> None:
    client = get_redis_client()
    ttl = exp_timestamp - int(datetime.utcnow().timestamp())
    if ttl > 0:
        client.setex(f"revoked_token:{token}", ttl, "1")


def is_token_revoked(token: str) -> bool:
    try:
        client = get_redis_client()
        return client.exists(f"revoked_token:{token}") > 0
    except Exception as e:
        logger.error(f"Redis unavailable for token revocation check: {e}")
        return True


def get_primary_role(user) -> str:
    if hasattr(user, "user_roles") and user.user_roles:
        for ur in user.user_roles:
            if not ur.is_deleted and (not ur.role.is_deleted):
                return ur.role.code
    return user.role if user.role else "customer_user"


def is_platform_admin(user) -> bool:
    return get_primary_role(user) == "platform_admin"


def is_platform_user(user) -> bool:
    return get_primary_role(user) == "platform_user"


def is_customer_admin(user) -> bool:
    return get_primary_role(user) == "customer_admin"


def is_customer_user(user) -> bool:
    return get_primary_role(user) == "customer_user"


def has_platform_access(user) -> bool:
    return get_primary_role(user) in ["platform_admin", "platform_user"]


def has_admin_access(user) -> bool:
    return get_primary_role(user) in ["platform_admin", "customer_admin"]


def can_manage_users(user) -> bool:
    return get_primary_role(user) in ["platform_admin", "customer_admin"]


def can_delete_resource(user) -> bool:
    return get_primary_role(user) in ["platform_admin", "customer_admin"]


def can_view_all_resources(user) -> bool:
    return get_primary_role(user) in ["platform_admin", "platform_user"]


def can_access_company(user, target_tenant_id: int | None) -> bool:
    if has_platform_access(user):
        return True
    if user.tenant_id is None:
        return False
    return user.tenant_id == target_tenant_id


def can_manage_resource(
    user, resource_owner_id: int, resource_tenant_id: int | None = None
) -> bool:
    if is_platform_admin(user):
        return True
    if resource_owner_id == user.id:
        return True
    if is_customer_admin(user) and resource_tenant_id:
        return user.tenant_id == resource_tenant_id
    return False


def get_role_display(user, company_name: str | None = None) -> str:
    role_names = {
        "platform_admin": "平台管理员",
        "platform_user": "平台用户",
        "customer_admin": "企业管理员",
        "customer_user": "企业用户",
    }
    role_display = role_names.get(user.role, user.role)
    if company_name and user.role in ["customer_admin", "customer_user"]:
        return f"{company_name}/{role_display}"
    return role_display


def require_admin(user):
    from fastapi import HTTPException, status

    if not has_admin_access(user):
        raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required")


def require_platform_admin(current_user: "User" = None):
    from fastapi import Depends, HTTPException, status

    from app.api.deps import get_current_user

    if current_user is None:

        def _require_platform_admin(user: "User" = Depends(get_current_user)):
            if not is_platform_admin(user):
                raise HTTPException(
                    status_code=status.HTTP_403_FORBIDDEN, detail="Platform admin access required"
                )
            return user

        return _require_platform_admin
    else:
        if not is_platform_admin(current_user):
            raise HTTPException(
                status_code=status.HTTP_403_FORBIDDEN, detail="Platform admin access required"
            )
        return current_user


def get_user_roles(user) -> list:
    if hasattr(user, "user_roles"):
        return [
            ur.role.code for ur in user.user_roles if not ur.is_deleted and (not ur.role.is_deleted)
        ]
    return []


def has_any_role(user, role_codes: list) -> bool:
    user_role_codes = get_user_roles(user)
    if not user_role_codes:
        return user.role in role_codes
    return any(code in role_codes for code in user_role_codes)


def has_role(user, role_code: str) -> bool:
    return has_any_role(user, [role_code])
