Java源码示例:org.nd4j.linalg.api.ops.executioner.OpExecutioner

示例1
/**
 * swaps a vector with another vector.
 *
 * @param x
 * @param y
 */
@Override
public void swap(INDArray x, INDArray y) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, x, y);

    if (x.isSparse() || y.isSparse()) {
        Nd4j.getSparseBlasWrapper().level1().swap(x, y);
        return;
    }

    if (x.data().dataType() == DataBuffer.Type.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, x, y);
        dswap(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
    } else {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, x, y);
        sswap(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
    }
}
 
示例2
/**
 * symv computes a matrix-vector product for a symmetric matrix:
 * y := alpha*a*x + beta*y.
 * Here a is an n-by-n symmetric matrix; x and y are n-element vectors, alpha and beta are scalars.
 *
 * @param order
 * @param Uplo
 * @param alpha
 * @param A
 * @param X
 * @param beta
 * @param Y
 */
@Override
public void symv(char order, char Uplo, double alpha, INDArray A, INDArray X, double beta, INDArray Y) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, A, X, Y);

    // FIXME: int cast

    if (X.data().dataType() == DataBuffer.Type.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, A, X, Y);
        dsymv(order, Uplo, (int) X.length(), alpha, A, (int) A.size(0), X, X.majorStride(), beta, Y, Y.majorStride());
    } else {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, A, X, Y);
        ssymv(order, Uplo, (int) X.length(), (float) alpha, A, (int) A.size(0), X, X.majorStride(), (float) beta, Y,
                        Y.majorStride());
    }

    OpExecutionerUtil.checkForAny(Y);
}
 
示例3
/**
 * tpsv solves a system of linear equations whose coefficients are in a triangular packed matrix.
 *
 * @param order
 * @param Uplo
 * @param TransA
 * @param Diag
 * @param Ap
 * @param X
 */
@Override
public void tpsv(char order, char Uplo, char TransA, char Diag, INDArray Ap, INDArray X) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, Ap, X);

    // FIXME: int cast

    if (X.data().dataType() == DataBuffer.Type.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, X, Ap);
        dtpsv(order, Uplo, TransA, Diag, (int) X.length(), Ap, X, X.majorStride());
    } else {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, Ap, X);
        stpsv(order, Uplo, TransA, Diag, (int) X.length(), Ap, X, X.majorStride());
    }

    OpExecutionerUtil.checkForAny(X);
}
 
示例4
/**
 * gbmv computes a matrix-vector product using a general band matrix and performs one of the following matrix-vector operations:
 * y := alpha*a*x + beta*y  for trans = 'N'or'n';
 * y := alpha*a'*x + beta*y  for trans = 'T'or't';
 * y := alpha*conjg(a')*x + beta*y  for trans = 'C'or'c'.
 * Here a is an m-by-n band matrix with ku superdiagonals and kl subdiagonals, x and y are vectors, alpha and beta are scalars.
 *
 * @param order
 * @param TransA
 * @param KL
 * @param KU
 * @param alpha
 * @param A
 * @param X
 * @param beta
 * @param Y
 */
@Override
public void gbmv(char order, char TransA, int KL, int KU, double alpha, INDArray A, INDArray X, double beta,
                INDArray Y) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, A, X, Y);

    if (A.data().dataType() == DataType.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X, Y);
        if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE || A.size(0) > Integer.MAX_VALUE)
            throw new ND4JArraySizeException();
        dgbmv(order, TransA, (int) A.rows(), (int) A.columns(), KL, KU, alpha, A, (int) A.size(0), X, X.stride(-1), beta, Y,
                        Y.stride(-1));
    } else {
        DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, X, Y);
        sgbmv(order, TransA, (int) A.rows(), (int) A.columns(), KL, KU, (float) alpha, A, (int) A.size(0), X, X.stride(-1), (float) beta, Y, Y.stride(-1));
    }

    OpExecutionerUtil.checkForAny(Y);
}
 
