#!/usr/bin/env python3

# /// script
# requires-python = ">=3.11"
# dependencies = []
# ///
import csv
from dataclasses import dataclass
import datetime
import logging
import logging.config
import re
import shutil
import tempfile
from typing import Dict, List, Optional, Tuple
import unicodedata
import urllib.parse
import urllib.request
import urllib.error
import subprocess
from pathlib import Path

import base64
import json
import hmac
import hashlib
import os
import zipfile
import argparse

ACCESS_KEY_ID = "SCW8F17A9K3W7J57DJMH"
SECRET_ACCESS_KEY = "0036fc8d-06e1-4a69-bc0f-e5368abd1967"

# Object Storage Region
REGION = "fr-par"
BUCKET_NAME = "vq-pbix"
HOST = f"{BUCKET_NAME}.s3.{REGION}.scw.cloud"
###

# Default log level
DEFAULT_LOG_LEVEL = "ERROR"

PBI_TOOLS_PATH = "D:\\vq-pbix\\pbi-tools"

# Configure logging
logging.config.dictConfig(
    {
        "version": 1,
        "disable_existing_loggers": False,
        "formatters": {
            "standard": {
                "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
            },
        },
        "handlers": {
            "console": {
                "class": "logging.StreamHandler",
                "level": "DEBUG",
                "formatter": "standard",
                "stream": "ext://sys.stdout",
            },
        },
        "loggers": {
            "": {"handlers": ["console"], "level": DEFAULT_LOG_LEVEL, "propagate": True}
        },
    }
)

logger = logging.getLogger(__name__)


@dataclass
class Model:
    description: str
    line_num: int
    table: str
    model_type: str = ""
    name: str = ""

    @property
    def uniq(self) -> str:
        return f"{slugify(self.table)}.{slugify(self.name)}"


def slugify(text: str, allow_unicode: bool = False) -> str:
    """
    Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
    dashes to single dashes. Remove characters that aren't alphanumerics,
    underscores, or hyphens. Convert to lowercase. Also strip leading and
    trailing whitespace, dashes, and underscores.
    """
    if allow_unicode:
        # Keep Unicode chars but normalize (e.g., é → é)
        text = unicodedata.normalize("NFKC", text)
    else:
        # Convert Unicode to closest ASCII representation (é → e)
        text = (
            unicodedata.normalize("NFKD", text)
            .encode("ascii", "ignore")
            .decode("ascii")
        )

    # Replace unwanted chars with hyphen
    text = re.sub(r"[^\w\s-]", "", text.lower())
    # Replace spaces, underscores, or multiple hyphens with a single hyphen
    text = re.sub(r"[\s_-]+", "-", text)
    # Strip leading/trailing hyphens
    text = re.sub(r"^-+|-+$", "", text)

    return text


