Java源码示例:org.nd4j.linalg.api.ops.executioner.OpExecutioner
示例1
/**
* swaps a vector with another vector.
*
* @param x
* @param y
*/
@Override
public void swap(INDArray x, INDArray y) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, x, y);
if (x.isSparse() || y.isSparse()) {
Nd4j.getSparseBlasWrapper().level1().swap(x, y);
return;
}
if (x.data().dataType() == DataBuffer.Type.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, x, y);
dswap(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
} else {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, x, y);
sswap(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
}
}
示例2
/**
* symv computes a matrix-vector product for a symmetric matrix:
* y := alpha*a*x + beta*y.
* Here a is an n-by-n symmetric matrix; x and y are n-element vectors, alpha and beta are scalars.
*
* @param order
* @param Uplo
* @param alpha
* @param A
* @param X
* @param beta
* @param Y
*/
@Override
public void symv(char order, char Uplo, double alpha, INDArray A, INDArray X, double beta, INDArray Y) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, A, X, Y);
// FIXME: int cast
if (X.data().dataType() == DataBuffer.Type.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, A, X, Y);
dsymv(order, Uplo, (int) X.length(), alpha, A, (int) A.size(0), X, X.majorStride(), beta, Y, Y.majorStride());
} else {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, A, X, Y);
ssymv(order, Uplo, (int) X.length(), (float) alpha, A, (int) A.size(0), X, X.majorStride(), (float) beta, Y,
Y.majorStride());
}
OpExecutionerUtil.checkForAny(Y);
}
示例3
/**
* tpsv solves a system of linear equations whose coefficients are in a triangular packed matrix.
*
* @param order
* @param Uplo
* @param TransA
* @param Diag
* @param Ap
* @param X
*/
@Override
public void tpsv(char order, char Uplo, char TransA, char Diag, INDArray Ap, INDArray X) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, Ap, X);
// FIXME: int cast
if (X.data().dataType() == DataBuffer.Type.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, X, Ap);
dtpsv(order, Uplo, TransA, Diag, (int) X.length(), Ap, X, X.majorStride());
} else {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, Ap, X);
stpsv(order, Uplo, TransA, Diag, (int) X.length(), Ap, X, X.majorStride());
}
OpExecutionerUtil.checkForAny(X);
}
示例4
/**
* gbmv computes a matrix-vector product using a general band matrix and performs one of the following matrix-vector operations:
* y := alpha*a*x + beta*y for trans = 'N'or'n';
* y := alpha*a'*x + beta*y for trans = 'T'or't';
* y := alpha*conjg(a')*x + beta*y for trans = 'C'or'c'.
* Here a is an m-by-n band matrix with ku superdiagonals and kl subdiagonals, x and y are vectors, alpha and beta are scalars.
*
* @param order
* @param TransA
* @param KL
* @param KU
* @param alpha
* @param A
* @param X
* @param beta
* @param Y
*/
@Override
public void gbmv(char order, char TransA, int KL, int KU, double alpha, INDArray A, INDArray X, double beta,
INDArray Y) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, A, X, Y);
if (A.data().dataType() == DataType.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X, Y);
if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE || A.size(0) > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
dgbmv(order, TransA, (int) A.rows(), (int) A.columns(), KL, KU, alpha, A, (int) A.size(0), X, X.stride(-1), beta, Y,
Y.stride(-1));
} else {
DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, X, Y);
sgbmv(order, TransA, (int) A.rows(), (int) A.columns(), KL, KU, (float) alpha, A, (int) A.size(0), X, X.stride(-1), (float) beta, Y, Y.stride(-1));
}
OpExecutionerUtil.checkForAny(Y);
}
示例5
/**
* syr2k performs a rank-2k update of an n-by-n symmetric matrix c, that is, one of the following operations:
* c := alpha*a*b' + alpha*b*a' + beta*c for trans = 'N'or'n'
* c := alpha*a'*b + alpha*b'*a + beta*c for trans = 'T'or't',
* where c is an n-by-n symmetric matrix;
* a and b are n-by-k matrices, if trans = 'N'or'n',
* a and b are k-by-n matrices, if trans = 'T'or't'.
* @param Order
* @param Side
* @param Uplo
* @param TransA
* @param Diag
* @param alpha
* @param A
* @param B
* @param C
*/
@Override
public void trmm(char Order, char Side, char Uplo, char TransA, char Diag, double alpha, INDArray A, INDArray B,
INDArray C) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, A, B, C);
if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE ||
A.size(0) > Integer.MAX_VALUE || B.size(0) > Integer.MAX_VALUE) {
throw new ND4JArraySizeException();
}
if (A.data().dataType() == DataType.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, B, C);
dtrmm(Order, Side, Uplo, TransA, Diag, (int) A.rows(), (int) A.columns(), alpha, A, (int) A.size(0), B, (int) B.size(0));
} else {
DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, B, C);
strmm(Order, Side, Uplo, TransA, Diag, (int) A.rows(), (int) A.columns(), (float) alpha, A, (int) A.size(0), B, (int) B.size(0));
}
OpExecutionerUtil.checkForAny(C);
}
示例6
/**
* computes the sum of magnitudes of all vector elements or, for a complex vector x, the sum
*
* @param arr
* @return
*/
@Override
public double asum(INDArray arr) {
if (arr.isSparse()) {
return Nd4j.getSparseBlasWrapper().level1().asum(arr);
}
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, arr);
if (arr.data().dataType() == DataBuffer.Type.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, arr);
return dasum(arr.length(), arr, BlasBufferUtil.getBlasStride(arr));
} else if (arr.data().dataType() == DataBuffer.Type.FLOAT) {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, arr);
return sasum(arr.length(), arr, BlasBufferUtil.getBlasStride(arr));
} else {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.HALF, arr);
return hasum(arr.length(), arr, BlasBufferUtil.getBlasStride(arr));
}
}
示例7
/**
* spr performs a rank-1 update of an n-by-n packed symmetric matrix a:
* a := alpha*x*x' + a.
*
* @param order
* @param Uplo
* @param alpha
* @param X
* @param Ap
*/
@Override
public void spr(char order, char Uplo, double alpha, INDArray X, INDArray Ap) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, Ap, X);
if (X.length() > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
if (X.data().dataType() == DataType.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, X);
dspr(order, Uplo, (int) X.length(), alpha, X, X.stride(-1), Ap);
} else {
DefaultOpExecutioner.validateDataType(DataType.FLOAT, X);
sspr(order, Uplo, (int) X.length(), (float) alpha, X, X.stride(-1), Ap);
}
OpExecutionerUtil.checkForAny(Ap);
}
示例8
/**
* @param order
* @param Uplo
* @param alpha
* @param X
* @param Y
* @param A
*/
@Override
public void syr2(char order, char Uplo, double alpha, INDArray X, INDArray Y, INDArray A) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, A, X, Y);
// FIXME: int cast
if (X.data().dataType() == DataBuffer.Type.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, A, X, Y);
dsyr2(order, Uplo, (int) X.length(), alpha, X, X.majorStride(), Y, Y.majorStride(), A, (int) A.size(0));
} else {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, A, X, Y);
ssyr2(order, Uplo, (int) X.length(), (float) alpha, X, X.majorStride(), Y, Y.majorStride(), A, (int) A.size(0));
}
OpExecutionerUtil.checkForAny(A);
}
示例9
/**
* gemm performs a matrix-matrix operation
* c := alpha*op(a)*op(b) + beta*c,
* where c is an m-by-n matrix,
* op(a) is an m-by-k matrix,
* op(b) is a k-by-n matrix.
* @param Order
* @param TransA
* @param TransB
* @param alpha
* @param A
* @param B
* @param beta
* @param C
*/
@Override
public void gemm(char Order, char TransA, char TransB, double alpha, INDArray A, INDArray B, double beta,
INDArray C) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(true, A, B, C);
GemmParams params = new GemmParams(A, B, C);
int charOder = Order;
if (A.data().dataType() == DataBuffer.Type.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, params.getA(), params.getB(), params.getC());
dgemm(Order, params.getTransA(), params.getTransB(), params.getM(), params.getN(), params.getK(), 1.0,
params.getA(), params.getLda(), params.getB(), params.getLdb(), 0, C, params.getLdc());
} else if (A.data().dataType() == DataBuffer.Type.FLOAT) {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, params.getA(), params.getB(), params.getC());
sgemm(Order, params.getTransA(), params.getTransB(), params.getM(), params.getN(), params.getK(), 1.0f,
params.getA(), params.getLda(), params.getB(), params.getLdb(), 0, C, params.getLdc());
} else {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.HALF, params.getA(), params.getB(), params.getC());
hgemm(Order, params.getTransA(), params.getTransB(), params.getM(), params.getN(), params.getK(), 1.0f,
params.getA(), params.getLda(), params.getB(), params.getLdb(), 0, C, params.getLdc());
}
OpExecutionerUtil.checkForAny(C);
}
示例10
/**
* sbmv computes a matrix-vector product using a symmetric band matrix:
* y := alpha*a*x + beta*y.
* Here a is an n-by-n symmetric band matrix with k superdiagonals, x and y are n-element vectors, alpha and beta are scalars.
*
* @param order
* @param Uplo
* @param alpha
* @param A
* @param X
* @param beta
* @param Y
*/
@Override
public void sbmv(char order, char Uplo, double alpha, INDArray A, INDArray X, double beta, INDArray Y) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, A, X, Y);
if (X.length() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE || A.size(0) > Integer.MAX_VALUE) {
throw new ND4JArraySizeException();
}
if (X.data().dataType() == DataType.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X, Y);
dsbmv(order, Uplo, (int) X.length(), (int) A.columns(), alpha, A, (int) A.size(0), X, X.stride(-1), beta, Y,
Y.stride(-1));
} else {
DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, X, Y);
ssbmv(order, Uplo, (int) X.length(), (int) A.columns(), (float) alpha, A, (int) A.size(0), X, X.stride(-1), (float) beta,
Y, Y.stride(-1));
}
OpExecutionerUtil.checkForAny(Y);
}
示例11
/**
* her2k performs a rank-2k update of an n-by-n Hermitian matrix c, that is, one of the following operations:
* c := alpha*a*conjg(b') + conjg(alpha)*b*conjg(a') + beta*c, for trans = 'N'or'n'
* c := alpha*conjg(b')*a + conjg(alpha)*conjg(a')*b + beta*c, for trans = 'C'or'c'
* where c is an n-by-n Hermitian matrix;
* a and b are n-by-k matrices if trans = 'N'or'n',
* a and b are k-by-n matrices if trans = 'C'or'c'.
* @param Order
* @param Side
* @param Uplo
* @param alpha
* @param A
* @param B
* @param beta
* @param C
*/
@Override
public void symm(char Order, char Side, char Uplo, double alpha, INDArray A, INDArray B, double beta, INDArray C) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, A, B, C);
if (C.rows() > Integer.MAX_VALUE || C.columns() > Integer.MAX_VALUE ||
A.size(0) > Integer.MAX_VALUE || B.size(0) > Integer.MAX_VALUE || C.size(0) > Integer.MAX_VALUE) {
throw new ND4JArraySizeException();
}
if (A.data().dataType() == DataType.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, B, C);
dsymm(Order, Side, Uplo, (int) C.rows(), (int) C.columns(), alpha, A, (int) A.size(0), B, (int) B.size(0), beta, C, (int) C.size(0));
} else {
DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, B, C);
ssymm(Order, Side, Uplo, (int) C.rows(), (int) C.columns(), (float) alpha, A, (int) A.size(0), B, (int) B.size(0), (float) beta, C,
(int) C.size(0));
}
OpExecutionerUtil.checkForAny(C);
}
示例12
/**
* tpsv solves a system of linear equations whose coefficients are in a triangular packed matrix.
*
* @param order
* @param Uplo
* @param TransA
* @param Diag
* @param Ap
* @param X
*/
@Override
public void tpsv(char order, char Uplo, char TransA, char Diag, INDArray Ap, INDArray X) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, Ap, X);
if (X.length() > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
if (X.data().dataType() == DataType.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, X, Ap);
dtpsv(order, Uplo, TransA, Diag, (int) X.length(), Ap, X, X.stride(-1));
} else {
DefaultOpExecutioner.validateDataType(DataType.FLOAT, Ap, X);
stpsv(order, Uplo, TransA, Diag, (int) X.length(), Ap, X, X.stride(-1));
}
OpExecutionerUtil.checkForAny(X);
}
示例13
/**
* computes a vector-scalar product and adds the result to a vector.
*
* @param n
* @param alpha
* @param x
* @param y
*/
@Override
public void axpy(long n, double alpha, INDArray x, INDArray y) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, x, y);
if (x.isSparse() && !y.isSparse()) {
Nd4j.getSparseBlasWrapper().level1().axpy(n, alpha, x, y);
} else if (x.data().dataType() == DataBuffer.Type.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, x, y);
daxpy(n, alpha, x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
} else if (x.data().dataType() == DataBuffer.Type.FLOAT) {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, x, y);
saxpy(n, (float) alpha, x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
} else {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.HALF, x, y);
haxpy(n, (float) alpha, x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
}
}
示例14
/**
* ?spr2 performs a rank-2 update of an n-by-n packed symmetric matrix a:
* a := alpha*x*y' + alpha*y*x' + a.
*
* @param order
* @param Uplo
* @param alpha
* @param X
* @param Y
* @param A
*/
@Override
public void spr2(char order, char Uplo, double alpha, INDArray X, INDArray Y, INDArray A) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, A, X, Y);
// FIXME int cast
if (X.data().dataType() == DataBuffer.Type.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, A, X, Y);
dspr2(order, Uplo, (int) X.length(), alpha, X, X.majorStride(), Y, Y.majorStride(), A);
} else {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, A, X, Y);
sspr2(order, Uplo, (int) X.length(), (float) alpha, X, X.majorStride(), Y, Y.majorStride(), A);
}
OpExecutionerUtil.checkForAny(A);
}
示例15
/**
* ?spr2 performs a rank-2 update of an n-by-n packed symmetric matrix a:
* a := alpha*x*y' + alpha*y*x' + a.
*
* @param order
* @param Uplo
* @param alpha
* @param X
* @param Y
* @param A
*/
@Override
public void spr2(char order, char Uplo, double alpha, INDArray X, INDArray Y, INDArray A) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, A, X, Y);
if (X.length() > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
if (X.data().dataType() == DataType.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X, Y);
dspr2(order, Uplo, (int) X.length(), alpha, X, X.stride(-1), Y, Y.stride(-1), A);
} else {
DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, X, Y);
sspr2(order, Uplo, (int) X.length(), (float) alpha, X, X.stride(-1), Y, Y.stride(-1), A);
}
OpExecutionerUtil.checkForAny(A);
}
示例16
/**
* gemm performs a matrix-matrix operation
* c := alpha*op(a)*op(b) + beta*c,
* where c is an m-by-n matrix,
* op(a) is an m-by-k matrix,
* op(b) is a k-by-n matrix.
* @param Order
* @param TransA
* @param TransB
* @param alpha
* @param A
* @param B
* @param beta
* @param C
*/
@Override
public void gemm(char Order, char TransA, char TransB, IComplexNumber alpha, IComplexNDArray A, IComplexNDArray B,
IComplexNumber beta, IComplexNDArray C) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(true, A, B, C);
GemmParams params = new GemmParams(A, B, C);
if (A.data().dataType() == DataBuffer.Type.DOUBLE) {
zgemm(Order, TransA, TransB, params.getM(), params.getN(), params.getK(), alpha.asDouble(),
A.ordering() == NDArrayFactory.C ? B : A, params.getLda(),
B.ordering() == NDArrayFactory.C ? A : B, params.getLdb(), beta.asDouble(), C,
params.getLdc());
} else
cgemm(Order, TransA, TransB, params.getM(), params.getN(), params.getK(), alpha.asFloat(),
A.ordering() == NDArrayFactory.C ? B : A, params.getLda(),
B.ordering() == NDArrayFactory.C ? A : B, params.getLdb(), beta.asFloat(), C,
params.getLdc());
}
示例17
/**
* performs a rank-1 update of a general m-by-n matrix a:
* a := alpha*x*y' + a.
*
* @param order
* @param alpha
* @param X
* @param Y
* @param A
*/
@Override
public void ger(char order, double alpha, INDArray X, INDArray Y, INDArray A) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, A, X, Y);
if (X.data().dataType() == DataType.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X, Y);
if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE || A.size(0) > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
dger(order, (int) A.rows(), (int) A.columns(), alpha, X, X.stride(-1), Y, Y.stride(-1), A, (int) A.size(0));
} else {
DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, X, Y);
sger(order, (int) A.rows(), (int) A.columns(), (float) alpha, X, X.stride(-1), Y, Y.stride(-1), A, (int) A.size(0));
}
OpExecutionerUtil.checkForAny(A);
}
示例18
/**
* computes a vector-scalar product and adds the result to a vector.
*
* @param n
* @param alpha
* @param x
* @param y
*/
@Override
public void axpy(long n, double alpha, INDArray x, INDArray y) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, x, y);
if (x.data().dataType() == DataType.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, x, y);
daxpy(n, alpha, x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
} else if (x.data().dataType() == DataType.FLOAT) {
DefaultOpExecutioner.validateDataType(DataType.FLOAT, x, y);
saxpy(n, (float) alpha, x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
} else {
DefaultOpExecutioner.validateDataType(DataType.HALF, x, y);
haxpy(n, (float) alpha, x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
}
}
示例19
/**
* trsv solves a system of linear equations whose coefficients are in a triangular matrix.
*
* @param order
* @param Uplo
* @param TransA
* @param Diag
* @param A
* @param X
*/
@Override
public void trsv(char order, char Uplo, char TransA, char Diag, INDArray A, INDArray X) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, A, X);
if (A.length() > Integer.MAX_VALUE || A.size(0) > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
if (X.data().dataType() == DataType.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X);
dtrsv(order, Uplo, TransA, Diag, (int) A.length(), A, (int) A.size(0), X, X.stride(-1));
} else {
DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, X);
strsv(order, Uplo, TransA, Diag, (int) A.length(), A, (int) A.size(0), X, X.stride(-1));
}
OpExecutionerUtil.checkForAny(X);
}
示例20
/**
* syr performs a rank-1 update of an n-by-n symmetric matrix a:
* a := alpha*x*x' + a.
*
* @param order
* @param Uplo
* @param N
* @param alpha
* @param X
* @param A
*/
@Override
public void syr(char order, char Uplo, int N, double alpha, INDArray X, INDArray A) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, A, X);
// FIXME: int cast
if (X.data().dataType() == DataBuffer.Type.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, A, X);
dsyr(order, Uplo, (int) X.length(), alpha, X, X.majorStride(), A, (int) A.size(0));
} else {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, A, X);
ssyr(order, Uplo, (int) X.length(), (float) alpha, X, X.majorStride(), A, (int) A.size(0));
}
OpExecutionerUtil.checkForAny(A);
}
示例21
/**
* syrk performs a rank-n update of an n-by-n symmetric matrix c, that is, one of the following operations:
* c := alpha*a*a' + beta*c for trans = 'N'or'n'
* c := alpha*a'*a + beta*c for trans = 'T'or't','C'or'c',
* where c is an n-by-n symmetric matrix;
* a is an n-by-k matrix, if trans = 'N'or'n',
* a is a k-by-n matrix, if trans = 'T'or't','C'or'c'.
* @param Order
* @param Uplo
* @param Trans
* @param alpha
* @param A
* @param beta
* @param C
*/
@Override
public void syrk(char Order, char Uplo, char Trans, double alpha, INDArray A, double beta, INDArray C) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, A, C);
// FIXME: int cast
if (A.data().dataType() == DataBuffer.Type.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, A, C);
dsyrk(Order, Uplo, Trans, (int) C.rows(), 1, alpha, A, (int) A.size(0), beta, C, (int) C.size(0));
} else {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, A, C);
ssyrk(Order, Uplo, Trans, (int) C.rows(), 1, (float) alpha, A, (int) A.size(0), (float) beta, C, (int) C.size(0));
}
OpExecutionerUtil.checkForAny(C);
}
示例22
/**
* ?trsm solves one of the following matrix equations:
* op(a)*x = alpha*b or x*op(a) = alpha*b,
* where x and b are m-by-n general matrices, and a is triangular;
* op(a) must be an m-by-m matrix, if side = 'L'or'l'
* op(a) must be an n-by-n matrix, if side = 'R'or'r'.
* For the definition of op(a), see Matrix Arguments.
* The routine overwrites x on b.
* @param Order
* @param Side
* @param Uplo
* @param TransA
* @param Diag
* @param alpha
* @param A
* @param B
*/
@Override
public void trsm(char Order, char Side, char Uplo, char TransA, char Diag, double alpha, INDArray A, INDArray B) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, A, B);
// FIXME: int cast
if (A.data().dataType() == DataBuffer.Type.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, A, B);
dtrsm(Order, Side, Uplo, TransA, Diag, (int) A.rows(), (int) A.columns(), alpha, A, (int) A.size(0), B, (int) B.size(0));
} else {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, A, B);
strsm(Order, Side, Uplo, TransA, Diag, (int) A.rows(), (int) A.columns(), (float) alpha, A, (int) A.size(0), B, (int) B.size(0));
}
OpExecutionerUtil.checkForAny(B);
}
示例23
@Override
public INDArray putScalar(long i, double value) {
if (i < 0)
i += rank();
if (isScalar()) {
if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.DISABLED && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.SCOPE_PANIC)
OpProfiler.getInstance().processScalarCall();
addOrUpdate(new long[] {0, 0}, value);
return this;
}
if (isRowVector()) {
addOrUpdate(new long[] {0, i}, value);
return this;
} else if (isColumnVector()) {
addOrUpdate(new long[] {i, 0}, value);
return this;
}
long[] indexes = ordering() == 'c' ? Shape.ind2subC(this, i) : Shape.ind2sub(this, i);
return putScalar(indexes, value);
}
示例24
@Override
public double getDouble(long i) {
if (i >= length()) {
throw new IllegalArgumentException("Unable to get linear index >= " + length());
}
if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.DISABLED && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.SCOPE_PANIC)
OpProfiler.getInstance().processScalarCall();
if (i == 0)
return data().getDouble(i);
long[] dimensions = ordering() == 'c' ? Shape.ind2subC(this, i) : Shape.ind2sub(this, i);
Shape.assertShapeLessThan(dimensions, shape());
return getDouble(dimensions);
}
示例25
/**
* @param order
* @param Uplo
* @param alpha
* @param X
* @param Y
* @param A
*/
@Override
public void syr2(char order, char Uplo, double alpha, INDArray X, INDArray Y, INDArray A) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, A, X, Y);
if (X.length() > Integer.MAX_VALUE || A.size(0) > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
if (X.data().dataType() == DataType.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X, Y);
dsyr2(order, Uplo, (int) X.length(), alpha, X, X.stride(-1), Y, Y.stride(-1), A, (int) A.size(0));
} else {
DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, X, Y);
ssyr2(order, Uplo, (int) X.length(), (float) alpha, X, X.stride(-1), Y, Y.stride(-1), A, (int) A.size(0));
}
OpExecutionerUtil.checkForAny(A);
}
示例26
/**
* gemv computes a matrix-vector product using a general matrix and performs one of the following matrix-vector operations:
* y := alpha*a*x + beta*y for trans = 'N'or'n';
* y := alpha*a'*x + beta*y for trans = 'T'or't';
* y := alpha*conjg(a')*x + beta*y for trans = 'C'or'c'.
* Here a is an m-by-n band matrix, x and y are vectors, alpha and beta are scalars.
*
* @param order
* @param transA
* @param alpha
* @param A
* @param X
* @param beta
* @param Y
*/
@Override
public void gemv(char order, char transA, IComplexNumber alpha, IComplexNDArray A, IComplexNDArray X,
IComplexNumber beta, IComplexNDArray Y) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, A, X, Y);
GemvParameters parameters = new GemvParameters(A, X, Y);
if (A.data().dataType() == DataBuffer.Type.DOUBLE)
zgemv(order, transA, parameters.getM(), parameters.getN(), alpha.asDouble(), A, parameters.getLda(), X,
parameters.getIncx(), beta.asDouble(), Y, parameters.getIncy());
else
cgemv(order, transA, parameters.getM(), parameters.getN(), alpha.asFloat(), A, parameters.getLda(), X,
parameters.getIncx(), beta.asFloat(), Y, parameters.getIncy());
}
示例27
/**
* swaps a vector with another vector.
*
* @param x
* @param y
*/
@Override
public void copy(INDArray x, INDArray y) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, x, y);
if (x.isSparse() || y.isSparse()) {
Nd4j.getSparseBlasWrapper().level1().copy(x, y);
return;
}
if (x.data().dataType() == DataBuffer.Type.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, x, y);
dcopy(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
} else {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, x, y);
scopy(x.length(), x, BlasBufferUtil.getBlasStride(x), y, BlasBufferUtil.getBlasStride(y));
}
}
示例28
/**
* performs a rank-1 update of a general m-by-n matrix a:
* a := alpha*x*y' + a.
*
* @param order
* @param alpha
* @param X
* @param Y
* @param A
*/
@Override
public void ger(char order, double alpha, INDArray X, INDArray Y, INDArray A) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, A, X, Y);
// FIXME: int cast
if (X.data().dataType() == DataBuffer.Type.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, A, X, Y);
dger(order, (int) A.rows(), (int) A.columns(), alpha, X, X.majorStride(), Y, Y.majorStride(), A, (int) A.size(0));
} else {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, A, X, Y);
sger(order, (int) A.rows(), (int) A.columns(), (float) alpha, X, X.majorStride(), Y, Y.majorStride(), A, (int) A.size(0));
}
OpExecutionerUtil.checkForAny(A);
}
示例29
/**
* sbmv computes a matrix-vector product using a symmetric band matrix:
* y := alpha*a*x + beta*y.
* Here a is an n-by-n symmetric band matrix with k superdiagonals, x and y are n-element vectors, alpha and beta are scalars.
*
* @param order
* @param Uplo
* @param alpha
* @param A
* @param X
* @param beta
* @param Y
*/
@Override
public void sbmv(char order, char Uplo, double alpha, INDArray A, INDArray X, double beta, INDArray Y) {
if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.ALL)
OpProfiler.getInstance().processBlasCall(false, A, X, Y);
// FIXME: int cast
if (X.data().dataType() == DataBuffer.Type.DOUBLE) {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.DOUBLE, A, X, Y);
dsbmv(order, Uplo, (int) X.length(), (int) A.columns(), alpha, A, (int) A.size(0), X, X.majorStride(), beta, Y,
(int) Y.majorStride());
} else {
DefaultOpExecutioner.validateDataType(DataBuffer.Type.FLOAT, A, X, Y);
ssbmv(order, Uplo, (int) X.length(), (int) A.columns(), (float) alpha, A, (int) A.size(0), X, X.majorStride(), (float) beta,
Y, Y.majorStride());
}
OpExecutionerUtil.checkForAny(Y);
}
示例30
@Test
@Ignore
public void testConversion() throws Exception {
SameDiff sameDiff = SameDiff.create();
INDArray ones = Nd4j.ones(4);
SDVariable sdVariable = sameDiff.var("ones",ones);
SDVariable result = sdVariable.addi(1.0);
SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE);
val executioner = new NativeGraphExecutioner();
ByteBuffer buffer = executioner.convertToFlatBuffers(sameDiff, ExecutorConfiguration.builder().profilingMode(OpExecutioner.ProfilingMode.DISABLED).executionMode(ExecutionMode.SEQUENTIAL).outputMode(OutputMode.IMPLICIT).build());
val offset = buffer.position();
val array = buffer.array();
try (val fos = new FileOutputStream("../../libnd4j/tests/resources/adam_sum.fb"); val dos = new DataOutputStream(fos)) {
dos.write(array, offset, array.length - offset);
}
//INDArray[] res = executioner.executeGraph(sameDiff);
//assertEquals(8.0, res[0].getDouble(0), 1e-5);
/*
INDArray output = null;
for(int i = 0; i < 5; i++) {
output = sameDiff.execAndEndResult(ops);
System.out.println("Ones " + ones);
System.out.println(output);
}
assertEquals(Nd4j.valueArrayOf(4,7),ones);
assertEquals(28,output.getDouble(0),1e-1);
*/
}