示例5
/**
 * syr2k performs a rank-2k update of an n-by-n symmetric matrix c, that is, one of the following operations:
 * c := alpha*a*b' + alpha*b*a' + beta*c  for trans = 'N'or'n'
 * c := alpha*a'*b + alpha*b'*a + beta*c  for trans = 'T'or't',
 * where c is an n-by-n symmetric matrix;
 * a and b are n-by-k matrices, if trans = 'N'or'n',
 * a and b are k-by-n matrices, if trans = 'T'or't'.
 * @param Order
 * @param Side
 * @param Uplo
 * @param TransA
 * @param Diag
 * @param alpha
 * @param A
 * @param B
 * @param C
 */
@Override
public void trmm(char Order, char Side, char Uplo, char TransA, char Diag, double alpha, INDArray A, INDArray B,
                INDArray C) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, A, B, C);

    if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE ||
        A.size(0) > Integer.MAX_VALUE || B.size(0) > Integer.MAX_VALUE) {
        throw new ND4JArraySizeException();
    }

    if (A.data().dataType() == DataType.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, B, C);
        dtrmm(Order, Side, Uplo, TransA, Diag, (int) A.rows(), (int) A.columns(), alpha, A, (int) A.size(0), B, (int) B.size(0));
    } else {
        DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, B, C);
        strmm(Order, Side, Uplo, TransA, Diag, (int) A.rows(), (int) A.columns(), (float) alpha, A, (int) A.size(0), B, (int) B.size(0));
    }

    OpExecutionerUtil.checkForAny(C);
}
 
示例6
/**
 * computes the sum of magnitudes of all vector elements or, for a complex vector x, the sum
 *
 * @param arr
 * @return
 */
@Override
public double asum(INDArray arr) {

    if (arr.isSparse()) {
        return Nd4j.getSparseBlasWrapper().level1().asum(arr);
    }
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, arr);

    if (arr.data().dataType() == DataBuffer.Type.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, arr);
        return dasum(arr.length(), arr, BlasBufferUtil.getBlasStride(arr));
    } else if (arr.data().dataType() == DataBuffer.Type.FLOAT) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, arr);
        return sasum(arr.length(), arr, BlasBufferUtil.getBlasStride(arr));
    } else {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.HALF, arr);
        return hasum(arr.length(), arr, BlasBufferUtil.getBlasStride(arr));
    }
}
 
示例7
/**
 * spr performs a rank-1 update of an n-by-n packed symmetric matrix a:
 * a := alpha*x*x' + a.
 *
 * @param order
 * @param Uplo
 * @param alpha
 * @param X
 * @param Ap
 */
@Override
public void spr(char order, char Uplo, double alpha, INDArray X, INDArray Ap) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, Ap, X);


    if (X.length() > Integer.MAX_VALUE)
        throw new ND4JArraySizeException();

    if (X.data().dataType() == DataType.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataType.DOUBLE, X);
        dspr(order, Uplo, (int) X.length(), alpha, X, X.stride(-1), Ap);
    } else {
        DefaultOpExecutioner.validateDataType(DataType.FLOAT, X);
        sspr(order, Uplo, (int) X.length(), (float) alpha, X, X.stride(-1), Ap);
    }

    OpExecutionerUtil.checkForAny(Ap);
}
 
示例8
/**
 * @param order
 * @param Uplo
 * @param alpha
 * @param X
 * @param Y
 * @param A
 */
@Override
public void syr2(char order, char Uplo, double alpha, INDArray X, INDArray Y, INDArray A) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, A, X, Y);

    // FIXME: int cast

    if (X.data().dataType() == DataBuffer.Type.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, A, X, Y);
        dsyr2(order, Uplo, (int) X.length(), alpha, X, X.majorStride(), Y, Y.majorStride(), A, (int) A.size(0));
    } else {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, A, X, Y);
        ssyr2(order, Uplo, (int) X.length(), (float) alpha, X, X.majorStride(), Y, Y.majorStride(), A, (int) A.size(0));
    }

    OpExecutionerUtil.checkForAny(A);
}
 
