未验证 提交 a8c0fb4e 编写于 作者: G Guo Sheng 提交者: GitHub

Add cholesky_op (#23543)

* Add cholesky_op forward part. test=develop

* Complete cholesky_op forward part. test=develop

* Add cholesky_op backward part. test=develop

* Complete cholesky_op backward part. test=develop

* Refine cholesky_op error check and docs. test=develop

* Add grad_check unit test for cholesky_op. test=develop

* Fix sample code in cholesky doc. test=develop

* Refine some error messages of cholesky_op. test=develop

* Refine some error messages of cholesky_op. test=develop

* Remove unused input in cholesky_grad. test=develop

* Remove unused input in cholesky_grad. test=develop

* Fix stream for cusolverDnSetStream. test=develop

* Update PADDLE_ENFORCE_CUDA_SUCCESS from cholesky_op to adapt to latest code.
test=develop

* Add CUSOLVER ERROR in enforce.h
test=develop

* Fix the missing return value in cholesky. test=develop
上级 461e6a01
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/cholesky_op.h"
namespace paddle {
namespace operators {
using framework::OpKernelType;
using framework::Tensor;
class CholeskyOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Cholesky");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Cholesky");
auto dims = ctx->GetInputDim("X");
auto rank = dims.size();
PADDLE_ENFORCE_GE(rank, 2,
platform::errors::InvalidArgument(
"The Input(X) should have at least 2 dimensions. But "
"received a %d dimension tensor.",
rank));
PADDLE_ENFORCE_EQ(
dims[rank - 2], dims[rank - 1],
platform::errors::InvalidArgument(
"The inner-most 2 dimensions of Input(X) all should be symmetric "
"positive-definite matrices and have the same size. But received "
"X's shape[-2] = %d and shape[-1] = %d.",
dims[rank - 2], dims[rank - 1]));
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
};
class CholeskyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor), The input tensor of cholesky op. Its shape should be "
"[*, M, M] where * is zero or more batch dimensions, and matrices "
"on the inner-most 2 dimensions all should be symmetric "
"positive-definite.");
AddOutput("Out",
"(Tensor), The output tensor of cholesky op. It has the same "
"shape as the input, and it is composed of upper-triangular or "
"lower-triangular Cholesky factors of each of the individual "
"matrices.");
AddAttr<bool>("upper",
"(bool, default false), flag indicating whether to return "
"upper or lower triangular matrices. Default: False")
.SetDefault(false);
AddComment(R"DOC(
Cholesky Operator.
Computes the Cholesky decomposition of one symmetric positive-definite matrix
or batches of symmetric positive-definite matrices.
)DOC");
}
};
class CholeskyGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "CholeskyGrad");
OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Out")), "Input",
"Out@GRAD", "CholeskyGrad");
auto dims = ctx->GetInputDim("Out");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, dims);
}
}
};
template <typename T>
class CholeskyGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("Out", this->Output("Out"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(cholesky, ops::CholeskyOp, ops::CholeskyOpMaker,
ops::CholeskyGradOpMaker<paddle::framework::OpDesc>,
ops::CholeskyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(cholesky_grad, ops::CholeskyGradOp);
REGISTER_OP_CPU_KERNEL(cholesky, ops::CholeskyCPUKernel<float>,
ops::CholeskyCPUKernel<double>);
REGISTER_OP_CPU_KERNEL(
cholesky_grad,
ops::CholeskyGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::CholeskyGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include <algorithm>
#include <vector>
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/cholesky_op.h"
#include "paddle/fluid/platform/dynload/cusolver.h"
namespace paddle {
namespace operators {
template <typename T>
class CholeskyGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
const Tensor* x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
bool upper = context.Attr<bool>("upper");
auto& dims = x->dims();
int batch_count = 1;
for (int i = 0; i < dims.size() - 2; i++) {
batch_count *= dims[i];
}
int m = dims[dims.size() - 1];
int tensor_size = batch_count * m * m;
const auto* x_data = x->data<T>();
auto* out_data = out->mutable_data<T>(context.GetPlace());
// matrices are assumed to be stored in column-major order in cusolver
cublasFillMode_t uplo =
upper ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
// portf is inplace, thus copy the triangular part of the input matrices to
// the output and set the other triangular part to 0 firstly
platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx,
tensor_size);
if (upper) {
MatrixBandPartFunctor<T> matrix_band_part_functor(
m, m, /* num_lower_diags */ 0, /* num_upper_diags */ m, x_data,
out_data);
for_range(matrix_band_part_functor);
} else {
MatrixBandPartFunctor<T> matrix_band_part_functor(
m, m, /* num_lower_diags */ m, /* num_upper_diags */ 0, x_data,
out_data);
for_range(matrix_band_part_functor);
}
// TODO(guosheng): Add callback to check info
auto info = memory::Alloc(dev_ctx, sizeof(int) * batch_count);
auto* info_ptr = reinterpret_cast<int*>(info->ptr());
#if CUDA_VERSION >= 9020
if (batch_count > 1) {
std::vector<T*> output_ptrs;
for (int i = 0; i < batch_count; i++) {
output_ptrs.emplace_back(out_data + i * m * m);
}
thrust::device_vector<T*> dev_output_ptrs(output_ptrs.begin(),
output_ptrs.end());
PotrfBatched(dev_ctx, uplo, m,
thrust::raw_pointer_cast(dev_output_ptrs.data()), m,
info_ptr, batch_count);
// TODO(guosheng): There seems to a bug in cusolver potrfBatched and need
// to clear the upper triangle of the output. Remove this workaround once
// the bug is fixed.
if (!upper) {
MatrixBandPartFunctor<T> matrix_band_part_functor(
m, m, /* num_lower_diags */ m, /* num_upper_diags */ 0, out_data,
out_data);
for_range(matrix_band_part_functor);
}
} else {
#endif
for (int i = 0; i < batch_count; i++) {
Potrf(dev_ctx, uplo, m, out_data + i * m * m, m, info_ptr + i);
}
#if CUDA_VERSION >= 9020
}
#endif
}
void Potrf(const platform::CUDADeviceContext& dev_ctx, cublasFillMode_t uplo,
int n, T* A, int lda, int* info) const;
void PotrfBatched(const platform::CUDADeviceContext& dev_ctx,
cublasFillMode_t uplo, int n, T* Aarray[], int lda,
int* info_array, int batch_size) const;
};
#define FUNC_WITH_TYPES(m) m(float, S) m(double, D)
#define POTRF_INSTANCE(T, C) \
template <> \
void CholeskyGPUKernel<T>::Potrf(const platform::CUDADeviceContext& dev_ctx, \
cublasFillMode_t uplo, int n, T* A, \
int lda, int* info) const { \
auto handle = dev_ctx.cusolver_dn_handle(); \
int workspace_size = 0; \
PADDLE_ENFORCE_CUDA_SUCCESS( \
platform::dynload::cusolverDn##C##potrf_bufferSize( \
handle, uplo, n, A, lda, &workspace_size)); \
auto workspace = memory::Alloc(dev_ctx, workspace_size); \
T* workspace_ptr = reinterpret_cast<T*>(workspace->ptr()); \
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDn##C##potrf( \
handle, uplo, n, A, lda, workspace_ptr, workspace_size, info)); \
}
FUNC_WITH_TYPES(POTRF_INSTANCE);
#if CUDA_VERSION >= 9020
#define POTRF_BATCH_INSTANCE(T, C) \
template <> \
void CholeskyGPUKernel<T>::PotrfBatched( \
const platform::CUDADeviceContext& dev_ctx, cublasFillMode_t uplo, \
int n, T* Aarray[], int lda, int* info_array, int batch_size) const { \
auto handle = dev_ctx.cusolver_dn_handle(); \
PADDLE_ENFORCE_CUDA_SUCCESS( \
platform::dynload::cusolverDn##C##potrfBatched( \
handle, uplo, n, Aarray, lda, info_array, batch_size)); \
}
FUNC_WITH_TYPES(POTRF_BATCH_INSTANCE);
#endif
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(cholesky, ops::CholeskyGPUKernel<float>,
ops::CholeskyGPUKernel<double>);
REGISTER_OP_CUDA_KERNEL(
cholesky_grad,
ops::CholeskyGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::CholeskyGradKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <numeric>
#include <vector>
#include "Eigen/Cholesky"
#include "Eigen/Core"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class CholeskyCPUKernel : public framework::OpKernel<T> {
public:
// different with EigenMatrix in framework/eigen.h
using EigenMatrix =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using InputMatrixMap = Eigen::Map<const EigenMatrix>;
using OutputMatrixMap = Eigen::Map<EigenMatrix>;
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
bool upper = context.Attr<bool>("upper");
auto& dims = x->dims();
int batch_count = 1;
for (int i = 0; i < dims.size() - 2; i++) {
batch_count *= dims[i];
}
auto m = dims[dims.size() - 1];
const auto* x_data = x->data<T>();
auto* out_data = out->mutable_data<T>(context.GetPlace());
// Cholesky decomposition for each matrix, maybe can use multi threads
for (int i = 0; i < batch_count; i++) {
auto input = InputMatrixMap(x_data + i * m * m, m, m);
auto output = OutputMatrixMap(out_data + i * m * m, m, m);
if (upper) {
Eigen::LLT<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>,
Eigen::UpLoType::Upper>
llt_decomposition(input);
PADDLE_ENFORCE_EQ(
llt_decomposition.info(), Eigen::Success,
platform::errors::InvalidArgument(
"Cholesky decomposition was not successful. The input matrice "
"might not be not be positive definite."));
output = llt_decomposition.matrixU();
} else {
Eigen::LLT<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>,
Eigen::UpLoType::Lower>
llt_decomposition(input);
PADDLE_ENFORCE_EQ(
llt_decomposition.info(), Eigen::Success,
platform::errors::InvalidArgument(
"Cholesky decomposition was not successful. The input matrice "
"might not be not be positive definite."));
output = llt_decomposition.matrixL();
}
}
}
};
/*! Use these functors to implement tril, triu, diagonal and other operators */
template <typename T>
struct EyeFunctor {
EyeFunctor(const int m, const int n, T* output)
: m_(m), n_(n), output_(output) {}
HOSTDEVICE void operator()(size_t index) const {
const int global_row = index / n_;
const int col = index - global_row * n_;
const int batch = global_row / m_;
const int row = global_row - batch * m_;
output_[index] = col == row ? static_cast<T>(1) : static_cast<T>(0);
}
const int m_, n_;
T* output_;
};
template <typename T>
struct MatrixBandPartFunctor {
/*! Set output as input value outside a central band and 0 inside that band.
* That is: output[i, j, ..., m, n] = in_band(m, n) * input[i, j, ..., m, n]
* where: in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) && (num_upper
* < 0 || (n-m) <= num_upper)
*/
MatrixBandPartFunctor(const int m, const int n, const int num_lower_diags,
const int num_upper_diags, const T* input, T* output)
: m_(m),
n_(n),
num_lower_diags_(num_lower_diags),
num_upper_diags_(num_upper_diags),
input_(input),
output_(output) {}
HOSTDEVICE void operator()(size_t index) const {
const int col = index % n_;
const int row = (index / n_) % m_;
const int band_start = (num_lower_diags_ < 0 ? 0 : row - num_lower_diags_);
const int band_end =
(num_upper_diags_ < 0 ? n_ : row + num_upper_diags_ + 1);
if (col < band_start || col >= band_end) {
output_[index] = static_cast<T>(0);
} else {
output_[index] = input_[index];
}
}
const int m_, n_, num_lower_diags_, num_upper_diags_;
const T* input_;
T* output_;
};
template <typename T>
struct MatrixSetDiagFunctor {
/*! Overwrite specified diagonals of output by the values in diagonal.
* diagonals can be a central band specified by num_diags and
* upper_diag_index, where upper_diag_index=0 refers to the main diagonal,
* positive value means superdiagonal and negative value means subdiagonal.
* When it is a band, `diag` has a shape [i, j, ..., num_diags, max_diag_len]
* and the num_diags diagonals has a up to down layout. Otherwise it has a
* shape [i, j, ..., max_diag_len].
*/
MatrixSetDiagFunctor(const int m, const int n, const int num_diags,
const int max_diag_len, const int upper_diag_index,
const T* diag, T* output)
: m_(m),
n_(n),
num_diags_(num_diags),
max_diag_len_(max_diag_len),
upper_diag_index_(upper_diag_index),
diag_(diag),
output_(output) {}
HOSTDEVICE void operator()(size_t index) const {
const int batch_and_diag_index = index / max_diag_len_;
const int index_in_the_diagonal =
index - batch_and_diag_index * max_diag_len_;
const int batch = batch_and_diag_index / num_diags_;
const int diag_index_in_input = batch_and_diag_index - batch * num_diags_;
// diag_index=0 refers to the main diagonal
const int diag_index = upper_diag_index_ - diag_index_in_input;
// shift down for subdiagonal if diag_index < 0
const int y_index =
index_in_the_diagonal + (0 > -diag_index ? 0 : -diag_index);
// shift right for superdiagonal if diag_index > 0
const int x_index =
index_in_the_diagonal + (0 > diag_index ? 0 : diag_index);
// Upper-bound checks for diagonals shorter than max_diag_len.
// y_index and x_index are nonnegative by construction.
if (y_index < m_ && x_index < n_) {
const int out_index = batch * m_ * n_ + y_index * n_ + x_index;
output_[out_index] = diag_[index];
}
}
const int m_, n_, num_diags_, max_diag_len_, upper_diag_index_;
const T* diag_;
T* output_;
};
template <typename T>
struct MatrixDiagPartFunctor {
/*! Similar to MatrixSetDiagFunctor but return the diagonals. diag_index=0
* refers to the main diagonal, positive value means superdiagonal and
* negative value means subdiagonal */
MatrixDiagPartFunctor(const int m, const int n, const int num_diags,
const int max_diag_len, const int upper_diag_index,
const T padding, const T* input, T* output)
: m_(m),
n_(n),
num_diags_(num_diags),
max_diag_len_(max_diag_len),
upper_diag_index_(upper_diag_index),
input_(input),
output_(output) {}
HOSTDEVICE void operator()(size_t index) const {
const int batch_and_mapped_diag_index = index / max_diag_len_;
const int index_in_the_diagonal =
index - batch_and_mapped_diag_index * max_diag_len_;
const int batch = batch_and_mapped_diag_index / num_diags_;
const int mapped_diag_index =
batch_and_mapped_diag_index - batch * num_diags_;
// diag_index=0 refers to the main diagonal
const int diag_index = upper_diag_index_ - mapped_diag_index;
// shift down for subdiagonal if diag_index < 0
const int y_index =
index_in_the_diagonal + (0 > -diag_index ? 0 : -diag_index);
// shift right for superdiagonal if diag_index > 0
const int x_index =
index_in_the_diagonal + (0 > diag_index ? 0 : diag_index);
if (y_index < m_ && x_index < n_) {
output_[index] = input_[batch * m_ * n_ + y_index * m_ + x_index];
} else {
output_[index] = padding_;
}
}
const int m_, n_, num_diags_, max_diag_len_, upper_diag_index_;
const T padding_;
const T* input_;
T* output_;
};
template <typename T>
struct MatrixBandPartScaleEndFunctor {
/*! Compared with MatrixBandPartFunctor, it scale up values at the end of
* band. It can be used to fuse the following operations, which actually
* output triangular with diagonal scaled up:
* 1. dig = matrix_diag_part(middle)
* 2. middle = matrix_set_diag(middle, diag * scalar)
* 3. middle = matrix_band_part(middle, -1, 0)
*/
MatrixBandPartScaleEndFunctor(const int m, const int n,
const int num_lower_diags,
const int num_upper_diags, const T scale,
const T* input, T* output)
: m_(m),
n_(n),
num_lower_diags_(num_lower_diags),
num_upper_diags_(num_upper_diags),
scale_(scale),
input_(input),
output_(output) {}
HOSTDEVICE void operator()(size_t index) const {
const int col = index % n_;
const int row = (index / n_) % m_;
const int band_start = (num_lower_diags_ < 0 ? 0 : row - num_lower_diags_);
const int band_end =
(num_upper_diags_ < 0 ? n_ : row + num_upper_diags_ + 1);
if (col < band_start || col >= band_end) {
output_[index] = 0;
} else if (col == band_end - 1) {
output_[index] = scale_ * input_[index];
} else {
output_[index] = input_[index];
}
}
const int m_, n_, num_lower_diags_, num_upper_diags_;
const T scale_;
const T* input_;
T* output_;
};
template <typename T>
struct AddtoScaleFunctor {
AddtoScaleFunctor(const T scale, const T* input, T* output)
: scale_(scale), input_(input), output_(output) {}
HOSTDEVICE void operator()(size_t index) const {
output_[index] += input_[index];
output_[index] *= scale_;
}
const T scale_;
const T* input_;
T* output_;
};
template <typename DeviceContext, typename T>
class CholeskyGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out = context.Input<Tensor>("Out");
auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out"));
auto* x_grad = context.Output<Tensor>(framework::GradVarName("X"));
auto* x_grad_data = x_grad->mutable_data<T>(context.GetPlace());
bool upper = context.Attr<bool>("upper");
auto& dims = out->dims();
int batch_count = 1;
for (int i = 0; i < dims.size() - 2; i++) {
batch_count *= dims[i];
}
auto m = dims[dims.size() - 1];
int tensor_size = batch_count * m * m;
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> axis(dims.size() - 2);
std::iota(axis.begin(), axis.end(), 0);
axis.insert(axis.end(), {dims.size() - 1, dims.size() - 2});
Tensor l, l_grad;
if (upper) {
l.mutable_data<T>(dims, context.GetPlace());
l_grad.mutable_data<T>(dims, context.GetPlace());
TransCompute<DeviceContext, T>(dims.size(), dev_ctx, *out, &l, axis);
TransCompute<DeviceContext, T>(dims.size(), dev_ctx, *out_grad, &l_grad,
axis);
} else {
l = *out;
l_grad = *out_grad;
}
auto* l_data = l.data<T>();
/*! refer to Iain Murray (2016); arXiv 1602.07527 */
/*! phi = matmul(L.transpose(-1, -2), grad) */
Tensor middle;
auto* middle_data = middle.mutable_data<T>(dims, context.GetPlace());
auto trans_desc = math::CreateMatrixDescriptor(dims, 0, true);
auto no_trans_desc = math::CreateMatrixDescriptor(dims, 0, false);
auto blas = math::GetBlas<DeviceContext, T>(context);
blas.MatMul(l, trans_desc, l_grad, no_trans_desc, T(1), &middle, T(0));
/*! phi.tril_().diagonal(0, -2, -1).mul_(0.5) */
platform::ForRange<DeviceContext> for_range(dev_ctx, tensor_size);
MatrixBandPartScaleEndFunctor<T> matrix_band_part_scale_end_functor(
m, m, /* num_lower_diags */ m, /* num_upper_diags */ 0,
/* scale */ 0.5, middle_data, middle_data);
for_range(matrix_band_part_scale_end_functor);
// Compute inverse by solving the triangular linear system AX = B, where B
// is the identity matrix. The matrix X would be overwritten on B
Tensor identity;
auto* identity_data = identity.mutable_data<T>(dims, context.GetPlace());
EyeFunctor<T> eye_functor(m, m, identity_data);
for_range(eye_functor);
// TODO(guosheng): use trsmBatched for GPU
for (int i = 0; i < batch_count; i++) {
blas.TRSM(/*side*/ CblasLeft, /*uplo*/ CblasLower,
/*trans*/ CblasNoTrans, /*diag*/ CblasNonUnit, /*m*/ m, /*n*/ m,
/*alpha*/ T(1), l_data + i * m * m, /*lda*/ m,
identity_data + i * m * m, /*ldb*/ m);
}
Tensor& l_inverse = identity;
/*! x_grad = matmul(matmul(L_inverse.transpose(-1, -2), phi), L_inverse) */
Tensor middle1;
middle1.mutable_data<T>(dims, context.GetPlace());
blas.MatMul(l_inverse, trans_desc, middle, no_trans_desc, T(1), &middle1,
T(0));
blas.MatMul(middle1, no_trans_desc, l_inverse, no_trans_desc, T(1), x_grad,
T(0));
/*! x_grad.add(x_grad.transpose(-1, -2)).mul_(0.5) */
Tensor x_grad_trans;
auto* x_grad_trans_data =
x_grad_trans.mutable_data<T>(dims, context.GetPlace());
TransCompute<DeviceContext, T>(dims.size(), dev_ctx, *x_grad, &x_grad_trans,
axis);
AddtoScaleFunctor<T> addto_scale_functor(0.5, x_grad_trans_data,
x_grad_data);
for_range(addto_scale_functor);
}
};
} // namespace operators
} // namespace paddle
......@@ -218,6 +218,11 @@ class Blas {
template <typename T>
void VMERF(int n, const T* a, T* y, int64_t mode) const;
template <typename T>
void TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, CBLAS_TRANSPOSE transA,
CBLAS_DIAG diag, int M, int N, T alpha, const T* A, int lda, T* B,
int ldb) const;
private:
const DeviceContext& context_;
};
......@@ -351,6 +356,11 @@ class BlasT : private Blas<DeviceContext> {
Base()->template VMERF<T>(args...);
}
template <typename... ARGS>
void TRSM(ARGS... args) const {
Base()->template TRSM<T>(args...);
}
private:
const Blas<DeviceContext>* Base() const {
return static_cast<const Blas<DeviceContext>*>(this);
......
......@@ -88,6 +88,11 @@ struct CUBlas<float> {
PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0");
#endif
}
template <typename... ARGS>
static void TRSM(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasStrsm(args...));
}
};
template <>
......@@ -131,6 +136,11 @@ struct CUBlas<double> {
static void GEMM_EX(ARGS... args) {
PADDLE_THROW("Currently there are not cublasDgemmEx.");
}
template <typename... ARGS>
static void TRSM(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDtrsm(args...));
}
};
template <>
......@@ -411,6 +421,31 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
#endif // CUDA_VERSION >= 9010
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo,
CBLAS_TRANSPOSE transA,
CBLAS_DIAG diag, int M, int N,
T alpha, const T *A, int lda, T *B,
int ldb) const {
// solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'`
// where ' stands for transpose
cublasSideMode_t cuSide =
(side == CblasLeft) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT;
cublasFillMode_t cuUplo =
(uplo == CblasLower) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
// use CUBLAS_OP_C (conjugate transpose) for complex
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasDiagType_t cuDiag =
(diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<T>::TRSM(handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, &alpha, A,
lda, B, ldb);
});
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -152,6 +152,11 @@ struct CBlas<float> {
platform::dynload::mkl_scsrmm(args...);
}
#endif
template <typename... ARGS>
static void TRSM(ARGS... args) {
platform::dynload::cblas_strsm(args...);
}
};
template <>
......@@ -273,6 +278,11 @@ struct CBlas<double> {
platform::dynload::mkl_dcsrmm(args...);
}
#endif
template <typename... ARGS>
static void TRSM(ARGS... args) {
platform::dynload::cblas_dtrsm(args...);
}
};
#else
......@@ -298,6 +308,11 @@ struct CBlas<float> {
static void GEMV(ARGS... args) {
cblas_sgemv(args...);
}
template <typename... ARGS>
static void TRSM(ARGS... args) {
cblas_strsm(args...);
}
};
template <>
......@@ -321,6 +336,11 @@ struct CBlas<double> {
static void GEMV(ARGS... args) {
cblas_dgemv(args...);
}
template <typename... ARGS>
static void TRSM(ARGS... args) {
cblas_dtrsm(args...);
}
};
#endif
......@@ -899,6 +919,17 @@ void Blas<platform::CPUDeviceContext>::CSRMM(
}
#endif
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo,
CBLAS_TRANSPOSE transA,
CBLAS_DIAG diag, int M, int N,
T alpha, const T *A, int lda, T *B,
int ldb) const {
CBlas<T>::TRSM(CblasRowMajor, side, uplo, transA, diag, M, N, alpha, A, lda,
B, ldb);
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -9,6 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <functional>
#include <string>
#include "cub/cub.cuh"
#include "paddle/fluid/operators/math.h"
......
......@@ -241,12 +241,14 @@ CUDAContext::CUDAContext(const CUDAPlace& place,
InitEigenContext();
InitCuBlasContext();
InitCuDNNContext();
InitCuSolverContext();
}
CUDAContext::~CUDAContext() {
CUDADeviceGuard guard(place_.device);
DestoryCuDNNContext();
DestoryCuBlasContext();
DestoryCuSolverContext();
}
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
......@@ -340,6 +342,10 @@ CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
}
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
return context()->CusolverDnHandle();
}
cudaStream_t CUDADeviceContext::stream() const {
return context()->RawStream();
}
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/platform/cuda_helper.h"
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/dynload/cusolver.h"
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/platform/dynload/nccl.h"
#endif
......@@ -105,6 +106,10 @@ class CUDAContext {
const cudnnHandle_t& CudnnHandle() const { return cudnn_handle_; }
const cusolverDnHandle_t& CusolverDnHandle() const {
return cusolver_dn_handle_;
}
const std::unique_ptr<CublasHandleHolder>& CublasHandle() const {
return cublas_handle_;
}
......@@ -170,6 +175,13 @@ class CUDAContext {
}
}
void InitCuSolverContext() {
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cusolverDnCreate(&cusolver_dn_handle_));
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cusolverDnSetStream(cusolver_dn_handle_, RawStream()));
}
void DestoryCuDNNContext() {
if (cudnn_handle_) {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroy(cudnn_handle_));
......@@ -182,6 +194,13 @@ class CUDAContext {
cublas_tensor_core_handle_.reset();
}
void DestoryCuSolverContext() {
if (cusolver_dn_handle_) {
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cusolverDnDestroy(cusolver_dn_handle_));
}
}
CUDAPlace place_;
std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
......@@ -189,6 +208,7 @@ class CUDAContext {
cudnnHandle_t cudnn_handle_;
std::unique_ptr<CublasHandleHolder> cublas_handle_;
std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
cusolverDnHandle_t cusolver_dn_handle_;
DISABLE_COPY_AND_ASSIGN(CUDAContext);
};
......@@ -249,6 +269,8 @@ class CUDADeviceContext : public DeviceContext {
* sequential cudnn function calls. */
CudnnWorkspaceHandle cudnn_workspace_handle() const;
cusolverDnHandle_t cusolver_dn_handle() const;
/*! \brief Return cuda stream in the device context. */
cudaStream_t stream() const;
......
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce)
list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc)
list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc cusolver.cc)
# There is no macOS version of NCCL.
# Disable nvrtc and cuda_driver api on MacOS and Windows, and only do a early test on Linux.
......
......@@ -76,6 +76,8 @@ extern void *cublas_dso_handle;
__macro(cublasSgemmEx); \
__macro(cublasSgeam); \
__macro(cublasDgeam); \
__macro(cublasStrsm_v2); \
__macro(cublasDtrsm_v2); \
__macro(cublasCreate_v2); \
__macro(cublasDestroy_v2); \
__macro(cublasSetStream_v2); \
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/dynload/cusolver.h"
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag cusolver_dso_flag;
void *cusolver_dso_handle;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
CUSOLVER_ROUTINE_EACH(DEFINE_WRAP);
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cusolverDn.h>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
namespace paddle {
namespace platform {
namespace dynload {
extern std::once_flag cusolver_dso_flag;
extern void *cusolver_dso_handle;
#ifdef PADDLE_USE_DSO
#define DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
cusolverStatus_t operator()(Args... args) { \
using cusolverFunc = decltype(&::__name); \
std::call_once(cusolver_dso_flag, []() { \
cusolver_dso_handle = \
paddle::platform::dynload::GetCusolverDsoHandle(); \
}); \
static void *p_##__name = dlsym(cusolver_dso_handle, #__name); \
return reinterpret_cast<cusolverFunc>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
#else
#define DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
cusolverStatus_t operator()(Args... args) { \
return ::__name(args...); \
} \
}; \
extern DynLoad__##__name __name
#endif
#define CUSOLVER_ROUTINE_EACH(__macro) \
__macro(cusolverDnCreate); \
__macro(cusolverDnDestroy); \
__macro(cusolverDnSetStream); \
__macro(cusolverDnSpotrf_bufferSize); \
__macro(cusolverDnDpotrf_bufferSize); \
__macro(cusolverDnSpotrf); \
__macro(cusolverDnDpotrf);
CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP);
#if CUDA_VERSION >= 9020
#define CUSOLVER_ROUTINE_EACH_R1(__macro) \
__macro(cusolverDnSpotrfBatched); \
__macro(cusolverDnDpotrfBatched);
CUSOLVER_ROUTINE_EACH_R1(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP)
#endif
} // namespace dynload
} // namespace platform
} // namespace paddle
......@@ -224,6 +224,16 @@ void* GetCurandDsoHandle() {
#endif
}
void* GetCusolverDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcusolver.dylib");
#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA)
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, win_cusolver_lib);
#else
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcusolver.so");
#endif
}
void* GetNVRTCDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvrtc.dylib");
......
......@@ -29,6 +29,7 @@ void* GetCublasDsoHandle();
void* GetCUDNNDsoHandle();
void* GetCUPTIDsoHandle();
void* GetCurandDsoHandle();
void* GetCusolverDsoHandle();
void* GetNVRTCDsoHandle();
void* GetCUDADsoHandle();
void* GetWarpCTCDsoHandle();
......
......@@ -53,6 +53,8 @@ extern void* mklml_dso_handle;
__macro(cblas_dcopy); \
__macro(cblas_sgemv); \
__macro(cblas_dgemv); \
__macro(cblas_strsm); \
__macro(cblas_dtrsm); \
__macro(cblas_sgemm_alloc); \
__macro(cblas_dgemm_alloc); \
__macro(cblas_sgemm_pack); \
......
......@@ -56,6 +56,7 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/dynload/curand.h"
#include "paddle/fluid/platform/dynload/cusolver.h"
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/platform/dynload/nccl.h"
#endif // __APPLE__
......@@ -680,6 +681,44 @@ inline void throw_on_error(cublasStatus_t stat, const std::string& msg) {
#endif
}
/***** CUSOLVER ERROR *****/
inline bool is_error(cusolverStatus_t stat) {
return stat != CUSOLVER_STATUS_SUCCESS;
}
inline const char* cusolverGetErrorString(cusolverStatus_t stat) {
switch (stat) {
case CUSOLVER_STATUS_NOT_INITIALIZED:
return "CUSOLVER_STATUS_NOT_INITIALIZED";
case CUSOLVER_STATUS_ALLOC_FAILED:
return "CUSOLVER_STATUS_ALLOC_FAILED";
case CUSOLVER_STATUS_INVALID_VALUE:
return "CUSOLVER_STATUS_INVALID_VALUE";
case CUSOLVER_STATUS_ARCH_MISMATCH:
return "CUSOLVER_STATUS_ARCH_MISMATCH";
case CUSOLVER_STATUS_EXECUTION_FAILED:
return "CUSOLVER_STATUS_EXECUTION_FAILED";
case CUSOLVER_STATUS_INTERNAL_ERROR:
return "CUSOLVER_STATUS_INTERNAL_ERROR";
case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED";
default:
return "Unknown cusolver status";
}
}
inline std::string build_nvidia_error_msg(cusolverStatus_t stat) {
std::string msg(" Cublas error, ");
return msg + cusolverGetErrorString(stat) + " ";
}
inline void throw_on_error(cusolverStatus_t stat, const std::string& msg) {
#ifndef REPLACE_ENFORCE_GLOG
throw std::runtime_error(msg);
#else
LOG(FATAL) << msg;
#endif
}
/****** NCCL ERROR ******/
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
inline bool is_error(ncclResult_t nccl_result) {
......@@ -716,6 +755,7 @@ DEFINE_CUDA_STATUS_TYPE(cudaError_t, cudaSuccess);
DEFINE_CUDA_STATUS_TYPE(curandStatus_t, CURAND_STATUS_SUCCESS);
DEFINE_CUDA_STATUS_TYPE(cudnnStatus_t, CUDNN_STATUS_SUCCESS);
DEFINE_CUDA_STATUS_TYPE(cublasStatus_t, CUBLAS_STATUS_SUCCESS);
DEFINE_CUDA_STATUS_TYPE(cusolverStatus_t, CUSOLVER_STATUS_SUCCESS);
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess);
......
......@@ -160,7 +160,7 @@ from .tensor.linalg import norm #DEFINE_ALIAS
from .tensor.linalg import dist #DEFINE_ALIAS
from .tensor.linalg import t #DEFINE_ALIAS
from .tensor.linalg import cross #DEFINE_ALIAS
# from .tensor.linalg import cholesky #DEFINE_ALIAS
from .tensor.linalg import cholesky #DEFINE_ALIAS
# from .tensor.linalg import .tensordot #DEFINE_ALIAS
# from .tensor.manipulation import cast #DEFINE_ALIAS
# from .tensor.manipulation import concat #DEFINE_ALIAS
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.core as core
from op_test import OpTest, skip_check_grad_ci
from gradient_checker import grad_check
from decorator_helper import prog_scope
@skip_check_grad_ci(
reason="The input of cholesky_op should always be symmetric positive-definite. "
"However, OpTest calculates the numeric gradient of each element in input "
"via small finite difference, which makes the input no longer symmetric "
"positive-definite thus can not compute the Cholesky decomposition. "
"While we can use the gradient_checker.grad_check to perform gradient "
"check of cholesky_op, since it supports check gradient with a program "
"and we can construct symmetric positive-definite matrices in the program")
class TestCholeskyOp(OpTest):
def setUp(self):
self.op_type = "cholesky"
self._input_shape = (2, 32, 32)
self._upper = True
self.init_config()
self.trans_dims = list(range(len(self._input_shape) - 2)) + [
len(self._input_shape) - 1, len(self._input_shape) - 2
]
self.root_data = np.random.random(self._input_shape).astype("float64")
# construct symmetric positive-definite matrice
input_data = np.matmul(
self.root_data, self.root_data.transpose(self.trans_dims)) + 1e-05
output_data = np.linalg.cholesky(input_data).astype("float64")
if self._upper:
output_data = output_data.transpose(self.trans_dims)
self.inputs = {"X": input_data}
self.attrs = {"upper": self._upper}
self.outputs = {"Out": output_data}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
@prog_scope()
def func(self, place):
# use small size since Jacobian gradients is time consuming
root_data = self.root_data[..., :3, :3]
prog = fluid.Program()
with fluid.program_guard(prog):
root = layers.create_parameter(
dtype=root_data.dtype, shape=root_data.shape)
root_t = layers.transpose(root, self.trans_dims)
x = layers.matmul(x=root, y=root_t) + 1e-05
out = paddle.cholesky(x, upper=self.attrs["upper"])
grad_check(root, out, x_init=root_data, place=place)
def init_config(self):
self._upper = True
class TestCholeskyOpLower(TestCholeskyOp):
def init_config(self):
self._upper = False
class TestCholeskyOp2D(TestCholeskyOp):
def init_config(self):
self._input_shape = (64, 64)
if __name__ == "__main__":
unittest.main()
......@@ -26,7 +26,7 @@ __all__ = [
'dist',
't',
'cross',
# 'cholesky',
'cholesky',
# 'tensordot',
'bmm'
]
......@@ -627,6 +627,59 @@ def cross(input, other, dim=None):
return out
def cholesky(x, upper=False):
"""
Computes the Cholesky decomposition of one symmetric positive-definite
matrix or batches of symmetric positive-definite matrice.
If `upper` is `True`, the decomposition has the form :math:`A = U^{T}U` ,
and the returned matrix :math:`U` is upper-triangular. Otherwise, the
decomposition has the form :math:`A = LL^{T}` , and the returned matrix
:math:`L` is lower-triangular.
Args:
x (Variable): The input tensor. Its shape should be `[*, M, M]`,
where * is zero or more batch dimensions, and matrices on the
inner-most 2 dimensions all should be symmetric positive-definite.
Its data type should be float32 or float64.
upper (bool): The flag indicating whether to return upper or lower
triangular matrices. Default: False.
Returns:
Variable: A Tensor with same shape and data type as `x`. It represents \
triangular matrices generated by Cholesky decomposition.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np
with fluid.dygraph.guard():
a = np.random.rand(3, 3)
a_t = np.transpose(a, [1, 0])
x = np.matmul(a, a_t) + 1e-03
x = fluid.dygraph.to_variable(x)
out = paddle.cholesky(x, upper=False)
print(out.numpy())
# [[1.190523 0. 0. ]
# [0.9906703 0.27676893 0. ]
# [1.25450498 0.05600871 0.06400121]]
"""
check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'cholesky')
check_type(upper, 'upper', bool, 'cholesky')
helper = LayerHelper('cholesky', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='cholesky',
inputs={'X': [x]},
outputs={'Out': out},
attrs={'upper': upper})
return out
def bmm(x, y, name=None):
"""
Applies batched matrix multiplication to two tensors.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册