

#include<RcppArmadillo.h>
#include<cmath>

arma::mat costMatrixSup(const arma::mat& x, const arma::mat& y, const double q)
{
    int n = x.n_rows;
    int m = y.n_rows;

    arma::mat costm = arma::mat(n, m, arma::fill::none);

    #pragma omp parallel for
    for (int i = 0; i < n; ++i)
    {
        for (int j = 0; j < m; ++j)
        {
            costm(i, j) = std::pow(arma::max(arma::abs(x.row(i) - y.row(j))), q);
        }
    }

    return costm;
}

// [[Rcpp::export(rng = false)]]
arma::mat costMatrixLp(const arma::mat& x, const arma::mat& y, const double p, const double q)
{
    if (std::isinf(p))
    {
        return costMatrixSup(x, y, q);
    }

    const int n = x.n_rows;
    const int m = y.n_rows;

    arma::mat costm = arma::mat(n, m, arma::fill::none);
    const double u = q / p;

    #pragma omp parallel for
    for (int i = 0; i < n; ++i)
    {
        for (int j = 0; j < m; ++j)
        {
            costm(i, j) = std::pow(arma::accu(arma::pow(arma::abs(x.row(i) - y.row(j)), p)), u);
        }
    }

    return costm;
}

arma::mat costMatrixSymmetricSup(const arma::mat& x, const double q)
{
    const int n = x.n_rows;

    arma::mat costm = arma::mat(n, n, arma::fill::none);

    #pragma omp parallel for
    for (int i = 0; i < n; ++i)
    {
        costm(i, i) = 0;

        for (int j = i + 1; j < n; ++j)
        {
            costm(i, j) = std::pow(arma::max(arma::abs(x.row(i) - x.row(j))), q);
            costm(j, i) = costm(i, j);
        }
    }

    return costm;
}

// [[Rcpp::export(rng = false)]]
arma::mat costMatrixSymmetricLp(const arma::mat& x, const double p, const double q)
{
    if (std::isinf(p))
    {
        return costMatrixSymmetricSup(x, q);
    }

    const int n = x.n_rows;

    arma::mat costm = arma::mat(n, n, arma::fill::none);
    const double u = q / p;

    #pragma omp parallel for
    for (int i = 0; i < n; ++i)
    {
        costm(i, i) = 0;

        for (int j = i + 1; j < n; ++j)
        {
            costm(i, j) = std::pow(arma::accu(arma::pow(arma::abs(x.row(i) - x.row(j)), p)), u);
            costm(j, i) = costm(i, j);
        }
    }

    return costm;
}

// [[Rcpp::export(rng = false)]]
bool satisfiesTriangleInequality(const arma::mat& x, const double tol) {
    const int N = x.n_rows;
    for (int i = 0; i < N; ++i) {
        for (int j = 0; j < N; ++j) {
            if (arma::any(x(i, j) - tol > x.row(i).t() + x.col(j)))
            {
                return false;
            }
        }
    }
    return true;
}