class ExportTables:
    def __init__(self, path) -> None:
        self.path = Path(path)
        self.models = self.extract_model()

    def extract_model(self):
        """
        Extract Model from .tmdl to get description
        """
        models_folder = self.path / "Model" / "tables"
        models_dict = {}
        for model in models_folder.glob("**/*.tmdl"):
            logger.info(f"Processing model: {model}")
            filename = model.stem
            with model.open(encoding="utf-8") as f:
                all_lines = [line.strip() for line in f.readlines()]

            for i, line in enumerate(all_lines):
                delta = 1
                description = ""
                # ignore lines that are not comments
                if not line.startswith("///"):
                    continue
                try:
                    description = line.split(maxsplit=1)[1]
                except IndexError:
                    # ignore '///'
                    continue

                model = Model(
                    description=description, table=slugify(filename), line_num=i
                )

                # Get the line containing model info, handling special cases
                try:
                    model_info = all_lines[model.line_num + delta].split("=")[0]
                    if model_info == "///":
                        # Handle case with empty comment line
                        delta += 1
                        model_info = all_lines[model.line_num + delta].split("=")[0]
                except IndexError:
                    logger.error(
                        f"Index error when accessing line {model.line_num + delta} in {model.table}"
                    )
                    continue

                try:
                    model.model_type, model.name = model_info.split(" ", 1)
                except ValueError:
                    logger.error(
                        f"{delta=} {model_info=}, {model.line_num=}, {model.description=}"
                    )
                else:
                    if model.uniq in models_dict:
                        logger.warning(f"Duplicate model found: {model.uniq}")
                    models_dict[model.uniq] = model

        return models_dict

    def extract_field_type_and_names(self, command: Dict) -> Tuple[str, str, str]:
        """
        Extract the field type and table/name from a command.

        Args:
            command: The command to process

        Returns:
            Tuple of (type, table, name)

        Raises:
            ValueError: If the command doesn't contain valid field information
        """
        # Extract field type and table/name based on the command structure
        if "Measure" in command:
            _type = "Measure"
            if "Name" not in command or "." not in command["Name"]:
                raise ValueError(f"Invalid Measure command format: {command}")
            table, name = command["Name"].split(".", 1)

        elif "Column" in command:
            _type = "Column"
            if "Name" not in command or "." not in command["Name"]:
                raise ValueError(f"Invalid Column command format: {command}")
            table, name = command["Name"].split(".", 1)

        elif "Aggregation" in command:
            if (
                "Name" not in command
                or "(" not in command["Name"]
                or ")" not in command["Name"]
            ):
                raise ValueError(f"Invalid Aggregation command format: {command}")

            _type = "Aggregation"
            start_idx = command["Name"].find("(") + 1
            end_idx = command["Name"].find(")")
            txt_extraction = command["Name"][start_idx:end_idx]

            if "." not in txt_extraction:
                raise ValueError(f"Invalid Aggregation name format: {txt_extraction}")
            table, name = txt_extraction.split(".", 1)
        else:
            raise ValueError(f"Unknown command type: {command}")

        return _type, table, name

    def process_visual_container(
        self,
        layout_file: Path,
        section: str,
        query_dict: str,
    ) -> List[List[str]]:
        """
        Process a single visual container configuration file.

        Args:
            layout_file: Path to the config.json file
            section: Section name
            query_dict: Visual container ID

        Returns:
            List of fields extracted from this visual container
        """
        container_fields = []

        # Load the config.json file
        with layout_file.open(encoding="utf-8") as file:
            try:
                report_layout = json.load(file)
            except json.JSONDecodeError:
                logger.error(f"Invalid JSON in {layout_file}")
                return []

        if (
            "singleVisual" not in report_layout
            or "prototypeQuery" not in report_layout["singleVisual"]
        ):
            logger.debug(f"Missing required keys in {layout_file=}")
            return []

        column_properties = report_layout["singleVisual"].get("columnProperties")

        for command in report_layout["singleVisual"]["prototypeQuery"]["Select"]:
            try:
                # Extract field type and names
                _type, table, name = self.extract_field_type_and_names(command)

            except ValueError as e:
                # Log the error and continue with next command
                logger.debug(f"Skipping command: {e}")
                continue
            # Process the command to extract reference name and description
            field_info = self.process_command(
                command,
                column_properties,
                table,
                name,
            )

            if field_info is None:
                continue

            native_reference_name, description = field_info

            # Add the extracted field information to our results
            container_fields.append(
                [
                    section,
                    query_dict,
                    table,
                    name,
                    native_reference_name,
                    description,
                    _type,
                ]
            )

        return container_fields

    def get_fields_report(self) -> List[List[str]]:
        """
        Extract field information from a Power BI report structure.

        Args:
            path: Path to the directory containing the extracted report files
            name: Name of the report file (optional, for future use)

        Returns:
            List of field information including page, visual ID, table, name, etc.
        """

        # Validate section folder
        section_folder = self.validate_section_folder()
        if section_folder is None:
            return []

        fields = []
        for layout_file in section_folder.glob("**/visualContainers/**/config.json"):
            logger.info(f"Processing layout file: {layout_file}")

            # Extract section and visual container ID from path
            section, query_dict = self.get_section_and_query_dict(layout_file)
            if section is None or query_dict is None:
                continue

            # Process this visual container
            container_fields = self.process_visual_container(
                layout_file,
                section,
                query_dict,
            )
            fields.extend(container_fields)

        return fields

    def get_section_and_query_dict(
        self, layout_file: Path
    ) -> Tuple[Optional[str], Optional[str]]:
        """
        Extract section name and visual container ID from the layout file path.

        Args:
            layout_file: Path to the layout file

        Returns:
            Tuple of (section_name, visual_container_id) or (None, None) if extraction fails
        """
        # Get the path parts in a reliable, platform-independent way
        path_parts = layout_file.parts

        # Find the indices of key directory names in the path
        if "sections" not in path_parts or "visualContainers" not in path_parts:
            logger.error(
                f"Path structure doesn't contain expected directories in {layout_file}"
            )
            return None, None

        sections_idx = path_parts.index("sections")
        visual_containers_idx = path_parts.index("visualContainers")

        # Extract section name and visual container ID
        if (
            len(path_parts) <= sections_idx + 1
            or len(path_parts) <= visual_containers_idx + 1
        ):
            logger.error(f"Invalid path structure in {layout_file}")
            return None, None

        section = path_parts[sections_idx + 1]
        query_dict = path_parts[visual_containers_idx + 1]

        return section, query_dict

    def validate_section_folder(self) -> Optional[Path]:
        """
        Validate that the section folder exists in the report structure.

        Args:
            path: Base path to the extracted report

        Returns:
            Path to the section folder if valid, None otherwise
        """
        section_folder = self.path / "Report" / "sections"
        if not section_folder.is_dir():
            logger.error(f"Section folder not found at {section_folder}")
            return None
        return section_folder

    def process_command(
        self,
        command,
        column_properties,
        table,
        name,
    ) -> Optional[List[str | None]]:
        """
        Process a single command from the prototype query.

        Args:
            command: The command to process
            column_properties: Column properties dictionary
            table: Table name
            name: Field name

        Returns:
            List of field data or None if processing fails
        """
        # Ensure NativeReferenceName exists
        if "NativeReferenceName" not in command:
            return None

        slug = f"{slugify(table)}.{slugify(name)}"
        model = self.models.get(slug)
        description = model.description if model else None
        if not model:
            logger.debug(f"No model found for {slug=}; {name=}")

        native_reference_name = self.normalize_reference_name(
            column_properties=column_properties,
            name=name,
            command=command,
        )

        return [native_reference_name, description]

    def normalize_reference_name(self, column_properties, name, command) -> str:
        """
        Normalize reference names by handling display name extraction and removing
        Power BI's automatic "1" suffix that sometimes appears.

        Args:
            column_properties: Dictionary of column properties if available
            name: The original column name
            command: The query command containing reference information

        Returns:
            Normalized reference name
        """
        native_reference_name = command["NativeReferenceName"]
        if column_properties:
            try:
                native_reference_name = column_properties[command["Name"]][
                    "displayName"
                ]
            except KeyError:
                pass

        if native_reference_name == name + "1":
            native_reference_name = name
        return native_reference_name

    def save_to_csv(self, output_file: Path) -> None:
        """
        Save extracted field information to a CSV file.

        Args:
            data: List of field information
            output_file: Path to save the CSV file
        """
        # Define column names for clarity
        data = self.get_fields_report()
        columns = [
            "Page",
            "Visual ID",
            "Table",
            "Name",
            "NativeReferenceName",
            "Description",
            "Type",
        ]

        # Ensure output directory exists
        output_path = Path(output_file)
        output_path.parent.mkdir(parents=True, exist_ok=True)

        # Write to CSV using built-in csv module
        with open(output_path, "w", newline="", encoding="utf-8") as f:
            writer = csv.writer(f)
            # Write header
            writer.writerow(columns)
            # Write data rows
            writer.writerows(data)

        logger.info(f"Successfully saved {len(data)} records to {output_file}")


