/*******************************************************************************
* Copyright (C) 2024 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER

#include "ComputeSYMGS_MKL.hpp"

#include "VeryBasicProfiler.hpp"

#ifndef HPCG_NO_MPI
#include "../ExchangeHalo.hpp"
#include <mpi.h>
#include "../Geometry.hpp"
#include <cstdlib>
#endif

#ifdef BASIC_PROFILING
#define BEGIN_PROFILE(n) optData->profiler->begin(n);
#define END_PROFILE(n) optData->profiler->end(n);
#define END_PROFILE_WAIT(n, event) event.wait(); optData->profiler->end(n);
#else
#define BEGIN_PROFILE(n)
#define END_PROFILE(n)
#define END_PROFILE_WAIT(n, event)
#endif

sycl::event run_SYMGS_onemkl(sycl::queue &main_queue,
                             const SparseMatrix & A,
                             struct optData *optData,
                             oneapi::mkl::sparse::matrix_handle_t hMatrixA,
                             oneapi::mkl::sparse::matrix_handle_t hMatrixB,
                             const Vector &r,
                             Vector &x,
                             const std::vector<sycl::event> &deps)
{
    namespace sparse = oneapi::mkl::sparse;
    namespace mkl = oneapi::mkl;

    const local_int_t nrow = A.localNumberOfRows;

    sycl::event result_ev;
    if (A.geom->size > 1) {

#ifndef HPCG_NO_MPI
        // exchange x.values with MPI neighbors
        BEGIN_PROFILE("SYMGS:halo");
        sycl::event halo_ev = ExchangeHalo(A, x, main_queue, deps);
        std::vector<sycl::event> halo_deps({halo_ev});
        END_PROFILE_WAIT("SYMGS:halo", halo_ev);
#else
        const std::vector<sycl::event>& halo_deps = deps;
#endif

        // dtmp4 <- (D + U) * x.values + 0.0 * dtmp4
        BEGIN_PROFILE("SYMGS:trmvU");
        auto trmvU_ev = sparse::trmv(main_queue, mkl::uplo::upper, mkl::transpose::nontrans,
                mkl::diag::nonunit, 1.0, hMatrixA, x.values, 0.0, optData->dtmp4, deps);
        END_PROFILE_WAIT("SYMGS:trmvU", trmvU_ev);

        // dtmp2 <- B * x.values + 0.0 * dtmp2  (does not depend on trmv_ev)
        BEGIN_PROFILE("SYMGS:gemv");
        auto gemvB_ev = sparse::gemv(main_queue, mkl::transpose::nontrans, 1.0, hMatrixB,
                x.values, 0.0, optData->dtmp2, halo_deps);
        END_PROFILE_WAIT("SYMGS:gemv", gemvB_ev);

        // does not depend on gemv_ev
        //
        // dtmp4 <- dtmp4 - D * x.values // fix dtmp4 to only be dtmp4 = U * x.values
        // dtmp  <- r.values - (dtmp4 - D * x.values)
        BEGIN_PROFILE("SYMGS:rhs1");
        auto rhs1_ev = main_queue.submit([&](sycl::handler &cgh) {
            cgh.depends_on(trmvU_ev);
            double *dtmp = optData->dtmp;
            double *dtmp4 = optData->dtmp4;
            double *diags = optData->diags;
            auto kernel = [=] (sycl::item<1> row) {
                const double tmp = dtmp4[row] - diags[row] * x.values[row];
                dtmp4[row] = tmp;
                dtmp[row]  = r.values[row] - tmp;
            };
            cgh.parallel_for<class SYMGS_rhs1_onemkl>(sycl::range<1>(nrow), kernel);
        });
        END_PROFILE_WAIT("SYMGS:rhs1", rhs1_ev);

        // dtmp <- dtmp - dtmp2(bmap)
        BEGIN_PROFILE("SYMGS:rhs1_nonloc");
        const local_int_t nrow_b = optData->nrow_b;
        auto rhs1_nonloc_ev = main_queue.submit([&](sycl::handler &cgh) {
            cgh.depends_on({gemvB_ev, rhs1_ev});
            double *dtmp  = optData->dtmp;
            double *dtmp2 = optData->dtmp2;
            local_int_t *bmap = optData->bmap;
            auto kernel = [=] (sycl::item<1> row) {
                dtmp[bmap[row]] -= dtmp2[row];
            };
            cgh.parallel_for<class SYMGS_rhs1_nonloc_onemkl>(sycl::range<1>(nrow_b), kernel);
        });
        END_PROFILE_WAIT("SYMGS:rhs1_nonloc", rhs1_nonloc_ev);

        // dtmp3 <- (L + D)^{-1} * dtmp
        BEGIN_PROFILE("SYMGS:trsvL");
        auto trsvL_ev = sparse::trsv(main_queue, mkl::uplo::lower, mkl::transpose::nontrans,
                mkl::diag::nonunit, 1.0, hMatrixA, optData->dtmp, optData->dtmp3, {rhs1_ev, rhs1_nonloc_ev});
        END_PROFILE_WAIT("SYMGS:trsvL", trsvL_ev);

        // dtmp3 <- D * dtmp3 + dtmp4
        BEGIN_PROFILE("SYMGS:rhs2");
        auto rhs2_ev = main_queue.submit([&](sycl::handler &cgh) {
            cgh.depends_on(trsvL_ev);
            double *dtmp3 = optData->dtmp3;
            double *dtmp4 = optData->dtmp4;
            double *diags  = optData->diags;
            auto kernel = [=] (sycl::item<1> row) {
                dtmp3[row] = diags[row]*dtmp3[row] + dtmp4[row];
            };
            cgh.parallel_for<class SYMGS_rhs2_onemkl>(sycl::range<1>(nrow), kernel);
        });
        END_PROFILE_WAIT("SYMGS:rhs2", rhs2_ev);

        // x.values <- (D + U)^{-1} * dtmp3
        BEGIN_PROFILE("SYMGS:trsvU");
        result_ev = sparse::trsv(main_queue, mkl::uplo::upper, mkl::transpose::nontrans,
                mkl::diag::nonunit, 1.0, hMatrixA, optData->dtmp3, x.values, {rhs2_ev});
        END_PROFILE_WAIT("SYMGS:trsvU", result_ev);

    }
    else { // A.geom->size == 1

        // dtmp4 <- (D + U) * x.values + 0.0 * dtmp4
        BEGIN_PROFILE("SYMGS:trmvU");
        auto trmvU_ev = sparse::trmv(main_queue, mkl::uplo::upper, mkl::transpose::nontrans,
                mkl::diag::nonunit, 1.0, hMatrixA, x.values, 0.0, optData->dtmp4, deps);
        END_PROFILE_WAIT("SYMGS:trmvU", trmvU_ev);

        // dtmp4 <- dtmp4 - D * x.values
        // dtmp  <- r.values - (dtmp4 - D * x.values)
        BEGIN_PROFILE("SYMGS:rhs_update");
        auto rhs1_ev = main_queue.submit([&](sycl::handler &cgh) {
            cgh.depends_on(trmvU_ev);
            auto kernel = [=] (sycl::item<1> row) {
                const double tmp = optData->dtmp4[row] - optData->diags[row] * x.values[row];
                optData->dtmp4[row] = tmp;
                optData->dtmp[row]  = r.values[row] - tmp;
            };
            cgh.parallel_for<class SYMGS1_rhs1_onemkl>(sycl::range<1>(nrow), kernel);
        });
        END_PROFILE_WAIT("SYMGS:rhs_update", rhs1_ev);

        // dtmp3 <- (L + D)^{-1} * dtmp
        BEGIN_PROFILE("SYMGS:trsvL");
        auto trsvL_ev = sparse::trsv(main_queue, mkl::uplo::lower, mkl::transpose::nontrans,
                mkl::diag::nonunit, 1.0, hMatrixA, optData->dtmp, optData->dtmp3, {rhs1_ev});
        END_PROFILE_WAIT("SYMGS:trsvL", trsvL_ev);

        // dtmp3 <- D * dtmp3 + dtmp4
        BEGIN_PROFILE("SYMGS:rhs2");
        auto rhs2_ev = main_queue.submit([&](sycl::handler &cgh) {
            cgh.depends_on(trsvL_ev);
            auto kernel = [=] (sycl::item<1> row) {
                optData->dtmp3[row] = optData->diags[row]*optData->dtmp3[row] + optData->dtmp4[row];
            };
            cgh.parallel_for<class SYMGS1_rhs2_onemkl>(sycl::range<1>(nrow), kernel);
        });
        END_PROFILE_WAIT("SYMGS:rhs2", rhs2_ev);

        // x.values <- (D + U)^{-1} * dtmp3
        BEGIN_PROFILE("SYMGS:trsvU");
        result_ev = sparse::trsv(main_queue, mkl::uplo::upper, mkl::transpose::nontrans,
                mkl::diag::nonunit, 1.0, hMatrixA, optData->dtmp3, x.values, {rhs2_ev});
        END_PROFILE_WAIT("SYMGS:trsvU", result_ev);

    } // endif A.geom->size > 1

    return result_ev;

} // run_SYMGS_onemkl




sycl::event run_SYMGS_MV_onemkl(sycl::queue &main_queue,
                                const SparseMatrix & A,
                                struct optData *optData,
                                oneapi::mkl::sparse::matrix_handle_t hMatrixA,
                                oneapi::mkl::sparse::matrix_handle_t hMatrixB,
                                const Vector &r,
                                Vector &x,
                                Vector &y,
                                const std::vector<sycl::event> &deps)
{
    namespace sparse = oneapi::mkl::sparse;
    namespace mkl = oneapi::mkl;

    const local_int_t nrow = A.localNumberOfRows;

    // dtmp3 <- (L + D)^{-1} * r.values
    BEGIN_PROFILE("SYMGS_MV:trsvL");
    auto trsvL_ev = sparse::trsv(main_queue, mkl::uplo::lower, mkl::transpose::nontrans,
            mkl::diag::nonunit, 1.0, hMatrixA, r.values, optData->dtmp3, deps);
    END_PROFILE_WAIT("SYMGS_MV:trsvL", trsvL_ev);

    // dtmp3 <- D * dtmp3
    BEGIN_PROFILE("SYMGS_MV:rhs1");
    auto rhs1_ev = main_queue.submit([&](sycl::handler &cgh) {
        cgh.depends_on(trsvL_ev);
        auto kernel = [=] (sycl::item<1> row) {
            optData->dtmp3[row] = optData->diags[row] * optData->dtmp3[row];
        };
        cgh.parallel_for<class SYMGS_MV_rhs1_onemkl>(sycl::range<1>(nrow), kernel);
    });
    END_PROFILE_WAIT("SYMGS_MV:rhs1", rhs1_ev);

    // x.values <- (D + U)^{-1} * dtmp3
    BEGIN_PROFILE("SYMGS_MV:trsvU");
    auto trsvU_ev = sparse::trsv(main_queue, mkl::uplo::upper, mkl::transpose::nontrans,
            mkl::diag::nonunit, 1.0, hMatrixA, optData->dtmp3, x.values, {rhs1_ev});
    END_PROFILE_WAIT("SYMGS_MV:trsvU", trsvU_ev);

    // y.values <- dtmp3 - D * x.values
    BEGIN_PROFILE("SYMGS_MV:y_part");
    auto y_part_ev = main_queue.submit([&](sycl::handler &cgh) {
        cgh.depends_on(trsvU_ev);
        auto kernel = [=] (sycl::item<1> row) {
            y.values[row] = optData->dtmp3[row] - optData->diags[row] * x.values[row];
        };
        cgh.parallel_for<class SYMGS_MV_y_part_onemkl>(sycl::range<1>(nrow), kernel);
    });
    END_PROFILE_WAIT("SYMGS_MV:y_part", y_part_ev);

    // y.values <- (L + D) * x.values + y.values
    BEGIN_PROFILE("SYMGS_MV:trmvL");
    auto trmvL_ev = sparse::trmv(main_queue, mkl::uplo::lower, mkl::transpose::nontrans,
            mkl::diag::nonunit, 1.0, hMatrixA, x.values, 1.0, y.values, {y_part_ev});
    END_PROFILE_WAIT("SYMGS_MV:trmvL", trmvL_ev);

    if(A.geom->size > 1)
    {

#ifndef HPCG_NO_MPI
        // x.values exchange with neighbors
        BEGIN_PROFILE("SYMGS_MV:halo");
        sycl::event halo_ev = ExchangeHalo(A, x, main_queue, {trsvU_ev});
        END_PROFILE_WAIT("SYMGS_MV:halo", halo_ev);
#else
        sycl::event halo_ev; // this should never be used, (no mpi + size > 1), so it is safe, but this
                             // loses the deps chain
#endif

        // below depends on halo_ev, but we reuse dtmp3 here, so
        // need subtract_ev to be done as well
        //
        // dtmp3 = B * x.values + 0 * dtmp3
        BEGIN_PROFILE("SYMGS_MV:gemvB");
        auto gemvB_ev = sparse::gemv(main_queue, mkl::transpose::nontrans, 1.0, hMatrixB,
                x.values, 0.0, optData->dtmp3, {y_part_ev, halo_ev});
        END_PROFILE_WAIT("SYMGS_MV:gemvB", gemvB_ev);

        const local_int_t nrow_b = optData->nrow_b;

        // depends on dtmp3(gemv_ev) and y.values (trmvL_ev)
        // y.values = y.values + dtmp3
        BEGIN_PROFILE("SYMGS_MV:y_nonloc");
        auto result_ev = main_queue.submit([&](sycl::handler &cgh) {
            cgh.depends_on({trmvL_ev, gemvB_ev});
            auto kernel = [=] (sycl::item<1> row) {
                y.values[optData->bmap[row]] += optData->dtmp3[row];
            };
            cgh.parallel_for<class SYMGS_MV_y_nonloc_onemkl>(sycl::range<1>(nrow_b), kernel);
        });
        END_PROFILE_WAIT("SYMGS_MV:y_nonloc", result_ev);

        return result_ev;
    }
    else { // A.geom->size == 1
        return trmvL_ev;
    }

} // run_SYMGS_MV_onemkl
