/*******************************************************************************
* Copyright (C) 2024 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       This example demonstrates use of oneAPI Math Kernel Library (oneMKL)
*       DPCPP API, oneapi::mkl::sparse::omatconvert, to perform an out-of-place
*       conversion operation between two sparse matrices, on a SYCL device (CPU,
*       GPU). This example illustrates the conversion from a matrix in COO
*       to a matrix in CSR format.
*
*           A(COO) -> B(CSR)
*
*       The supported floating point data types for omatconvert matrix data are:
*           float
*           double
*           std::complex<float>
*           std::complex<double>
*
*       The supported conversions for omatconvert are:
*           COO -> CSR
*           CSR -> COO
*
*******************************************************************************/

// stl includes
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <iterator>
#include <limits>
#include <list>
#include <vector>

#include "mkl.h"
#include "oneapi/mkl.hpp"
#include <sycl/sycl.hpp>

// local includes
#include "common_for_examples.hpp"
#include "./include/common_for_sparse_examples.hpp"

//
// Main example for Sparse Matrix conversion: A(COO) -> B(CSR)
//
template <typename dataType, typename intType>
int run_sparse_blas_example(sycl::queue &q)
{
    bool good = true;

    //
    // Array memory management tools
    //
    std::vector<intType *> int_ptr_vec;
    std::vector<dataType *> data_ptr_vec;
    std::vector<std::int64_t *> i64_ptr_vec;
    std::vector<void *> void_ptr_vec;

    //
    // handles for sparse matrices
    //
    oneapi::mkl::sparse::matrix_handle_t cooA = nullptr;
    oneapi::mkl::sparse::matrix_handle_t csrB = nullptr;

    //
    // omatconvert descriptor for conversion to CSR
    //
    oneapi::mkl::sparse::omatconvert_descr_t descr = nullptr;


    try {

        // Initialize data for Sparse Matrix
        oneapi::mkl::index_base a_index = oneapi::mkl::index_base::zero;
        oneapi::mkl::index_base b_index = oneapi::mkl::index_base::one;
        intType a_ind = a_index == oneapi::mkl::index_base::zero ? 0 : 1;
        intType b_ind = b_index == oneapi::mkl::index_base::zero ? 0 : 1;

        //
        // Set dimensions of matrices
        //
        intType size = 4;

        const std::int64_t a_nrows = size * size * size;
        const std::int64_t a_ncols = a_nrows;
        std::int64_t a_nnz   = 27 * a_nrows; // upper bound for now
        const std::int64_t b_nrows = a_nrows;
        const std::int64_t b_ncols = a_ncols;
        std::int64_t b_nnz = 0; // potentially unknown for now

        //
        // Set A data locally in COO format
        //
        std::vector<intType, mkl_allocator<intType, 64>> csr_ia;
        std::vector<intType, mkl_allocator<intType, 64>> ia;
        std::vector<intType, mkl_allocator<intType, 64>> ja;
        std::vector< dataType, mkl_allocator< dataType, 64>> a;

        csr_ia.resize(a_nrows + 1);
        ia.resize(a_nnz);
        ja.resize(a_nnz);
        a.resize(a_nnz);

        generate_sparse_matrix<dataType, intType>(size, csr_ia, ja, a, a_ind);
        a_nnz = csr_ia[a_nrows] - a_ind;

        // Manually uncompress the row indices from CSR to COO since we want A in
        // COO format to test omatconvert with COO->CSR conversion
        for (intType i = 0; i < a_nrows; i++) {
            for (intType j = csr_ia[i]; j < csr_ia[i+1]; j++) {
                ia[j] = i;
            }
        }
        // USM arrays for A matrix
        intType *a_rowind = nullptr, *a_colind = nullptr;
        dataType *a_vals = nullptr;

        a_rowind = (intType *)malloc_shared((a_nnz) * sizeof(intType), q);
        a_colind = (intType *)malloc_shared((a_nnz) * sizeof(intType), q);
        a_vals   = (dataType *)malloc_shared((a_nnz) * sizeof(dataType), q);

        if (!a_rowind || !a_colind || !a_vals) {
           std::string errorMessage =
               "Failed to allocate USM shared memory arrays \n"
               " for COO A matrix: a_rowind(" + std::to_string((a_nnz)*sizeof(intType)) + " bytes)\n"
               "                   a_colind(" + std::to_string((a_nnz)*sizeof(intType)) + " bytes)\n"
               "                   a_vals(" + std::to_string((a_nnz)*sizeof(dataType)) + " bytes)";
            throw std::runtime_error(errorMessage);
        }

        int_ptr_vec.push_back(a_rowind);
        int_ptr_vec.push_back(a_colind);
        data_ptr_vec.push_back(a_vals);

        // Copy data to USM arrays
        for (intType i = 0; i < a_nnz; ++i) {
            a_rowind[i] = ia[i];
            a_colind[i] = ja[i];
            a_vals[i]   = a[i];
        }

        //
        // Declare B matrix arrays
        //
        intType *b_rowptr = nullptr, *b_colind = nullptr;
        dataType *b_vals = nullptr;

        //
        // Execute Matrix Conversion
        //

        std::cout << "\n\t\tsparse::omatconvert parameters:\n";

        std::cout << "\t\t\tA format = COO" << std::endl;
        std::cout << "\t\t\tA_nrows  = " << a_nrows << std::endl;
        std::cout << "\t\t\tA_ncols  = " << a_ncols << std::endl;
        std::cout << "\t\t\tA_nnz    = " << a_nnz << std::endl;
        std::cout << "\t\t\tA_index  = " << a_index << std::endl;


        std::int64_t size_temp_workspace;
        void *temp_workspace = nullptr;
        oneapi::mkl::sparse::omatconvert_alg alg =
            oneapi::mkl::sparse::omatconvert_alg::default_alg;

        oneapi::mkl::sparse::init_matrix_handle(&cooA);
        oneapi::mkl::sparse::init_matrix_handle(&csrB);

        oneapi::mkl::sparse::init_omatconvert_descr(q, &descr);

        // Set A matrix
        auto ev_setA = oneapi::mkl::sparse::set_coo_data(q, cooA, a_nrows,
            a_ncols, a_nnz, a_index, a_rowind, a_colind, a_vals);

        // Set B matrix with dummy data and b_nnz=0 for now
        auto ev_setB_dummy = oneapi::mkl::sparse::set_csr_data(q, csrB,
            b_nrows, b_ncols, b_nnz, b_index, b_rowptr, b_colind, b_vals);

        // Get the size of the temporary workspace for the subsequent stages
        oneapi::mkl::sparse::omatconvert_buffer_size(
            q, cooA, csrB, alg, descr, size_temp_workspace);
        temp_workspace = sycl::malloc_shared(size_temp_workspace * sizeof(std::uint8_t), q);
        if ( (size_temp_workspace > 0) && !temp_workspace) {
            std::string errorMessage =
                "Failed to allocate USM shared memory arrays \n"
                " for temp_workspace(" + std::to_string((size_temp_workspace)*sizeof(std::uint8_t)) + " bytes)";
            throw std::runtime_error(errorMessage);
        }
        void_ptr_vec.push_back(temp_workspace);

        // Analyze stage
        auto ev_analyze = oneapi::mkl::sparse::omatconvert_analyze(
            q, cooA, csrB, alg, descr, temp_workspace, {ev_setA, ev_setB_dummy});

        // Get NNZ for the B matrix
        oneapi::mkl::sparse::omatconvert_get_nnz(
            q, cooA, csrB, alg, descr, b_nnz, {ev_analyze});

        std::cout << "\t\t\tB format = CSR" << std::endl;
        std::cout << "\t\t\tB_nrows  = " << b_nrows << std::endl;
        std::cout << "\t\t\tB_ncols  = " << b_ncols << std::endl;
        std::cout << "\t\t\tB_nnz    = " << b_nnz << std::endl;
        std::cout << "\t\t\tB_index  = " << b_index << std::endl;

        // Set B matrix with properly allocated data
        b_rowptr = (intType *)malloc_shared((b_nrows+1) * sizeof(intType), q);
        b_colind = (intType *)malloc_shared((b_nnz) * sizeof(intType), q);
        b_vals   = (dataType *)malloc_shared((b_nnz) * sizeof(dataType), q);
        if (!b_rowptr || !b_colind || !b_vals) {
           std::string errorMessage =
               "Failed to allocate USM shared memory arrays \n"
               " for CSR B matrix: b_rowptr(" + std::to_string((b_nrows+1)*sizeof(intType)) + " bytes)\n"
               "                   b_colind(" + std::to_string((b_nnz)*sizeof(intType)) + " bytes)\n"
               "                   b_vals(" + std::to_string((b_nnz)*sizeof(dataType)) + " bytes)";
            throw std::runtime_error(errorMessage);
        }

        int_ptr_vec.push_back(b_rowptr);
        int_ptr_vec.push_back(b_colind);
        data_ptr_vec.push_back(b_vals);

        auto ev_setB = oneapi::mkl::sparse::set_csr_data(q, csrB,
            b_nrows, b_ncols, b_nnz, b_index, b_rowptr, b_colind, b_vals);

        // Convert A(COO) -> B(CSR)
        auto ev_convert = oneapi::mkl::sparse::omatconvert(
                q, cooA, csrB, alg, descr, {ev_setB});

        auto ev_rel_descr = oneapi::mkl::sparse::release_omatconvert_descr(q, descr, {ev_convert}); descr = nullptr;
        auto ev_relA = oneapi::mkl::sparse::release_matrix_handle(q, &cooA, {ev_convert});
        auto ev_relB = oneapi::mkl::sparse::release_matrix_handle(q, &csrB, {ev_convert});

        // Print portion of B solution
        sycl::event ev_print = q.submit([&](sycl::handler &cgh) {
            cgh.depends_on({ev_convert});
            auto kernel = [=]() {
                std::cout << "B matrix [first two rows]:" << std::endl;
                intType printed_rows = std::min(static_cast<std::int64_t>(2),
                                                b_nrows);
                for (intType row = 0; row < printed_rows; ++row) {
                    for (intType j = b_rowptr[row] - b_ind; j < b_rowptr[row + 1] - b_ind; ++j) {
                        intType col = b_colind[j];
                        dataType val  = b_vals[j];
                        std::cout << "B(" << row + b_ind << ", " << col << ") = " << val
                                  << std::endl;
                    }
                }
            };
            cgh.host_task(kernel);
        });

        q.wait_and_throw();

    }
    catch (sycl::exception const &e) {
        std::cout << "\t\tCaught synchronous SYCL exception:\n" << e.what()
                  << std::endl;
        good = false;
    }
    catch (std::exception const &e) {
        std::cout << "\t\tCaught std exception:\n" << e.what() << std::endl;
        good = false;
    }

    q.wait();

    // backup cleaning of matrix handle and others for if exceptions happened
    if (descr) oneapi::mkl::sparse::release_omatconvert_descr(q, descr, {}).wait();
    oneapi::mkl::sparse::release_matrix_handle(q, &cooA, {}).wait();
    oneapi::mkl::sparse::release_matrix_handle(q, &csrB, {}).wait();

    cleanup_arrays<dataType, intType>(data_ptr_vec, int_ptr_vec, i64_ptr_vec,
                                void_ptr_vec, q);

    return good ? 0 : 1;
}