def sign(key, msg):
    """Helper function for HMAC signing"""
    return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest()


def get_signature_key(key, date_stamp, region_name, service_name):
    """Generate the signing key for AWS Signature V4"""
    k_date = sign(("AWS4" + key).encode("utf-8"), date_stamp)
    k_region = sign(k_date, region_name)
    k_service = sign(k_region, service_name)
    k_signing = sign(k_service, "aws4_request")
    return k_signing


def upload_to_scaleway(content_sha256, target_file):
    # Configuration - REPLACE THESE WITH YOUR VALUES
    region = "fr-par"  # or "nl-ams", "pl-waw"

    # Get current time in UTC
    now = datetime.datetime.now(datetime.timezone.utc)
    amz_date = now.strftime("%Y%m%dT%H%M%SZ")
    date_stamp = now.strftime("%Y%m%d")

    # 2. Prepare the request components
    host = f"{BUCKET_NAME}.s3.{region}.scw.cloud"
    canonical_uri = f"/{urllib.parse.quote(target_file, safe='/')}"
    print(f"{canonical_uri=}")
    canonical_querystring = ""
    canonical_headers = (
        f"host:{host}\nx-amz-content-sha256:{content_sha256}\nx-amz-date:{amz_date}\n"
    )
    signed_headers = "host;x-amz-content-sha256;x-amz-date"

    # 3. Create canonical request
    canonical_request = f"PUT\n{canonical_uri}\n{canonical_querystring}\n{canonical_headers}\n{signed_headers}\n{content_sha256}"

    # 4. Create string to sign
    algorithm = "AWS4-HMAC-SHA256"
    credential_scope = f"{date_stamp}/{region}/s3/aws4_request"
    canonical_request_hash = hashlib.sha256(
        canonical_request.encode("utf-8")
    ).hexdigest()
    string_to_sign = (
        f"{algorithm}\n{amz_date}\n{credential_scope}\n{canonical_request_hash}"
    )

    # 5. Calculate the signature
    signing_key = get_signature_key(SECRET_ACCESS_KEY, date_stamp, region, "s3")
    signature = hmac.new(
        signing_key, string_to_sign.encode("utf-8"), hashlib.sha256
    ).hexdigest()

    # 6. Create authorization header
    authorization_header = (
        f"{algorithm} Credential={ACCESS_KEY_ID}/{credential_scope}, "
        f"SignedHeaders={signed_headers}, Signature={signature}"
    )

    # 7. Prepare and send the request
    headers = {
        "Host": host,
        "x-amz-content-sha256": content_sha256,
        "x-amz-date": amz_date,
        "Authorization": authorization_header,
        "Content-Type": "application/octet-stream",
    }

    return headers