示例9
/**
 * gemm performs a matrix-matrix operation
 * c := alpha*op(a)*op(b) + beta*c,
 * where c is an m-by-n matrix,
 * op(a) is an m-by-k matrix,
 * op(b) is a k-by-n matrix.
 *  @param Order
 * @param TransA
 * @param TransB
 * @param alpha
 * @param A
 * @param B
 * @param beta
 * @param C
 */
@Override
public void gemm(char Order, char TransA, char TransB, double alpha, INDArray A, INDArray B, double beta,
                INDArray C) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(true, A, B, C);

    GemmParams params = new GemmParams(A, B, C);

    int charOder = Order;
    if (A.data().dataType() == DataBuffer.Type.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, params.getA(), params.getB(), params.getC());
        dgemm(Order, params.getTransA(), params.getTransB(), params.getM(), params.getN(), params.getK(), 1.0,
                        params.getA(), params.getLda(), params.getB(), params.getLdb(), 0, C, params.getLdc());
    } else if (A.data().dataType() == DataBuffer.Type.FLOAT) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, params.getA(), params.getB(), params.getC());
        sgemm(Order, params.getTransA(), params.getTransB(), params.getM(), params.getN(), params.getK(), 1.0f,
                        params.getA(), params.getLda(), params.getB(), params.getLdb(), 0, C, params.getLdc());
    } else {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.HALF, params.getA(), params.getB(), params.getC());
        hgemm(Order, params.getTransA(), params.getTransB(), params.getM(), params.getN(), params.getK(), 1.0f,
                        params.getA(), params.getLda(), params.getB(), params.getLdb(), 0, C, params.getLdc());
    }

    OpExecutionerUtil.checkForAny(C);
}
 
示例10
/**
 * sbmv computes a matrix-vector product using a symmetric band matrix:
 * y := alpha*a*x + beta*y.
 * Here a is an n-by-n symmetric band matrix with k superdiagonals, x and y are n-element vectors, alpha and beta are scalars.
 *
 * @param order
 * @param Uplo
 * @param alpha
 * @param A
 * @param X
 * @param beta
 * @param Y
 */
@Override
public void sbmv(char order, char Uplo, double alpha, INDArray A, INDArray X, double beta, INDArray Y) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, A, X, Y);

    if (X.length() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE || A.size(0) > Integer.MAX_VALUE) {
        throw new ND4JArraySizeException();
    }
    if (X.data().dataType() == DataType.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X, Y);
        dsbmv(order, Uplo, (int) X.length(), (int) A.columns(), alpha, A, (int) A.size(0), X, X.stride(-1), beta, Y,
                Y.stride(-1));
    } else {
        DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, X, Y);
        ssbmv(order, Uplo, (int) X.length(), (int) A.columns(), (float) alpha, A, (int) A.size(0), X, X.stride(-1), (float) beta,
                        Y, Y.stride(-1));
    }

    OpExecutionerUtil.checkForAny(Y);
}
 
示例11
/**
 * her2k performs a rank-2k update of an n-by-n Hermitian matrix c, that is, one of the following operations:
 * c := alpha*a*conjg(b') + conjg(alpha)*b*conjg(a') + beta*c,  for trans = 'N'or'n'
 * c := alpha*conjg(b')*a + conjg(alpha)*conjg(a')*b + beta*c,  for trans = 'C'or'c'
 * where c is an n-by-n Hermitian matrix;
 * a and b are n-by-k matrices if trans = 'N'or'n',
 * a and b are k-by-n matrices if trans = 'C'or'c'.
 *  @param Order
 * @param Side
 * @param Uplo
 * @param alpha
 * @param A
 * @param B
 * @param beta
 * @param C
 */