//
// Description of example setup, apis used and supported floating point type
// precisions
//
void print_example_banner()
{

    std::cout << "" << std::endl;
    std::cout << "#############################################################"
                 "###########"
              << std::endl;
    std::cout << "# Sparse Matrix Conversion Example: " << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "#    A(COO) -> B(CSR)" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# where A and B are sparse matrices in CSR and COO formats"
                 " respectively" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Using apis:" << std::endl;
    std::cout << "#   sparse::omatconvert_buffer_size" << std::endl;
    std::cout << "#   sparse::omatconvert_analyze" << std::endl;
    std::cout << "#   sparse::omatconvert_get_nnz" << std::endl;
    std::cout << "#   sparse::omatconvert" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "#   sparse::init_matrix_handle" << std::endl;
    std::cout << "#   sparse::set_csr_data" << std::endl;
    std::cout << "#   sparse::set_coo_data" << std::endl;
    std::cout << "#   sparse::release_matrix_handle" << std::endl;
    std::cout << "#   sparse::init_omatconvert_descr" << std::endl;
    std::cout << "#   sparse::release_omatconvert_descr" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Supported floating point type precisions:" << std::endl;
    std::cout << "#   float" << std::endl;
    std::cout << "#   double" << std::endl;
    std::cout << "#   std::complex<float>" << std::endl;
    std::cout << "#   std::complex<double>" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "#############################################################"
                 "###########"
              << std::endl;
    std::cout << std::endl;
}