###


def calculate_sha256(file_path):
    """Calculate SHA256 hash of file content"""
    sha256 = hashlib.sha256()
    with open(file_path, "rb") as f:
        while chunk := f.read(8192):
            sha256.update(chunk)
    return sha256.hexdigest()


def get_headers(content_sha256: str):
    # Generate your access key from the console

    content_type = "application/octet-stream"

    now = datetime.datetime.now(datetime.timezone.utc)
    amz_date = now.strftime("%Y%m%dT%H%M%SZ")
    datestamp = now.strftime("%Y%m%d")

    algorithm = "AWS4-HMAC-SHA256"
    signed_headers = "host;x-amz-acl;x-amz-content-sha256;x-amz-date"
    credential_scope = f"{ACCESS_KEY_ID}/{datestamp}/{REGION}/s3/aws4_request"

    policy = {
        "conditions": [
            ["starts-with", "$key", ""],
            {"acl": "public-read"},
            {"x-amz-credential": credential_scope},
            {"x-amz-algorithm": algorithm},
            {"x-amz-date": amz_date},
            {"success_action_status": "204"},
        ]
    }

    stringToSign = base64.b64encode(bytes(json.dumps(policy), encoding="utf8"))
    print("Base64 encoded policy:", stringToSign.decode("utf-8"), end="\n\n")

    dateKey = hmac.new(
        bytes("AWS4" + SECRET_ACCESS_KEY, "utf-8"),
        bytes(datestamp, "utf-8"),
        digestmod=hashlib.sha256,
    ).digest()
    dateRegionKey = hmac.new(
        dateKey, bytes(REGION, "utf-8"), digestmod=hashlib.sha256
    ).digest()
    dateRegionServiceKey = hmac.new(
        dateRegionKey, bytes("s3", "utf-8"), digestmod=hashlib.sha256
    ).digest()
    signinKey = hmac.new(
        dateRegionServiceKey, bytes("aws4_request", "utf-8"), digestmod=hashlib.sha256
    ).digest()
    print("Signin key:", signinKey.hex(), end="\n\n")

    signature = hmac.new(signinKey, stringToSign, digestmod=hashlib.sha256).digest()
    print("Signature:", signature.hex(), end="\n\n")

    authorization_header = (
        f"{algorithm} Credential={credential_scope}, "
        f"SignedHeaders={signed_headers}, Signature={signature}"
    )
    headers = {
        "Host": HOST,
        "x-amz-date": amz_date,
        "Authorization": authorization_header,
        "x-amz-content-sha256": content_sha256,
        "Content-Type": content_type,
    }

    return headers