@Override
public void symm(char Order, char Side, char Uplo, double alpha, INDArray A, INDArray B, double beta, INDArray C) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, A, B, C);

    if (C.rows() > Integer.MAX_VALUE || C.columns() > Integer.MAX_VALUE ||
        A.size(0) > Integer.MAX_VALUE || B.size(0) > Integer.MAX_VALUE || C.size(0) > Integer.MAX_VALUE) {
        throw new ND4JArraySizeException();
    }

    if (A.data().dataType() == DataType.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, B, C);
        dsymm(Order, Side, Uplo, (int) C.rows(), (int) C.columns(), alpha, A, (int) A.size(0), B, (int) B.size(0), beta, C, (int) C.size(0));
    } else {
        DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, B, C);
        ssymm(Order, Side, Uplo, (int) C.rows(), (int) C.columns(), (float) alpha, A, (int) A.size(0), B, (int) B.size(0), (float) beta, C,
                (int) C.size(0));
    }

    OpExecutionerUtil.checkForAny(C);
}
 
示例12
/**
 * tpsv solves a system of linear equations whose coefficients are in a triangular packed matrix.
 *
 * @param order
 * @param Uplo
 * @param TransA
 * @param Diag
 * @param Ap
 * @param X
 */
@Override
public void tpsv(char order, char Uplo, char TransA, char Diag, INDArray Ap, INDArray X) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, Ap, X);

    if (X.length() > Integer.MAX_VALUE)
        throw new ND4JArraySizeException();

    if (X.data().dataType() == DataType.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataType.DOUBLE, X, Ap);
        dtpsv(order, Uplo, TransA, Diag, (int) X.length(), Ap, X, X.stride(-1));
    } else {
        DefaultOpExecutioner.validateDataType(DataType.FLOAT, Ap, X);
        stpsv(order, Uplo, TransA, Diag, (int) X.length(), Ap, X, X.stride(-1));
    }

    OpExecutionerUtil.checkForAny(X);
}
 
示例13
/**
 * computes a vector-scalar product and adds the result to a vector.
 *
 * @param n
 * @param alpha
 * @param x
 * @param y
 */
@Override
public void axpy(long n, double alpha, INDArray x, INDArray y) {

    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, x, y);

    if (x.isSparse() && !y.isSparse()) {
        Nd4j.getSparseBlasWrapper().level1().axpy(n, alpha, x, y);
    } else if (x.data().dataType() == DataBuffer.Type.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, x, y);
        daxpy(n, alpha, x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
    } else if (x.data().dataType() == DataBuffer.Type.FLOAT) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, x, y);
        saxpy(n, (float) alpha, x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
    } else {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.HALF, x, y);
        haxpy(n, (float) alpha, x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
    }
}
 
示例14
/**
 * ?spr2 performs a rank-2 update of an n-by-n packed symmetric matrix a:
 * a := alpha*x*y' + alpha*y*x' + a.
 *
 * @param order
 * @param Uplo
 * @param alpha
 * @param X
 * @param Y
 * @param A
 */
@Override
public void spr2(char order, char Uplo, double alpha, INDArray X, INDArray Y, INDArray A) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, A, X, Y);

    // FIXME int cast

    if (X.data().dataType() == DataBuffer.Type.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, A, X, Y);
        dspr2(order, Uplo, (int) X.length(), alpha, X, X.majorStride(), Y, Y.majorStride(), A);
    } else {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, A, X, Y);
        sspr2(order, Uplo, (int) X.length(), (float) alpha, X, X.majorStride(), Y, Y.majorStride(), A);
    }

    OpExecutionerUtil.checkForAny(A);
}
 
示例15
/**
 * ?spr2 performs a rank-2 update of an n-by-n packed symmetric matrix a:
 * a := alpha*x*y' + alpha*y*x' + a.
 *
 * @param order
 * @param Uplo
 * @param alpha
 * @param X
 * @param Y
 * @param A
 */