//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU implementation
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU implementation
// -DSYCL_DEVICES_all (default) -- runs on all: cpu and gpu devices
//
//  For each device selected and each supported data type,
//  run_sparse_blas_example is run with all supported data types,
//  if any fail, we move on to the next device.
//

int main(int argc, char **argv)
{

    print_example_banner();

    std::list<my_sycl_device_types> list_of_devices;
    set_list_of_devices(list_of_devices);

    int status = 0;
    for (auto it = list_of_devices.begin(); it != list_of_devices.end(); ++it) {
        try {
            sycl::device my_dev;
            bool my_dev_is_found = false;
            get_sycl_device(my_dev, my_dev_is_found, *it);

            if (my_dev_is_found) {
                std::cout << "Running tests on " << sycl_device_names[*it] << ".\n";

                // Catch asynchronous exceptions
                auto exception_handler = [](sycl::exception_list exceptions) {
                    for (std::exception_ptr const &e : exceptions) {
                        try {
                            std::rethrow_exception(e);
                        }
                        catch (sycl::exception const &e) {
                            std::cout << "Caught asynchronous SYCL exception: \n"
                                << e.what() << std::endl;
                        }
                    }
                };

                sycl::queue q(my_dev, exception_handler);

                std::cout << "\tRunning with single precision real data type:" << std::endl;
                status |= run_sparse_blas_example<float, std::int32_t>(q);

                if (my_dev.get_info<sycl::info::device::double_fp_config>().size() != 0) {
                    std::cout << "\tRunning with double precision real data type:" << std::endl;
                    status |= run_sparse_blas_example<double, std::int32_t>(q);
                }

                std::cout << "\tRunning with single precision complex data type:" << std::endl;
                status |= run_sparse_blas_example<std::complex<float>, std::int32_t>(q);

                if (my_dev.get_info<sycl::info::device::double_fp_config>().size() != 0) {
                    std::cout << "\tRunning with double precision complex data type:" << std::endl;
                    status |= run_sparse_blas_example<std::complex<double>, std::int32_t>(q);
                }

            }
            else {
#ifdef FAIL_ON_MISSING_DEVICES
                std::cout << "No " << sycl_device_names[*it]
                    << " devices found; Fail on missing devices "
                    "is enabled.\n";
                return 1;
#else
                std::cout << "No " << sycl_device_names[*it] << " devices found; skipping "
                    << sycl_device_names[*it] << " tests.\n";
#endif
            }
        }
        catch (sycl::exception const &e) {
            std::cout << "\t\tCaught SYCL exception at driver level: \n" << e.what() << std::endl;
            continue; // stop with device, but move on to other devices
        }
        catch (std::exception const &e) {
            std::cout << "\t\tCaught std exception at driver level: \n" << e.what() << std::endl;
            continue; // stop with device, but move on to other devices
        }


    } // for loop over devices

    mkl_free_buffers();
    return status;
}

