/*
 * Decompiled with CFR 0.152.
 */
package org.ojalgo.tensor;

import java.util.Arrays;
import org.ojalgo.function.FunctionSet;
import org.ojalgo.scalar.Scalar;
import org.ojalgo.structure.Access1D;
import org.ojalgo.structure.Access2D;
import org.ojalgo.structure.AccessAnyD;
import org.ojalgo.structure.FactoryAnyD;
import org.ojalgo.structure.MutateAnyD;

public final class TensorFactoryAnyD<N extends Comparable<N>, T extends MutateAnyD>
implements FactoryAnyD<T> {
    private final FactoryAnyD<T> myFactory;

    public static <N extends Comparable<N>, T extends MutateAnyD> TensorFactoryAnyD<N, T> of(FactoryAnyD<T> factory) {
        return new TensorFactoryAnyD<N, T>(factory);
    }

    TensorFactoryAnyD(FactoryAnyD<T> factory) {
        this.myFactory = factory;
    }

    public T blocks(AccessAnyD<N> ... tensors) {
        int rank = 1;
        for (AccessAnyD<N> tensor : tensors) {
            rank = Math.max(rank, tensor.shape().length);
        }
        long[] structure = new long[rank];
        for (int r = 0; r < structure.length; ++r) {
            long count = 0L;
            for (int t = 0; t < tensors.length; ++t) {
                count += tensors[t].count(r);
            }
            structure[r] = count;
        }
        MutateAnyD retVal = (MutateAnyD)this.myFactory.make(structure);
        long[] offset = new long[rank];
        long[] outRef = new long[rank];
        for (AccessAnyD tensor : tensors) {
            tensor.loopAll(inRef -> {
                double value = tensor.doubleValue(inRef);
                System.arraycopy(offset, 0, outRef, 0, offset.length);
                for (int i = 0; i < inRef.length; ++i) {
                    int n = i;
                    outRef[n] = outRef[n] + inRef[i];
                }
                retVal.set(outRef, value);
            });
            for (int i = 0; i < offset.length; ++i) {
                int n = i;
                offset[n] = offset[n] + tensor.count(i);
            }
        }
        return (T)retVal;
    }

    public T copy(Access1D<N> elements) {
        MutateAnyD retVal = (MutateAnyD)this.myFactory.make(elements.count());
        for (long i = 0L; i < elements.count(); ++i) {
            retVal.set(i, (Comparable<?>)elements.get(i));
        }
        return (T)retVal;
    }

    public T copy(Access2D<N> elements) {
        MutateAnyD retVal = (MutateAnyD)this.myFactory.make(elements.countRows(), elements.countColumns());
        for (long i = 0L; i < elements.count(); ++i) {
            retVal.set(i, (Comparable<?>)elements.get(i));
        }
        return (T)retVal;
    }

    public T copy(AccessAnyD<N> elements) {
        MutateAnyD retVal = (MutateAnyD)this.myFactory.make(elements);
        for (long i = 0L; i < elements.count(); ++i) {
            retVal.set(i, (Comparable<?>)elements.get(i));
        }
        return (T)retVal;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!super.equals(obj) || !(obj instanceof TensorFactoryAnyD)) {
            return false;
        }
        TensorFactoryAnyD other = (TensorFactoryAnyD)obj;
        return !(this.myFactory == null ? other.myFactory != null : !this.myFactory.equals(other.myFactory));
    }

    public FunctionSet<N> function() {
        return this.myFactory.function();
    }

    public int hashCode() {
        int prime = 31;
        int result = super.hashCode();
        result = 31 * result + (this.myFactory == null ? 0 : this.myFactory.hashCode());
        return result;
    }

    @Override
    public T make(long ... structure) {
        return (T)((MutateAnyD)this.myFactory.make(structure));
    }

    public T power(Access1D<N> vector, int exponent) {
        Object[] vectors = new Access1D[exponent];
        Arrays.fill(vectors, vector);
        return this.product((Access1D<?>[])vectors);
    }

    public T product(Access1D<?> ... vectors) {
        int rank = vectors.length;
        long[] shape = new long[rank];
        for (int i = 0; i < rank; ++i) {
            shape[i] = vectors[i].count();
        }
        MutateAnyD retVal = (MutateAnyD)this.myFactory.make(shape);
        retVal.loopAll(ref -> {
            double val = 1.0;
            for (int d = 0; d < ref.length; ++d) {
                val *= vectors[d].doubleValue(ref[d]);
            }
            retVal.set(ref, val);
        });
        return (T)retVal;
    }

    public Scalar.Factory<N> scalar() {
        return this.myFactory.scalar();
    }

    public T sum(Access1D<N> ... vectors) {
        long dimensions = 0L;
        for (Access1D<N> vector : vectors) {
            dimensions += vector.count();
        }
        MutateAnyD retVal = (MutateAnyD)this.myFactory.make(dimensions);
        long offset = 0L;
        for (Access1D<N> vector : vectors) {
            long limit = vector.count();
            int i = 0;
            while ((long)i < limit) {
                retVal.set(offset + (long)i, vector.doubleValue(i));
                ++i;
            }
            offset += limit;
        }
        return (T)retVal;
    }
}

