#!/usr/bin/env python3

# (C) Copyright 2025- ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#


import sys
from pathlib import Path
from struct import pack

import numpy as np
from scipy.sparse import csr_array
from scipy.sparse import load_npz
from scipy.sparse import save_npz

# Add the directory containing stream.py to the import path
sys.path.append(str(Path(__file__).parent))
from stream import TAG_LARGE_BLOB
from stream import Stream


def dtype_uint(little_endian: bool, size: int):
    order = "<" if little_endian else ">"
    return np.dtype({4: np.uint32}[size]).newbyteorder(order)


def dtype_float(little_endian: bool, size: int):
    order = "<" if little_endian else ">"
    return np.dtype({4: np.float32, 8: np.float64}[size]).newbyteorder(order)


def convert_mat_to_npz(input_path: Path, output_path: Path):
    """Convert MIR .mat format to .npz format."""
    with open(input_path, "rb") as f:
        s = Stream(f)
        rows = s.read_unsigned_long()  # rows
        cols = s.read_unsigned_long()  # cols
        s.read_unsigned_long()  # non-zeros, ignored

        little_endian = s.read_int() != 0  # little_endian
        index_item_size = s.read_unsigned_long()  # sizeof(index)
        scalar_item_size = s.read_unsigned_long()  # sizeof(scalar)
        s.read_unsigned_long()  # sizeof(size), ignored

        outer = s.read_large_blob()  # outer
        inner = s.read_large_blob()  # inner
        data = s.read_large_blob()  # data

        outer = np.frombuffer(
            outer,
            dtype=dtype_uint(little_endian, index_item_size),
        )

        inner = np.frombuffer(
            inner,
            dtype=dtype_uint(little_endian, index_item_size),
        )

        data = np.frombuffer(
            data,
            dtype=dtype_float(little_endian, scalar_item_size),
        )

        save_npz(output_path, csr_array((data, inner, outer), shape=(rows, cols)))


def convert_npz_to_mat(input_path: Path, output_path: Path):
    """Convert .npz format to MIR .mat format."""
    matrix = load_npz(input_path)

    # Convert to CSR format if not already
    if not isinstance(matrix, csr_array):
        matrix = csr_array(matrix)

    rows, cols = matrix.shape
    nnz = matrix.nnz

    # Use little-endian format (typical MIR use)
    little_endian = True

    # Determine indices and data sizes (in bytes)
    indptr = matrix.indptr.astype(np.uint32)
    indices = matrix.indices.astype(np.uint32)
    data = matrix.data.astype(np.float64)
    index_item_size = 4  # uint32
    scalar_item_size = 8  # float64
    size_item_size = 8  # size_t (typically on 64-bit systems)

    # Convert to little-endian byte order
    if little_endian:
        indptr = indptr.astype("<u4")
        indices = indices.astype("<u4")
        data = data.astype("<f8")

    def _write_large_blob(f, data):
        """Write a large blob to the stream."""
        f.write(pack("b", TAG_LARGE_BLOB))
        f.write(pack("!Q", len(data)))
        f.write(data)

    with open(output_path, "wb") as f:
        s = Stream(f)
        s.write_unsigned_long(rows)
        s.write_unsigned_long(cols)
        s.write_unsigned_long(nnz)

        s.write_int(1 if little_endian else 0)
        s.write_unsigned_long(index_item_size)
        s.write_unsigned_long(scalar_item_size)
        s.write_unsigned_long(size_item_size)

        # Write large blobs for outer (indptr), inner (indices), and data
        _write_large_blob(f, indptr.tobytes())
        _write_large_blob(f, indices.tobytes())
        _write_large_blob(f, data.tobytes())


def main():
    import argparse

    parser = argparse.ArgumentParser(
        description="Convert between MIR .mat and .npz matrix formats"
    )
    parser.add_argument("input", help="Path to input matrix file", type=Path)
    parser.add_argument("output", help="Path to output matrix file", type=Path)
    args = parser.parse_args()

    input_path = args.input
    if not input_path.exists():
        raise FileNotFoundError(f"Input file not found: {input_path}")
    input_ext = input_path.suffix.lower()

    output_path = args.output
    output_ext = output_path.suffix.lower()

    if input_ext == ".mat" and output_ext == ".npz":
        convert_mat_to_npz(input_path, output_path)
    elif input_ext == ".npz" and output_ext == ".mat":
        convert_npz_to_mat(input_path, output_path)
    else:
        raise ValueError(
            f"Unsupported conversion: {input_ext} -> {output_ext}. "
            "Supported conversions: .mat -> .npz, .npz -> .mat"
        )


if __name__ == "__main__":
    main()