@Override
public void spr2(char order, char Uplo, double alpha, INDArray X, INDArray Y, INDArray A) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, A, X, Y);

    if (X.length() > Integer.MAX_VALUE)
        throw new ND4JArraySizeException();

    if (X.data().dataType() == DataType.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X, Y);
        dspr2(order, Uplo, (int) X.length(), alpha, X, X.stride(-1), Y, Y.stride(-1), A);
    } else {
        DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, X, Y);
        sspr2(order, Uplo, (int) X.length(), (float) alpha, X, X.stride(-1), Y, Y.stride(-1), A);
    }

    OpExecutionerUtil.checkForAny(A);
}
 
示例16
/**
 * gemm performs a matrix-matrix operation
 * c := alpha*op(a)*op(b) + beta*c,
 * where c is an m-by-n matrix,
 * op(a) is an m-by-k matrix,
 * op(b) is a k-by-n matrix.
 *  @param Order
 * @param TransA
 * @param TransB
 * @param alpha
 * @param A
 * @param B
 * @param beta
 * @param C
 */
@Override
public void gemm(char Order, char TransA, char TransB, IComplexNumber alpha, IComplexNDArray A, IComplexNDArray B,
                IComplexNumber beta, IComplexNDArray C) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(true, A, B, C);

    GemmParams params = new GemmParams(A, B, C);

    if (A.data().dataType() == DataBuffer.Type.DOUBLE) {
        zgemm(Order, TransA, TransB, params.getM(), params.getN(), params.getK(), alpha.asDouble(),
                        A.ordering() == NDArrayFactory.C ? B : A, params.getLda(),
                        B.ordering() == NDArrayFactory.C ? A : B, params.getLdb(), beta.asDouble(), C,
                        params.getLdc());
    } else
        cgemm(Order, TransA, TransB, params.getM(), params.getN(), params.getK(), alpha.asFloat(),
                        A.ordering() == NDArrayFactory.C ? B : A, params.getLda(),
                        B.ordering() == NDArrayFactory.C ? A : B, params.getLdb(), beta.asFloat(), C,
                        params.getLdc());

}
 
示例17
/**
 * performs a rank-1 update of a general m-by-n matrix a:
 * a := alpha*x*y' + a.
 *
 * @param order
 * @param alpha
 * @param X
 * @param Y
 * @param A
 */
@Override
public void ger(char order, double alpha, INDArray X, INDArray Y, INDArray A) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, A, X, Y);

    if (X.data().dataType() == DataType.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X, Y);
        if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE || A.size(0) > Integer.MAX_VALUE)
            throw new ND4JArraySizeException();
        dger(order, (int) A.rows(), (int) A.columns(), alpha, X, X.stride(-1), Y, Y.stride(-1), A, (int) A.size(0));
    } else {
        DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, X, Y);
        sger(order, (int) A.rows(), (int) A.columns(), (float) alpha, X, X.stride(-1), Y, Y.stride(-1), A, (int) A.size(0));
    }

    OpExecutionerUtil.checkForAny(A);
}
 
示例18
/**
 * computes a vector-scalar product and adds the result to a vector.
 *
 * @param n
 * @param alpha
 * @param x
 * @param y
 */
