/*******************************************************************************
* Copyright (C) 2023 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

#pragma once

#include "mkl.h"
#include "oneapi/mkl.hpp"
#include <stdexcept>

#include "include/common_for_sparse_examples.hpp"

/*
 * Note:
 *  In this header file, we provide serial reference implementations for
 *  validating the output of the corresponding oneapi::mkl::sparse APIs,
 *  used as part of the examples.
 *  Also, we have do_csr_transpose() helper function to utilize it easily
 *  for the case of applying transpose or conjugate-transpose operation on
 *  the input sparse matrix in the reference implementation.
 */

template <typename dataType, typename intType>
void do_csr_transpose(const oneapi::mkl::transpose opA,
                      intType *ia_t,
                      intType *ja_t,
                      dataType *a_t,
                      std::int64_t a_nrows,
                      std::int64_t a_ncols,
                      intType a_ind,
                      intType *ia,
                      intType *ja,
                      dataType *a,
                      const bool structOnlyFlag = false)
{
    assert(ia_t != nullptr);
    assert(ja_t != nullptr);
    assert(a_t  != nullptr);
    assert(ia   != nullptr);
    assert(ja   != nullptr);
    assert(a    != nullptr);

    const bool isConj = (opA == oneapi::mkl::transpose::conjtrans);

    // initialize ia_t to zero
    for (intType i = 0; i < a_ncols + 1; ++i) {
        ia_t[i] = 0;
    }

    // fill ia_t with counts of columns
    for (intType i = 0; i < a_nrows; ++i) {
        const intType st = ia[i] - a_ind;
        const intType en = ia[i + 1] - a_ind;
        for (intType j = st; j < en; ++j) {
            const intType col = ja[j] - a_ind;
            ia_t[col + 1]++;
        }
    }
    // prefix sum to get official ia_t counts
    ia_t[0] = a_ind;
    for (intType i = 0; i < a_ncols; ++i) {
        ia_t[i + 1] += ia_t[i];
    }

    // second pass through data to fill transpose structure
    for (intType i = 0; i < a_nrows; ++i) {
        const intType st = ia[i] - a_ind;
        const intType en = ia[i + 1] - a_ind;
        for (intType j = st; j < en; ++j) {
            const intType col      = ja[j] - a_ind;
            const intType j_in_a_t = ia_t[col] - a_ind;
            ia_t[col]++;
            ja_t[j_in_a_t] = i + a_ind;
            if (!structOnlyFlag) {
                const dataType val = a[j];
                a_t[j_in_a_t]    = opVal(val, isConj);
            }
        }
    }

    // adjust ia_t back to original state after filling structure
    for (intType i = a_ncols; i > 0; --i) {
        ia_t[i] = ia_t[i - 1];
    }
    ia_t[0] = a_ind;
}

// Note:
//  Currently, this function assumes opB == nontrans.
template <typename dataType, typename intType>
void prepare_reference_gemm_data(intType *ia,
                                 intType *ja,
                                 dataType *a,
                                 std::int64_t a_nrows,
                                 std::int64_t a_ncols,
                                 std::int64_t a_nnz,
                                 intType a_ind,
                                 oneapi::mkl::transpose opA,
                                 dataType alpha,
                                 dataType beta,
                                 oneapi::mkl::layout layout,
                                 dataType *b,
                                 std::int64_t columns, // columns == b_ncols == c_ncols.
                                 std::int64_t ldb,
                                 dataType *c_ref,
                                 std::int64_t ldc)
{
    intType opa_nrows = (opA == oneapi::mkl::transpose::nontrans) ? a_nrows : a_ncols;
    intType opa_ncols = (opA == oneapi::mkl::transpose::nontrans) ? a_ncols : a_nrows;

    // prepare op(A) locally
    intType *iopa = nullptr;
    intType *jopa = nullptr;
    dataType *opa = nullptr;
    if (opA == oneapi::mkl::transpose::nontrans) {
        iopa = ia;
        jopa = ja;
        opa = a;
    }
    else if ( opA == oneapi::mkl::transpose::trans ||
              opA == oneapi::mkl::transpose::conjtrans )
    {
        iopa = (intType *)mkl_malloc((opa_nrows + 1) * sizeof(intType), 64);
        jopa = (intType *)mkl_malloc((a_nnz) * sizeof(intType), 64);
        opa  = (dataType *)mkl_malloc((a_nnz) * sizeof(dataType), 64);
        if (!iopa || !jopa || !opa)
            throw std::runtime_error("failed to allocate memory for one of iopa, jopa or opa in prepare_reference_gemm_data");

        do_csr_transpose(opA, iopa, jopa, opa, a_nrows, a_ncols, a_ind, ia, ja, a);

    }
    else {
        throw std::runtime_error("unsupported transpose_val (opA) in prepare_reference_gemm_data()");
    }

    intType c_nrows = opa_nrows;
    intType c_ncols = columns;

    //
    // do GEMM operation (note that alpha or beta could be (dataType)0.0)
    //
    //  c_ref <- alpha * op(A) * b + beta * c_ref
    //
    if (layout == oneapi::mkl::layout::col_major) {
        for (intType col = 0; col < c_ncols; col++) {
            for (intType row = 0; row < c_nrows; row++) {
                /* ldb >= opa_ncols (== b_nrows), ldc >= c_nrows */
                dataType tmp = 0.0;
                for (intType i = iopa[row] - a_ind; i < iopa[row + 1] - a_ind; i++) {
                    tmp += opa[i] * b[(jopa[i] - a_ind) + col * ldb];
                }
                c_ref[row + col * ldc] = alpha * tmp + beta * c_ref[row + col * ldc];
            }
        }
    }
    else {
        for (intType row = 0; row < c_nrows; row++) {
            for (intType col = 0; col < c_ncols; col++) {
                /* ldb >= b_ncols, ldc >= c_ncols */
                dataType tmp = 0.0;
                for (intType i = iopa[row] - a_ind; i < iopa[row + 1] - a_ind; i++) {
                    tmp += opa[i] * b[(jopa[i] - a_ind) * ldb + col];
                }
                c_ref[row * ldc + col] = alpha * tmp + beta * c_ref[row * ldc + col];
            }
        }
    }

    // clean up local allocated memory for op(A)
    if (opA == oneapi::mkl::transpose::nontrans) {
        // nothing to cleanup
    }
    else if ( opA == oneapi::mkl::transpose::trans ||
              opA == oneapi::mkl::transpose::conjtrans )
    {
        mkl_free(iopa);
        mkl_free(jopa);
        mkl_free(opa);
    }
    else {
        throw std::runtime_error("unsupported transpose_val (opA) in prepare_reference_gemm_data()");
    }

}