def zip_directory(directory_path, zip_path):
    """
    Compress a directory and all its contents to a zip file

    :param directory_path: Path to directory to compress
    :param zip_path: Path to output zip file (including .zip extension)
    """
    print("zip")
    with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
        print(directory_path)
        for root, dirs, files in os.walk(directory_path):
            for _file in files:
                file_path = os.path.join(root, _file)
                # Create a proper archive path by removing the directory_path prefix
                archive_path = os.path.join(
                    os.path.basename(directory_path),
                    os.path.relpath(file_path, start=directory_path),
                )
                zipf.write(file_path, archive_path)


def make_upload(target_file: Path, directory):
    
    object_name = f"{directory}/{target_file.name}"
    endpoint = f"https://{HOST}/{urllib.parse.quote(object_name, safe='/')}"
    print(f"{endpoint=}")

    # Create the request

    content_sha256 = calculate_sha256(target_file)
    req = urllib.request.Request(
        url=endpoint,
        method="PUT",
        headers=upload_to_scaleway(
            content_sha256=content_sha256,
            target_file=object_name,
        ),
    )

    # Add the file content
    with open(target_file, "rb") as f:
        file_content = f.read()
        req.data = file_content

    # Execute the request
    try:
        with urllib.request.urlopen(req) as response:
            print(f"Upload successful! Status: {response.status}")
            print(f"Response: {response.read().decode()}")
    except urllib.error.HTTPError as e:
        print(f"Error uploading file: {e.code} {e.reason}")
        print(e.read().decode())


def main() -> None:
    parser = argparse.ArgumentParser(description="Process a file for extraction.")
    parser.add_argument("file", type=str, help="Path to the file to process")
    args = parser.parse_args()
    filename = Path(args.file).stem
    target_file = f"{filename}.zip"
    output = "tables.csv"

    # Assuming args.file contains the file path
    with tempfile.TemporaryDirectory() as tmpdir:
        # subprocess.run([PBI_TOOLS_PATH, "extract", args.file], check=True)

        tmpdir_path = Path(tmpdir)
        print("Hello from flow.py!")
        shutil.copytree(args.file, tmpdir_path / filename)
        export_table = ExportTables(path=tmpdir_path / filename)
        export_table.save_to_csv(tmpdir_path / output)
        zip_directory(directory_path=f"{filename}", zip_path=tmpdir_path / target_file)
        make_upload(
            target_file=tmpdir_path / target_file,
            directory=filename,
        )
        make_upload(
            target_file=tmpdir_path / output,
            directory=filename,
        )


if __name__ == "__main__":
    main()