@Override
public void axpy(long n, double alpha, INDArray x, INDArray y) {

    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, x, y);

    if (x.data().dataType() == DataType.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataType.DOUBLE, x, y);
        daxpy(n, alpha, x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
    } else if (x.data().dataType() == DataType.FLOAT) {
        DefaultOpExecutioner.validateDataType(DataType.FLOAT, x, y);
        saxpy(n, (float) alpha, x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
    } else {
        DefaultOpExecutioner.validateDataType(DataType.HALF, x, y);
        haxpy(n, (float) alpha, x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
    }
}
 
示例19
/**
 * trsv solves a system of linear equations whose coefficients are in a triangular matrix.
 *
 * @param order
 * @param Uplo
 * @param TransA
 * @param Diag
 * @param A
 * @param X
 */
@Override
public void trsv(char order, char Uplo, char TransA, char Diag, INDArray A, INDArray X) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, A, X);

    if (A.length() > Integer.MAX_VALUE || A.size(0) > Integer.MAX_VALUE)
        throw new ND4JArraySizeException();

    if (X.data().dataType() == DataType.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X);
        dtrsv(order, Uplo, TransA, Diag, (int) A.length(), A, (int) A.size(0), X, X.stride(-1));
    } else {
        DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, X);
        strsv(order, Uplo, TransA, Diag, (int) A.length(), A, (int) A.size(0), X, X.stride(-1));
    }

    OpExecutionerUtil.checkForAny(X);
}
 
示例20
/**
 * syr performs a rank-1 update of an n-by-n symmetric matrix a:
 * a := alpha*x*x' + a.
 *
 * @param order
 * @param Uplo
 * @param N
 * @param alpha
 * @param X
 * @param A
 */
@Override
public void syr(char order, char Uplo, int N, double alpha, INDArray X, INDArray A) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, A, X);

    // FIXME: int cast

    if (X.data().dataType() == DataBuffer.Type.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, A, X);
        dsyr(order, Uplo, (int) X.length(), alpha, X, X.majorStride(), A, (int) A.size(0));
    } else {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, A, X);
        ssyr(order, Uplo, (int) X.length(), (float) alpha, X, X.majorStride(), A, (int) A.size(0));
    }

    OpExecutionerUtil.checkForAny(A);
}
 
示例21
/**
 * syrk performs a rank-n update of an n-by-n symmetric matrix c, that is, one of the following operations:
 * c := alpha*a*a' + beta*c  for trans = 'N'or'n'
 * c := alpha*a'*a + beta*c  for trans = 'T'or't','C'or'c',
 * where c is an n-by-n symmetric matrix;
 * a is an n-by-k matrix, if trans = 'N'or'n',
 * a is a k-by-n matrix, if trans = 'T'or't','C'or'c'.
 *  @param Order
 * @param Uplo
 * @param Trans
 * @param alpha
 * @param A
 * @param beta
 * @param C
 */
@Override
public void syrk(char Order, char Uplo, char Trans, double alpha, INDArray A, double beta, INDArray C) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, A, C);

    // FIXME: int cast

    if (A.data().dataType() == DataBuffer.Type.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, A, C);
        dsyrk(Order, Uplo, Trans, (int) C.rows(), 1, alpha, A, (int) A.size(0), beta, C, (int) C.size(0));
    } else {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, A, C);
        ssyrk(Order, Uplo, Trans, (int) C.rows(), 1, (float) alpha, A, (int) A.size(0), (float) beta, C, (int) C.size(0));
    }

    OpExecutionerUtil.checkForAny(C);
}
 
示例22
/**
 * ?trsm solves one of the following matrix equations:
 * op(a)*x = alpha*b  or  x*op(a) = alpha*b,
 * where x and b are m-by-n general matrices, and a is triangular;
 * op(a) must be an m-by-m matrix, if side = 'L'or'l'
 * op(a) must be an n-by-n matrix, if side = 'R'or'r'.
 * For the definition of op(a), see Matrix Arguments.
 * The routine overwrites x on b.
 *  @param Order
 * @param Side
 * @param Uplo
 * @param TransA
 * @param Diag
 * @param alpha
 * @param A
 * @param B
 */
@Override
public void trsm(char Order, char Side, char Uplo, char TransA, char Diag, double alpha, INDArray A, INDArray B) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, A, B);

    // FIXME: int cast

    if (A.data().dataType() == DataBuffer.Type.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, A, B);
        dtrsm(Order, Side, Uplo, TransA, Diag, (int) A.rows(), (int) A.columns(), alpha, A, (int) A.size(0), B, (int) B.size(0));
    } else {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, A, B);
        strsm(Order, Side, Uplo, TransA, Diag, (int) A.rows(), (int) A.columns(), (float) alpha, A, (int) A.size(0), B, (int) B.size(0));
    }

    OpExecutionerUtil.checkForAny(B);
}
 
示例23
@Override
public INDArray putScalar(long i, double value) {
    if (i < 0)
        i += rank();
    if (isScalar()) {
        if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.DISABLED && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.SCOPE_PANIC)
            OpProfiler.getInstance().processScalarCall();

        addOrUpdate(new long[] {0, 0}, value);
        return this;
    }
    if (isRowVector()) {
        addOrUpdate(new long[] {0, i}, value);
        return this;
    } else if (isColumnVector()) {
        addOrUpdate(new long[] {i, 0}, value);
        return this;
    }
    long[] indexes = ordering() == 'c' ? Shape.ind2subC(this, i) : Shape.ind2sub(this, i);
    return putScalar(indexes, value);
}
 
示例24
@Override
public double getDouble(long i) {
    if (i >= length()) {
        throw new IllegalArgumentException("Unable to get linear index >= " + length());
    }

    if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.DISABLED && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.SCOPE_PANIC)
        OpProfiler.getInstance().processScalarCall();

    if (i == 0)
        return data().getDouble(i);

    long[] dimensions = ordering() == 'c' ? Shape.ind2subC(this, i) : Shape.ind2sub(this, i);
    Shape.assertShapeLessThan(dimensions, shape());
    return getDouble(dimensions);
}
 
示例25
/**
 * @param order
 * @param Uplo
 * @param alpha
 * @param X
 * @param Y
 * @param A
 */
@Override
public void syr2(char order, char Uplo, double alpha, INDArray X, INDArray Y, INDArray A) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, A, X, Y);

    if (X.length() > Integer.MAX_VALUE || A.size(0) > Integer.MAX_VALUE)
        throw new ND4JArraySizeException();

    if (X.data().dataType() == DataType.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X, Y);
        dsyr2(order, Uplo, (int) X.length(), alpha, X, X.stride(-1), Y, Y.stride(-1), A, (int) A.size(0));
    } else {
        DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, X, Y);
        ssyr2(order, Uplo, (int) X.length(), (float) alpha, X, X.stride(-1), Y, Y.stride(-1), A, (int) A.size(0));
    }

    OpExecutionerUtil.checkForAny(A);
}
 
示例26
/**
 * gemv computes a matrix-vector product using a general matrix and performs one of the following matrix-vector operations:
 * y := alpha*a*x + beta*y  for trans = 'N'or'n';
 * y := alpha*a'*x + beta*y  for trans = 'T'or't';
 * y := alpha*conjg(a')*x + beta*y  for trans = 'C'or'c'.
 * Here a is an m-by-n band matrix, x and y are vectors, alpha and beta are scalars.
 *
 * @param order
 * @param transA
 * @param alpha
 * @param A
 * @param X
 * @param beta
 * @param Y
 */
@Override
public void gemv(char order, char transA, IComplexNumber alpha, IComplexNDArray A, IComplexNDArray X,
                IComplexNumber beta, IComplexNDArray Y) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, A, X, Y);

    GemvParameters parameters = new GemvParameters(A, X, Y);

    if (A.data().dataType() == DataBuffer.Type.DOUBLE)
        zgemv(order, transA, parameters.getM(), parameters.getN(), alpha.asDouble(), A, parameters.getLda(), X,
                        parameters.getIncx(), beta.asDouble(), Y, parameters.getIncy());
    else
        cgemv(order, transA, parameters.getM(), parameters.getN(), alpha.asFloat(), A, parameters.getLda(), X,
                        parameters.getIncx(), beta.asFloat(), Y, parameters.getIncy());

}
 
示例27
/**
 * swaps a vector with another vector.
 *
 * @param x
 * @param y
 */
@Override
public void copy(INDArray x, INDArray y) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, x, y);

    if (x.isSparse() || y.isSparse()) {
        Nd4j.getSparseBlasWrapper().level1().copy(x, y);
        return;
    }
    if (x.data().dataType() == DataBuffer.Type.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, x, y);
        dcopy(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
    } else {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, x, y);
        scopy(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
    }
}
 
示例28
/**
 * performs a rank-1 update of a general m-by-n matrix a:
 * a := alpha*x*y' + a.
 *
 * @param order
 * @param alpha
 * @param X
 * @param Y
 * @param A
 */
@Override
public void ger(char order, double alpha, INDArray X, INDArray Y, INDArray A) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, A, X, Y);

    // FIXME: int cast

    if (X.data().dataType() == DataBuffer.Type.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, A, X, Y);
        dger(order, (int) A.rows(), (int) A.columns(), alpha, X, X.majorStride(), Y, Y.majorStride(), A, (int) A.size(0));
    } else {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, A, X, Y);
        sger(order, (int) A.rows(), (int) A.columns(), (float) alpha, X, X.majorStride(), Y, Y.majorStride(), A, (int) A.size(0));
    }

    OpExecutionerUtil.checkForAny(A);
}
 
示例29
/**
 * sbmv computes a matrix-vector product using a symmetric band matrix:
 * y := alpha*a*x + beta*y.
 * Here a is an n-by-n symmetric band matrix with k superdiagonals, x and y are n-element vectors, alpha and beta are scalars.
 *
 * @param order
 * @param Uplo
 * @param alpha
 * @param A
 * @param X
 * @param beta
 * @param Y
 */
@Override
public void sbmv(char order, char Uplo, double alpha, INDArray A, INDArray X, double beta, INDArray Y) {
    if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
        OpProfiler.getInstance().processBlasCall(false, A, X, Y);

    // FIXME: int cast

    if (X.data().dataType() == DataBuffer.Type.DOUBLE) {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, A, X, Y);
        dsbmv(order, Uplo, (int) X.length(), (int) A.columns(), alpha, A, (int) A.size(0), X, X.majorStride(), beta, Y,
                (int) Y.majorStride());
    } else {
        DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, A, X, Y);
        ssbmv(order, Uplo, (int) X.length(), (int) A.columns(), (float) alpha, A, (int) A.size(0), X, X.majorStride(), (float) beta,
                        Y, Y.majorStride());
    }

    OpExecutionerUtil.checkForAny(Y);
}
 
示例30
@Test
@Ignore
public void testConversion() throws Exception {
    SameDiff sameDiff = SameDiff.create();
    INDArray ones = Nd4j.ones(4);
    SDVariable sdVariable = sameDiff.var("ones",ones);
    SDVariable result = sdVariable.addi(1.0);
    SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE);

    val executioner = new NativeGraphExecutioner();

    ByteBuffer buffer = executioner.convertToFlatBuffers(sameDiff, ExecutorConfiguration.builder().profilingMode(OpExecutioner.ProfilingMode.DISABLED).executionMode(ExecutionMode.SEQUENTIAL).outputMode(OutputMode.IMPLICIT).build());

    val offset = buffer.position();
    val array = buffer.array();

    try (val fos = new FileOutputStream("../../libnd4j/tests/resources/adam_sum.fb"); val dos = new DataOutputStream(fos)) {
        dos.write(array, offset, array.length - offset);
    }


    //INDArray[] res = executioner.executeGraph(sameDiff);
    //assertEquals(8.0, res[0].getDouble(0), 1e-5);
    /*
    INDArray output = null;
    for(int i = 0; i < 5; i++) {
        output = sameDiff.execAndEndResult(ops);
        System.out.println("Ones " + ones);
        System.out.println(output);
    }

    assertEquals(Nd4j.valueArrayOf(4,7),ones);
    assertEquals(28,output.getDouble(0),1e-1);
    */
}