未验证 提交 be817719 编写于 作者: Z zyfncg 提交者: GitHub

【PTen】Add dot and matmul grad kernel in pten (#38713)

* refactor matmul directory in pten

* fix merge conflict

* add dot_grad kernel

* add dot_grad kernel in pten

* add matmul_grad kernel

* update the code

* delete useless code in fluid

* fix some bug of running matmul grad kernel

* fix merge conflict

* refactor some code

* refactor code
上级 5b940c44
......@@ -79,6 +79,9 @@ function(kernel_library TARGET)
endif()
list(APPEND all_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.h)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/impl/${TARGET}_impl.h)
list(APPEND all_srcs ${CMAKE_CURRENT_SOURCE_DIR}/impl/${TARGET}_impl.h)
endif()
list(APPEND all_srcs ${common_srcs})
list(APPEND all_srcs ${cpu_srcs})
list(APPEND all_srcs ${gpu_srcs})
......
......@@ -1880,16 +1880,32 @@ void OperatorWithKernel::BuildPtenKernelContext(
// Otherwise,we will create new storage.
for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
if (current_vector_size > start_idx + offset) {
experimental::ReMakePtenDenseTensorFromVar(
outs_vector[offset], out_def,
auto* buffer_tensor =
pt_kernel_context_->MutableOutputAt<pten::DenseTensor>(start_idx +
offset));
offset);
if (buffer_tensor) {
experimental::ReMakePtenDenseTensorFromVar(outs_vector[offset],
out_def, buffer_tensor);
}
} else {
pt_kernel_context_->EmplaceBackOutputWithoutSetRange(
experimental::MakePtenTensorBaseFromVar(outs_vector[offset],
out_def));
}
}
// Deal with the case that some outputs are NULL when run the kernel.
// For example : the outputs of matmul_grad are dx and dy,
// sometimes dx or dy may be NULL.
if (outs_vector.empty()) {
if (current_vector_size > start_idx) {
pt_kernel_context_->SetOutputWithoutSetRange(start_idx, {nullptr});
} else {
pt_kernel_context_->EmplaceBackOutputWithoutSetRange({nullptr});
}
end_idx = start_idx + 1;
}
pt_kernel_context_->AssignOutputRange(std::make_pair(start_idx, end_idx),
i);
}
......@@ -2002,9 +2018,11 @@ void OperatorWithKernel::WriteBackToOutputs(RuntimeContext* ctx) const {
range_pair.first, range_pair.second);
for (size_t j = 0; j < pten_outs.size(); ++j) {
if (pten_outs[j]) {
experimental::MakeVariableFromPtenTensor(pten_outs[j], outs_vector[j]);
}
}
}
}
} // namespace framework
......
......@@ -99,7 +99,7 @@ KernelSignatureMap& KernelSignatureMap::Instance() {
const auto& op_type = pair.first;
const auto* op_proto = pair.second.proto_;
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type) &&
op_proto != nullptr) {
op_proto) {
KernelArgsNameMakerByOpProto maker(op_proto);
VLOG(10) << "Register kernel signature for " << op_type;
auto success = kernel_signature_map_->map_
......
......@@ -338,19 +338,41 @@ static void BuildDygraphPtenKernelContext(
for (size_t i = 0; i < output_names.size(); ++i) {
auto& out_def = output_defs.at(i);
auto& outs_vector = outs.at(output_names[i]);
size_t start_idx = (i == 0 ? 0 : kernel_ctx->OutputRangeAt(i - 1).second);
size_t end_idx = start_idx + outs_vector.size();
auto current_vector_size = kernel_ctx->OutputsSize();
auto iter = outs.find(output_names[i]);
if (iter == outs.end()) {
if (current_vector_size > start_idx) {
kernel_ctx->SetOutputWithoutSetRange(start_idx, {nullptr});
} else {
kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr});
}
kernel_ctx->AssignOutputRange(std::make_pair(start_idx, start_idx + 1),
i);
continue;
}
auto& outs_vector = iter->second;
size_t end_idx = start_idx + outs_vector.size();
// If the memory needed is less than the current memory allocated, we will
// reuse the current memory by using ReMakePtenDenseTensorFromVar.
// Otherwise,we will create new storage.
for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
if (current_vector_size > start_idx + offset) {
auto* buffer_tensor =
kernel_ctx->MutableOutputAt<pten::DenseTensor>(start_idx + offset);
if (buffer_tensor) {
experimental::ReMakePtenDenseTensorFromVar(
outs_vector[offset]->MutableVar(), out_def,
kernel_ctx->MutableOutputAt<pten::DenseTensor>(start_idx + offset));
outs_vector[offset]->MutableVar(), out_def, buffer_tensor);
} else {
kernel_ctx->SetOutputWithoutSetRange(
start_idx + offset,
experimental::MakePtenTensorBaseFromVar(
outs_vector[offset]->MutableVar(), out_def));
}
} else {
kernel_ctx->EmplaceBackOutputWithoutSetRange(
experimental::MakePtenTensorBaseFromVar(
......@@ -465,7 +487,9 @@ static void WriteBackToOutputs(
auto& output_names = std::get<2>(pt_kernel_signature.args);
for (size_t i = 0; i < output_names.size(); ++i) {
auto& outs_vector = outs.at(output_names[i]);
auto iter = outs.find(output_names[i]);
if (iter != outs.end()) {
auto& outs_vector = iter->second;
auto& range_pair = kernel_ctx->OutputRangeAt(i);
auto pten_outs = kernel_ctx->MutableOutputBetween<pten::DenseTensor>(
......@@ -476,6 +500,7 @@ static void WriteBackToOutputs(
outs_vector[j]->MutableVar());
}
}
}
}
template <typename VarType>
......@@ -529,6 +554,7 @@ static void PreparedOpRunImpl(
template <typename VarType>
static void PreparedOpRunPtImpl(
const framework::OperatorBase& op,
const framework::OpKernelType& kernel_type,
const framework::KernelSignature& pt_kernel_signature,
const pten::Kernel& pt_kernel, pten::KernelContext* pt_kernel_context,
platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins,
......@@ -558,7 +584,9 @@ static void PreparedOpRunPtImpl(
pt_kernel_context->ClearData();
// TODO(chenweihang): add debug flags later
// TODO(chenweihang): deal with complex cases later
if (framework::IsComplexType(kernel_type.data_type_)) {
HandleComplexGradToRealGrad<VarType>(outs);
}
}
void PreparedOp::Run(const NameVarMap<VarBase>& ins,
......@@ -566,9 +594,9 @@ void PreparedOp::Run(const NameVarMap<VarBase>& ins,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) {
PreparedOpRunPtImpl<VarBase>(op_, pt_kernel_signature_, pt_kernel_,
pt_kernel_context_, dev_ctx_, ins, outs, attrs,
default_attrs);
PreparedOpRunPtImpl<VarBase>(op_, kernel_type_, pt_kernel_signature_,
pt_kernel_, pt_kernel_context_, dev_ctx_, ins,
outs, attrs, default_attrs);
} else {
PreparedOpRunImpl<VarBase>(op_, ctx_, kernel_type_, func_, dev_ctx_, ins,
outs, attrs, default_attrs);
......@@ -580,9 +608,9 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) {
PreparedOpRunPtImpl<VariableWrapper>(op_, pt_kernel_signature_, pt_kernel_,
pt_kernel_context_, dev_ctx_, ins,
outs, attrs, default_attrs);
PreparedOpRunPtImpl<VariableWrapper>(
op_, kernel_type_, pt_kernel_signature_, pt_kernel_, pt_kernel_context_,
dev_ctx_, ins, outs, attrs, default_attrs);
} else {
PreparedOpRunImpl<VariableWrapper>(op_, ctx_, kernel_type_, func_, dev_ctx_,
ins, outs, attrs, default_attrs);
......
......@@ -39,7 +39,7 @@ class ConjKernel : public framework::OpKernel<T> {
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
// call new kernel
pten::Conj<T>(dev_ctx, *pt_x.get(), pt_out.get());
pten::ConjKernel<T>(dev_ctx, *pt_x.get(), pt_out.get());
}
};
......
......@@ -117,6 +117,13 @@ class DotGradOp : public framework::OperatorWithKernel {
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"dot_grad", {"X", "Y", framework::GradVarName("Out")}, {},
{framework::GradVarName("X"), framework::GradVarName("Y")});
}
};
template <typename T>
......
......@@ -22,217 +22,14 @@
// only can include the headers in paddle/pten/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/linalg.h"
#include "paddle/pten/kernels/dot_grad_kernel.h"
#include "paddle/pten/kernels/dot_kernel.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, typename R>
struct P {
void operator()(T a, R b);
};
template <typename DeviceContext, typename T, typename Enabel = void>
struct DotGradFunction {
void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx);
};
template <typename DeviceContext, typename T>
struct DotGradFunction<DeviceContext, T, math::EnableComplex<T>> {
void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx) {
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) {
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
if (tensor_dx) {
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
Eigen::DSizes<int, 1> size(tensor_dx->numel());
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_y->numel());
math::ConjFunctor<T> functor(tensor_y->data<T>(), tensor_y->numel(),
tensor_dx->data<T>());
for_range(functor);
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
dx.device(dev) = dx * dout.broadcast(size);
}
if (tensor_dy) {
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
Eigen::DSizes<int, 1> size(tensor_dy->numel());
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_y->numel());
math::ConjFunctor<T> functor(tensor_x->data<T>(), tensor_x->numel(),
tensor_dy->data<T>());
for_range(functor);
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
dy.device(dev) = dy * dout.broadcast(size);
}
} else {
auto dout = framework::EigenMatrix<T>::From(*tensor_dout);
if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
auto y = framework::EigenMatrix<T>::From(*tensor_y);
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_y->numel());
math::ConjFunctor<T> functor(tensor_y->data<T>(), tensor_y->numel(),
tensor_dx->data<T>());
for_range(functor);
auto dx = framework::EigenMatrix<T>::From(*tensor_dx);
dx.device(dev) = dx * dout.broadcast(size);
}
if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
auto x = framework::EigenMatrix<T>::From(*tensor_x);
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_x->numel());
math::ConjFunctor<T> functor(tensor_x->data<T>(), tensor_x->numel(),
tensor_dy->data<T>());
for_range(functor);
auto dy = framework::EigenMatrix<T>::From(*tensor_dy);
dy.device(dev) = dy * dout.broadcast(size);
}
}
#else
const auto* data_dout = tensor_dout->data<T>();
if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
const auto* data_y = tensor_y->data<T>();
const framework::DDim& dim = tensor_x->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = T(data_y[i].real, -data_y[i].imag) * data_dout[s];
}
}
if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
const auto* data_x = tensor_x->data<T>();
const framework::DDim& dim = tensor_y->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = T(data_x[i].real, -data_x[i].imag) * data_dout[s];
}
}
#endif
}
};
template <typename DeviceContext, typename T>
struct DotGradFunction<DeviceContext, T, math::DisableComplex<T>> {
void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
const Tensor* tensor_dout, Tensor* tensor_dx,
Tensor* tensor_dy,
const paddle::framework::ExecutionContext& ctx) {
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) {
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
if (tensor_dx) {
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dx->numel());
dx.device(dev) = y * dout.broadcast(size);
}
if (tensor_dy) {
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> size(tensor_dy->numel());
dy.device(dev) = x * dout.broadcast(size);
}
} else {
auto dout = framework::EigenMatrix<T>::From(*tensor_dout);
if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
auto y = framework::EigenMatrix<T>::From(*tensor_y);
auto dx = framework::EigenMatrix<T>::From(*tensor_dx);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
dx.device(dev) = y * dout.broadcast(size);
}
if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
auto x = framework::EigenMatrix<T>::From(*tensor_x);
auto dy = framework::EigenMatrix<T>::From(*tensor_dy);
auto& dev =
*ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
dy.device(dev) = x * dout.broadcast(size);
}
}
#else
auto const *x = tensor_x->data<T>(), *y = tensor_y->data<T>(),
*dz = tensor_dout->data<T>();
auto&& d = tensor_x->dims();
auto const N = tensor_x->numel();
auto const B = d[d.size() - 1];
if (tensor_dx) {
auto* dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
for (auto j = 0; j < N / B; ++j) {
auto const ss = dz[j];
for (auto i = 0; i < B; ++i) *dx++ = *y++ * ss;
}
}
if (tensor_dy) {
auto* dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
for (auto j = 0; j < N / B; ++j) {
auto const ss = dz[j];
for (auto i = 0; i < B; i++) *dy++ = *x++ * ss;
}
}
#endif
}
};
// See Note [ Why still keep the original kernel implementation? ]
template <typename DeviceContext, typename T>
class DotKernel : public framework::OpKernel<T> {
......@@ -249,7 +46,7 @@ class DotKernel : public framework::OpKernel<T> {
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
// call new kernel
pten::Dot<T>(dev_ctx, *pt_x.get(), *pt_y.get(), pt_out.get());
pten::DotKernel<T>(dev_ctx, *pt_x.get(), *pt_y.get(), pt_out.get());
}
};
......@@ -266,8 +63,17 @@ class DotGradKernel : public framework::OpKernel<T> {
if (tensor_dx) tensor_dx->mutable_data<T>(ctx.GetPlace());
if (tensor_dy) tensor_dy->mutable_data<T>(ctx.GetPlace());
DotGradFunction<DeviceContext, T>()(tensor_x, tensor_y, tensor_dout,
tensor_dx, tensor_dy, ctx);
auto pt_x = paddle::experimental::MakePtenDenseTensor(*tensor_x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*tensor_y);
auto pt_dout = paddle::experimental::MakePtenDenseTensor(*tensor_dout);
auto pt_dx = paddle::experimental::MakePtenDenseTensor(*tensor_dx);
auto pt_dy = paddle::experimental::MakePtenDenseTensor(*tensor_dy);
auto& dev_ctx = ctx.device_context<DeviceContext>();
// call new kernel
pten::DotGradKernel<T>(dev_ctx, *pt_x, *pt_y, *pt_dout, pt_dx.get(),
pt_dy.get());
}
};
......
......@@ -225,6 +225,10 @@ class Blas {
const framework::Tensor& mat_b, const MatDescriptor& dim_b,
T alpha, framework::Tensor* mat_out, T beta) const;
template <typename T>
void MatMul(const T* mat_a, const MatDescriptor& dim_a, const T* mat_b,
const MatDescriptor& dim_b, T alpha, T* mat_out, T beta) const;
template <typename T>
void VINV(int n, const T* a, T* y) const;
......
......@@ -1249,6 +1249,15 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
const framework::Tensor &mat_b,
const MatDescriptor &dim_b, T alpha,
framework::Tensor *mat_out, T beta) const {
MatMul(mat_a.data<T>(), dim_a, mat_b.data<T>(), dim_b, alpha,
mat_out->data<T>(), beta);
}
template <typename DeviceContext>
template <typename T>
void Blas<DeviceContext>::MatMul(const T *mat_a, const MatDescriptor &dim_a,
const T *mat_b, const MatDescriptor &dim_b,
T alpha, T *mat_out, T beta) const {
PADDLE_ENFORCE_EQ(
dim_a.width_, dim_b.height_,
platform::errors::InvalidArgument(
......@@ -1261,8 +1270,7 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans;
if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) {
this->template GEMM<T>(transA, transB, dim_a.height_, dim_b.width_,
dim_a.width_, alpha, mat_a.data<T>(),
mat_b.data<T>(), beta, mat_out->data<T>());
dim_a.width_, alpha, mat_a, mat_b, beta, mat_out);
} else {
PADDLE_ENFORCE_EQ(
dim_a.batch_size_ == dim_b.batch_size_ || dim_a.batch_size_ == 0 ||
......@@ -1273,8 +1281,8 @@ void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
"But got dim_a.batch_size = %d, dim_b.batch_size = %d.",
dim_a.batch_size_, dim_b.batch_size_));
this->template BatchedGEMM<T>(
transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha,
mat_a.data<T>(), mat_b.data<T>(), beta, mat_out->data<T>(),
transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha, mat_a,
mat_b, beta, mat_out,
dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_,
dim_a.stride_, dim_b.stride_);
}
......
......@@ -389,6 +389,14 @@ class MatMulV2OpGrad : public framework::OperatorWithKernel {
tensor.place(), tensor.layout());
}
}
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"matmul_grad", {"X", "Y", framework::GradVarName("Out")},
{"trans_x", "trans_y"},
{framework::GradVarName("X"), framework::GradVarName("Y")});
}
};
template <typename T>
......@@ -431,6 +439,13 @@ class MatMulV2OpDoubleGrad : public framework::OperatorWithKernel {
context->ShareDim("DOut", "DDOut");
}
}
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"matmul_double_grad", {"X", "Y", "DOut", "DDX", "DDY"},
{"trans_x", "trans_y"}, {"DX", "DY", "DDOut"});
}
};
template <typename T>
......@@ -500,6 +515,15 @@ class MatMulV2OpTripleGrad : public framework::OperatorWithKernel {
context->ShareDim("Y", "D_DDY_out");
}
}
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"matmul_triple_grad",
{"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"},
{"trans_x", "trans_y"},
{"D_X_out", "D_Y_out", "D_DOut_out", "D_DDX_out", "D_DDY_out"});
}
};
template <typename T>
......
......@@ -28,6 +28,7 @@ limitations under the License. */
// only can include the headers in paddle/pten/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/kernels/matmul_grad_kernel.h"
#include "paddle/pten/kernels/matmul_kernel.h"
#if defined(__NVCC__) || defined(__HIPCC__)
......@@ -39,333 +40,6 @@ namespace operators {
using framework::Tensor;
template <typename DeviceContext, typename T>
void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output,
const std::vector<int>& reduce_dims,
const paddle::framework::ExecutionContext& ctx) {
#if defined(__NVCC__) || defined(__HIPCC__)
auto stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*input, output, kps::IdentityFunctor<T>(), reduce_dims, stream);
#else
ReduceKernelFunctor<DeviceContext, T, ops::SumFunctor>(
input, output, reduce_dims, true, false, ctx)
.template apply<T>();
#endif
}
static void GetBroadcastFromDims(const int x_ndim, const std::int64_t* x_dims,
const int y_ndim, const std::int64_t* y_dims,
std::int64_t* x_bd_dims,
std::int64_t* y_bd_dims,
std::int64_t* out_bd_dims) {
const int ndim = (std::max)(x_ndim, y_ndim);
std::fill(x_bd_dims, x_bd_dims + ndim - x_ndim, 1);
std::fill(y_bd_dims, y_bd_dims + ndim - y_ndim, 1);
std::copy(x_dims, x_dims + x_ndim, x_bd_dims + ndim - x_ndim);
std::copy(y_dims, y_dims + y_ndim, y_bd_dims + ndim - y_ndim);
for (int i = 0; i < ndim; ++i) {
PADDLE_ENFORCE_EQ(
x_bd_dims[i] == y_bd_dims[i] || x_bd_dims[i] <= 1 || y_bd_dims[i] <= 1,
true,
platform::errors::InvalidArgument(
"Input(X) and Input(Y) has error dim."
"X_broadcast's shape[%s] must be equal to Y_broadcast's shape[%s],"
"or X_broadcast's shape[%s] <= 1, or Y_broadcast's shape[%s] <= 1,"
"But received X_broadcast's shape[%s] = [%s]"
"received Y_broadcast's shape[%s] = [%s]",
i, i, i, i, i, x_bd_dims[i], i, y_bd_dims[i]));
if (x_bd_dims[i] == 0 || y_bd_dims[i] == 0) {
out_bd_dims[i] = 0;
} else {
out_bd_dims[i] = (std::max)(x_bd_dims[i], y_bd_dims[i]);
}
}
}
static int64_t GetIndexMessage(const int n, const int64_t* dims,
const int64_t* index) {
int64_t sum = 0;
for (int i = 0; i < n; ++i) {
if (dims[i] > 1) {
sum = sum * dims[i] + index[i];
}
}
return sum;
}
static void IndexIncreaseFromDims(const int ndim, const int64_t* dims,
int64_t* index) {
for (int i = ndim - 1; i >= 0; --i) {
++index[i];
if (index[i] >= dims[i]) {
index[i] -= dims[i];
} else {
break;
}
}
}
template <typename DeviceContext, typename T>
void MatMulFunction(const Tensor* X, const Tensor* Y,
const std::vector<std::int64_t>& x_dims,
const std::vector<std::int64_t>& y_dims, Tensor* Out,
bool trans_x, bool trans_y,
const paddle::framework::ExecutionContext& ctx,
bool flag = false) {
const int x_ndim = x_dims.size();
const int y_ndim = y_dims.size();
// Get data ptr
const T* x_data = X->data<T>();
const T* y_data = Y->data<T>();
if (x_ndim == 1 && y_ndim == 1) {
PADDLE_ENFORCE_EQ(
X->numel(), Y->numel(),
platform::errors::InvalidArgument(
"X's numbers must be equal to Y's numbers,"
"when X/Y's dims =1. But received X has [%d] elements,"
"received Y has [%d] elements",
X->numel(), Y->numel()));
VLOG(3) << "MatMul's case 1";
Out->Resize({1});
Out->mutable_data<T>(ctx.GetPlace());
auto out_eigen = framework::EigenScalar<T>::From(*Out);
auto x_eigen = framework::EigenVector<T>::Flatten(*X);
auto y_eigen = framework::EigenVector<T>::Flatten(*Y);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
if (flag) {
out_eigen.device(dev) = (x_eigen * y_eigen).sum() + out_eigen;
} else {
out_eigen.device(dev) = (x_eigen * y_eigen).sum();
}
return;
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
if (x_ndim == 1) {
const int N = X->numel();
if (trans_y) {
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], N,
platform::errors::InvalidArgument(
"Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d",
y_ndim - 1, N, y_ndim - 1, y_dims[y_ndim - 1]));
} else {
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], N,
platform::errors::InvalidArgument(
"Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d",
y_ndim - 2, N, y_ndim - 2, y_dims[y_ndim - 2]));
}
std::vector<std::int64_t> out_dims(y_ndim - 1);
if (trans_y) {
std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin());
} else {
std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin());
out_dims.back() = y_dims.back();
}
Out->Resize(framework::make_ddim(out_dims));
Out->mutable_data<T>(ctx.GetPlace());
if (trans_y) {
const int M = Y->numel() / N;
VLOG(3) << "MatMul's case 2";
blas.GEMV(false, M, N, static_cast<T>(1), y_data, x_data,
static_cast<T>(flag), Out->data<T>());
} else {
const int M = y_dims[y_ndim - 1];
const int batch_size = Y->numel() / (M * N);
if (batch_size == 1) {
VLOG(3) << "MatMul's case 3";
blas.GEMV(true, N, M, static_cast<T>(1), y_data, x_data,
static_cast<T>(flag), Out->data<T>());
} else {
VLOG(3) << "MatMul's case 4";
blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, static_cast<T>(1),
y_data, x_data, static_cast<T>(flag), Out->data<T>(),
batch_size, M * N, 0);
}
}
return;
}
if (y_ndim == 1) {
const int N = Y->numel();
if (trans_x) {
PADDLE_ENFORCE_EQ(x_dims[x_ndim - 2], N,
platform::errors::InvalidArgument(
"Input(X) has error dim."
"X'dims[%d] must be equal to %d"
"But received X'dims[%d] is %d",
x_ndim - 2, N, x_ndim - 2, x_dims[x_ndim - 2]));
} else {
PADDLE_ENFORCE_EQ(x_dims[x_ndim - 1], N,
platform::errors::InvalidArgument(
"Input(X) has error dim."
"X'dims[%d] must be equal to %d"
"But received X'dims[%d] is %d",
x_ndim - 1, N, x_ndim - 1, x_dims[x_ndim - 1]));
}
std::vector<std::int64_t> out_dims(x_ndim - 1);
if (trans_x) {
std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin());
out_dims.back() = x_dims.back();
} else {
std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin());
}
Out->Resize(framework::make_ddim(out_dims));
Out->mutable_data<T>(ctx.GetPlace());
if (trans_x) {
const int M = x_dims[x_ndim - 1];
const int batch_size = X->numel() / (M * N);
if (batch_size == 1) {
VLOG(3) << "MatMul's case 5";
blas.GEMV(true, N, M, static_cast<T>(1), x_data, y_data,
static_cast<T>(flag), Out->data<T>());
} else {
VLOG(3) << "MatMul's case 6";
blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, static_cast<T>(1),
x_data, y_data, static_cast<T>(flag), Out->data<T>(),
batch_size, M * N, 0);
}
} else {
const int M = X->numel() / N;
VLOG(3) << "MatMul's case 7";
blas.GEMV(false, M, N, static_cast<T>(1), x_data, y_data,
static_cast<T>(flag), Out->data<T>());
}
return;
}
const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2];
const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1];
if (trans_y) {
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], K,
platform::errors::InvalidArgument(
"Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d",
y_ndim - 1, K, y_ndim - 1, y_dims[y_ndim - 1]));
} else {
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], K,
platform::errors::InvalidArgument(
"Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d",
y_ndim - 2, K, y_ndim - 2, y_dims[y_ndim - 2]));
}
const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1];
const int ndim = (std::max)(x_ndim, y_ndim);
std::vector<std::int64_t> x_broadcast_dims(ndim);
std::vector<std::int64_t> y_broadcast_dims(ndim);
std::vector<std::int64_t> out_broadcast_dims(ndim);
GetBroadcastFromDims(x_ndim - 2, x_dims.data(), y_ndim - 2, y_dims.data(),
x_broadcast_dims.data(), y_broadcast_dims.data(),
out_broadcast_dims.data());
out_broadcast_dims[ndim - 2] = M;
out_broadcast_dims[ndim - 1] = N;
Out->Resize(framework::make_ddim(out_broadcast_dims));
Out->mutable_data<T>(ctx.GetPlace());
const int batch_dim = ndim - 2;
// broadcast message
const bool is_broadcast_dims = !std::equal(
x_broadcast_dims.cbegin(), x_broadcast_dims.cbegin() + batch_dim,
y_broadcast_dims.cbegin());
const std::int64_t x_batch_size = std::accumulate(
x_broadcast_dims.cbegin(), x_broadcast_dims.cbegin() + batch_dim, 1LL,
std::multiplies<std::int64_t>());
const std::int64_t y_batch_size = std::accumulate(
y_broadcast_dims.cbegin(), y_broadcast_dims.cbegin() + batch_dim, 1LL,
std::multiplies<std::int64_t>());
const std::int64_t out_batch_size = std::accumulate(
out_broadcast_dims.cbegin(), out_broadcast_dims.cbegin() + batch_dim, 1LL,
std::multiplies<std::int64_t>());
if (out_batch_size == 0) return;
if (x_batch_size == 1 && y_batch_size == 1) {
VLOG(3) << "MatMul's case 8";
blas.GEMM(trans_x ? CblasTrans : CblasNoTrans,
trans_y ? CblasTrans : CblasNoTrans, M, N, K, static_cast<T>(1),
x_data, y_data, static_cast<T>(flag), Out->data<T>());
} else if (x_batch_size == 1) {
if (M == 1 && trans_y) {
VLOG(3) << "MatMul's case 9";
blas.GEMV(false, y_batch_size * N, K, static_cast<T>(1), y_data, x_data,
static_cast<T>(flag), Out->data<T>());
} else {
VLOG(3) << "MatMul's case 10";
blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans,
trans_y ? CblasTrans : CblasNoTrans, M, N, K,
static_cast<T>(1), x_data, y_data, static_cast<T>(flag),
Out->data<T>(), out_batch_size, 0, K * N);
}
} else if (y_batch_size == 1) {
if (!trans_x) {
VLOG(3) << "MatMul's case 11";
blas.GEMM(CblasNoTrans, trans_y ? CblasTrans : CblasNoTrans,
x_batch_size * M, N, K, static_cast<T>(1), x_data, y_data,
static_cast<T>(flag), Out->data<T>());
} else {
VLOG(3) << "MatMul's case 12";
blas.BatchedGEMM(CblasTrans, trans_y ? CblasTrans : CblasNoTrans, M, N, K,
static_cast<T>(1), x_data, y_data, static_cast<T>(flag),
Out->data<T>(), out_batch_size, M * K, 0);
}
} else if (!is_broadcast_dims) {
VLOG(3) << "MatMul's case 13";
blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans,
trans_y ? CblasTrans : CblasNoTrans, M, N, K,
static_cast<T>(1), x_data, y_data, static_cast<T>(flag),
Out->data<T>(), out_batch_size, M * K, K * N);
} else {
// in the case, can't use stridedgemm
std::vector<const T*> x_ptr(out_batch_size);
std::vector<const T*> y_ptr(out_batch_size);
std::vector<T*> out_ptr(out_batch_size);
std::vector<std::int64_t> index(batch_dim, 0);
for (std::int64_t i = 0; i < out_batch_size; ++i) {
// using the index to get offset
const std::int64_t x_index =
GetIndexMessage(batch_dim, x_broadcast_dims.data(), index.data());
const std::int64_t y_index =
GetIndexMessage(batch_dim, y_broadcast_dims.data(), index.data());
x_ptr[i] = x_data + x_index * M * K;
y_ptr[i] = y_data + y_index * K * N;
out_ptr[i] = Out->data<T>() + i * M * N;
IndexIncreaseFromDims(batch_dim, out_broadcast_dims.data(), index.data());
}
VLOG(3) << "MatMul's case 14";
blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans,
trans_y ? CblasTrans : CblasNoTrans, M, N, K,
static_cast<T>(1), x_ptr.data(), y_ptr.data(),
static_cast<T>(flag), out_ptr.data(), out_batch_size);
}
}
template <typename DeviceContext, typename T>
void MatMulFunction(const Tensor* X, const Tensor* Y, Tensor* Out, bool trans_x,
bool trans_y,
const paddle::framework::ExecutionContext& ctx,
bool flag = false) {
const std::vector<std::int64_t> x_dims = vectorize(X->dims());
const std::vector<std::int64_t> y_dims = vectorize(Y->dims());
MatMulFunction<DeviceContext, T>(X, Y, x_dims, y_dims, Out, trans_x, trans_y,
ctx, flag);
}
template <typename DeviceContext, typename T>
class MatMulV2Kernel : public framework::OpKernel<T> {
public:
......@@ -400,26 +74,6 @@ static framework::Tensor FoldInitDims(const framework::Tensor& input) {
return output;
}
// Reshape a rank-3 tensor from P x M x N to M x (P * N).
// (Warning: This requires transposing data and writes into new memory.)
// Identity op if the tensor is not of rank 3.
template <typename DeviceContext, typename T>
static framework::Tensor FoldHeadAndLastDims(const DeviceContext& context,
const framework::Tensor& input) {
auto in_dims = input.dims();
if (in_dims.size() != 3) {
return input;
}
framework::Tensor output;
output.Resize({in_dims[1], in_dims[0], in_dims[2]});
output.mutable_data<T>(context.GetPlace());
std::vector<int> axis = {1, 0, 2};
math::Transpose<DeviceContext, T, 3> trans;
trans(context, input, &output, axis);
output.Resize({in_dims[1], in_dims[0] * in_dims[2]});
return output;
}
/**
* Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
* original x_dim is returned.
......@@ -483,1868 +137,130 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x,
}
template <typename DeviceContext, typename T>
struct ConjHelper {
explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {}
HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) {
dst.Resize(src.dims());
dst.set_layout(src.layout());
dst.ShareDataWith(src);
return;
}
const framework::ExecutionContext& ctx_;
};
template <typename DeviceContext>
struct ConjHelper<DeviceContext, paddle::platform::complex<float>> {
explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {}
HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) {
dst.Resize(src.dims());
auto* src_data = src.data<paddle::platform::complex<float>>();
auto* dst_data = dst.mutable_data<paddle::platform::complex<float>>(
ctx_.GetPlace(),
size_t(src.numel() * sizeof(paddle::platform::complex<float>)));
platform::ForRange<DeviceContext> for_range(
ctx_.template device_context<DeviceContext>(), src.numel());
math::ConjFunctor<paddle::platform::complex<float>> functor(
src_data, src.numel(), dst_data);
for_range(functor);
return;
}
const framework::ExecutionContext& ctx_;
};
template <typename DeviceContext>
struct ConjHelper<DeviceContext, paddle::platform::complex<double>> {
explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {}
HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) {
dst.Resize(src.dims());
auto* src_data = src.data<paddle::platform::complex<double>>();
auto* dst_data = dst.mutable_data<paddle::platform::complex<double>>(
ctx_.GetPlace(),
size_t(src.numel() * sizeof(paddle::platform::complex<double>)));
platform::ForRange<DeviceContext> for_range(
ctx_.template device_context<DeviceContext>(), src.numel());
math::ConjFunctor<paddle::platform::complex<double>> functor(
src_data, src.numel(), dst_data);
for_range(functor);
return;
}
const framework::ExecutionContext& ctx_;
};
template <typename DeviceContext, typename T, typename Enabel = void>
struct DotDoubleGradFunction {
void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
Tensor* tensor_dx, Tensor* tensor_dy,
const Tensor* tensor_dout, const Tensor* tensor_ddx,
const Tensor* tensor_ddy, Tensor* tensor_ddout,
const paddle::framework::ExecutionContext& ctx);
};
template <typename DeviceContext, typename T>
struct DotDoubleGradFunction<DeviceContext, T, math::EnableComplex<T>> {
void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
Tensor* tensor_dx, Tensor* tensor_dy,
const Tensor* tensor_dout, const Tensor* tensor_ddx,
const Tensor* tensor_ddy, Tensor* tensor_ddout,
const paddle::framework::ExecutionContext& ctx) {
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) {
framework::Tensor tensor_dout_help;
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
if (tensor_dx || tensor_dy) {
tensor_dout_help.Resize(tensor_dout->dims());
tensor_dout_help.mutable_data<T>(ctx.GetPlace());
paddle::platform::ForRange<DeviceContext> for_range(
dev_raw, tensor_dout->numel());
math::ConjFunctor<T> functor(tensor_dout->data<T>(),
tensor_dout->numel(),
tensor_dout_help.data<T>());
for_range(functor);
}
if (tensor_dx) {
auto ddy = framework::EigenVector<T>::Flatten(*tensor_ddy);
Eigen::DSizes<int, 1> size(tensor_ddy->numel());
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
auto dout = framework::EigenVector<T>::Flatten(tensor_dout_help);
dx.device(dev) = ddy * dout.broadcast(size);
}
if (tensor_dy) {
auto ddx = framework::EigenVector<T>::Flatten(*tensor_ddx);
Eigen::DSizes<int, 1> size(tensor_ddx->numel());
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
auto dout = framework::EigenVector<T>::Flatten(tensor_dout_help);
dy.device(dev) = ddx * dout.broadcast(size);
}
if (tensor_ddout) {
framework::Tensor tensor_x_help, tensor_y_help;
tensor_x_help.Resize(tensor_x->dims());
tensor_x_help.mutable_data<T>(ctx.GetPlace());
tensor_y_help.Resize(tensor_y->dims());
tensor_y_help.mutable_data<T>(ctx.GetPlace());
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
paddle::platform::ForRange<DeviceContext> for_range(dev_raw,
tensor_x->numel());
math::ConjFunctor<T> functor_x(tensor_x->data<T>(), tensor_x->numel(),
tensor_x_help.data<T>());
for_range(functor_x);
math::ConjFunctor<T> functor_y(tensor_y->data<T>(), tensor_y->numel(),
tensor_y_help.data<T>());
for_range(functor_y);
auto x = framework::EigenVector<T>::Flatten(tensor_x_help);
auto y = framework::EigenVector<T>::Flatten(tensor_y_help);
auto ddx = framework::EigenVector<T>::Flatten(*tensor_ddx);
auto ddy = framework::EigenVector<T>::Flatten(*tensor_ddy);
auto ddout = framework::EigenVector<T>::Flatten(*tensor_ddout);
ddout.device(dev) = (x * ddy + y * ddx).sum();
}
}
#else
const auto* data_dout = tensor_dout->data<T>();
if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
const auto* data_ddy = tensor_ddy->data<T>();
const framework::DDim& dim = tensor_dx->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = T(data_dout[s].real, -data_dout[s].imag) * data_ddy[i];
}
}
if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
const auto* data_ddx = tensor_ddx->data<T>();
const framework::DDim& dim = tensor_dy->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = T(data_dout[s].real, -data_dout[s].imag) * data_ddx[i];
}
}
if (tensor_ddout) {
auto* data_ddout = tensor_ddout->mutable_data<T>(ctx.GetPlace());
auto* data_x = tensor_x->data<T>();
auto* data_y = tensor_y->data<T>();
auto* data_ddx = tensor_ddx->data<T>();
auto* data_ddy = tensor_ddy->data<T>();
const framework::DDim& dim = tensor_dy->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
bool new_s = false;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) {
++s;
new_s = true;
}
if (new_s) {
data_ddout[s] = T(data_x[i].real, -data_x[i].imag) * data_ddy[i] +
T(data_y[i].real, -data_y[i].imag) * data_ddx[i];
} else {
data_ddout[s] += T(data_x[i].real, -data_x[i].imag) * data_ddy[i] +
T(data_y[i].real, -data_y[i].imag) * data_ddx[i];
}
new_s = false;
}
}
#endif
}
};
template <typename DeviceContext, typename T>
struct DotDoubleGradFunction<DeviceContext, T, math::DisableComplex<T>> {
void operator()(const Tensor* tensor_x, const Tensor* tensor_y,
Tensor* tensor_dx, Tensor* tensor_dy,
const Tensor* tensor_dout, const Tensor* tensor_ddx,
const Tensor* tensor_ddy, Tensor* tensor_ddout,
const paddle::framework::ExecutionContext& ctx) {
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) {
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
auto dout = framework::EigenVector<T>::Flatten(*tensor_dout);
if (tensor_dx) {
tensor_dx->mutable_data<T>(ctx.GetPlace());
auto ddy = framework::EigenVector<T>::Flatten(*tensor_ddy);
Eigen::DSizes<int, 1> size(tensor_ddy->numel());
auto dx = framework::EigenVector<T>::Flatten(*tensor_dx);
dx.device(dev) = ddy * dout.broadcast(size);
}
if (tensor_dy) {
tensor_dy->mutable_data<T>(ctx.GetPlace());
auto ddx = framework::EigenVector<T>::Flatten(*tensor_ddx);
Eigen::DSizes<int, 1> size(tensor_ddx->numel());
auto dy = framework::EigenVector<T>::Flatten(*tensor_dy);
dy.device(dev) = ddx * dout.broadcast(size);
}
if (tensor_ddout) {
tensor_ddout->mutable_data<T>(ctx.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*tensor_x);
auto y = framework::EigenVector<T>::Flatten(*tensor_y);
auto ddx = framework::EigenVector<T>::Flatten(*tensor_ddx);
auto ddy = framework::EigenVector<T>::Flatten(*tensor_ddy);
auto ddout = framework::EigenVector<T>::Flatten(*tensor_ddout);
ddout.device(dev) = (x * ddy + y * ddx).sum();
}
}
#else
const auto* data_dout = tensor_dout->data<T>();
if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>(ctx.GetPlace());
const auto* data_ddy = tensor_ddy->data<T>();
const framework::DDim& dim = tensor_dx->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = data_dout[s] * data_ddy[i];
}
}
if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>(ctx.GetPlace());
const auto* data_ddx = tensor_ddx->data<T>();
const framework::DDim& dim = tensor_dy->dims();
size_t N = static_cast<size_t>(framework::product(dim));
class MatMulV2GradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
bool transpose_x = ctx.Attr<bool>("trans_x");
bool transpose_y = ctx.Attr<bool>("trans_y");
auto* x = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y");
auto* dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto step = dim[dim.size() - 1];
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = data_dout[s] * data_ddx[i];
}
}
if (dx) dx->mutable_data<T>(ctx.GetPlace());
if (dy) dy->mutable_data<T>(ctx.GetPlace());
if (tensor_ddout) {
auto* data_ddout = tensor_ddout->mutable_data<T>(ctx.GetPlace());
auto* data_x = tensor_x->data<T>();
auto* data_y = tensor_y->data<T>();
auto* data_ddx = tensor_ddx->data<T>();
auto* data_ddy = tensor_ddy->data<T>();
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_dout = paddle::experimental::MakePtenDenseTensor(*dout);
auto pt_dx = dx ? paddle::experimental::MakePtenDenseTensor(*dx)
: std::unique_ptr<pten::DenseTensor>(nullptr);
auto pt_dy = dy ? paddle::experimental::MakePtenDenseTensor(*dy)
: std::unique_ptr<pten::DenseTensor>(nullptr);
const framework::DDim& dim = tensor_dy->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
bool new_s = false;
auto& dev_ctx = ctx.device_context<DeviceContext>();
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) {
++s;
new_s = true;
}
if (new_s) {
data_ddout[s] = data_x[i] * data_ddy[i] + data_y[i] * data_ddx[i];
} else {
data_ddout[s] += data_x[i] * data_ddy[i] + data_y[i] * data_ddx[i];
}
new_s = false;
}
}
#endif
// call new kernel
pten::MatmulGradKernel<T>(dev_ctx, *pt_x, *pt_y, *pt_dout, transpose_x,
transpose_y, pt_dx.get(), pt_dy.get());
}
};
template <typename DeviceContext, typename T, typename Enabel = void>
struct DotTripleGradFunction {
void operator()(const Tensor* in_tensor_x, const Tensor* in_tensor_y,
const Tensor* in_tensor_ddx, const Tensor* in_tensor_ddy,
const Tensor* in_tensor_d_dx, const Tensor* in_tensor_d_dy,
const Tensor* in_tensor_dout, const Tensor* in_tensor_d_ddout,
Tensor* out_tensor_d_x, Tensor* out_tensor_d_y,
Tensor* out_tensor_d_dout, Tensor* out_tensor_d_ddx,
Tensor* out_tensor_d_ddy,
const paddle::framework::ExecutionContext& ctx);
};
// TODO(wuweilong): enable this function when the unittests framewark for multi
// grad is ok (dtype: complex64 or complex128).
template <typename DeviceContext, typename T>
struct DotTripleGradFunction<DeviceContext, T, math::EnableComplex<T>> {
void operator()(const Tensor* in_tensor_x, const Tensor* in_tensor_y,
const Tensor* in_tensor_ddx, const Tensor* in_tensor_ddy,
const Tensor* in_tensor_d_dx, const Tensor* in_tensor_d_dy,
const Tensor* in_tensor_dout, const Tensor* in_tensor_d_ddout,
Tensor* out_tensor_d_x, Tensor* out_tensor_d_y,
Tensor* out_tensor_d_dout, Tensor* out_tensor_d_ddx,
Tensor* out_tensor_d_ddy,
const paddle::framework::ExecutionContext& ctx) {
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == in_tensor_d_ddout->dims().size()) {
framework::Tensor in_tensor_d_ddout_help;
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
if (out_tensor_d_x || out_tensor_d_y) {
in_tensor_d_ddout_help.Resize(in_tensor_d_ddout->dims());
in_tensor_d_ddout_help.mutable_data<T>(ctx.GetPlace());
paddle::platform::ForRange<DeviceContext> for_range(
dev_raw, in_tensor_d_ddout->numel());
math::ConjFunctor<T> functor(in_tensor_d_ddout->data<T>(),
in_tensor_d_ddout->numel(),
in_tensor_d_ddout_help.data<T>());
for_range(functor);
}
if (out_tensor_d_x) {
auto ddy = framework::EigenVector<T>::Flatten(*in_tensor_ddy);
Eigen::DSizes<int, 1> size(in_tensor_ddy->numel());
auto d_x = framework::EigenVector<T>::Flatten(*out_tensor_d_x);
auto d_ddout =
framework::EigenVector<T>::Flatten(in_tensor_d_ddout_help);
d_x.device(dev) = ddy * d_ddout.broadcast(size);
}
if (out_tensor_d_y) {
auto ddx = framework::EigenVector<T>::Flatten(*in_tensor_ddx);
Eigen::DSizes<int, 1> size(in_tensor_ddx->numel());
auto d_y = framework::EigenVector<T>::Flatten(*out_tensor_d_y);
auto d_ddout =
framework::EigenVector<T>::Flatten(in_tensor_d_ddout_help);
d_y.device(dev) = ddx * d_ddout.broadcast(size);
}
if (out_tensor_d_dout) {
framework::Tensor in_tensor_ddx_help, in_tensor_ddy_help;
in_tensor_ddx_help.Resize(in_tensor_ddx->dims());
in_tensor_ddx_help.mutable_data<T>(ctx.GetPlace());
in_tensor_ddy_help.Resize(in_tensor_ddy->dims());
in_tensor_ddy_help.mutable_data<T>(ctx.GetPlace());
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
paddle::platform::ForRange<DeviceContext> for_range(
dev_raw, in_tensor_ddx->numel());
math::ConjFunctor<T> functor_ddx(in_tensor_ddx->data<T>(),
in_tensor_ddx->numel(),
in_tensor_ddx_help.data<T>());
for_range(functor_ddx);
math::ConjFunctor<T> functor_ddy(in_tensor_ddy->data<T>(),
in_tensor_ddy->numel(),
in_tensor_ddy_help.data<T>());
for_range(functor_ddy);
auto ddx = framework::EigenVector<T>::Flatten(in_tensor_ddx_help);
auto ddy = framework::EigenVector<T>::Flatten(in_tensor_ddy_help);
auto d_dx = framework::EigenVector<T>::Flatten(*in_tensor_d_dx);
auto d_dy = framework::EigenVector<T>::Flatten(*in_tensor_d_dy);
auto d_dout = framework::EigenVector<T>::Flatten(*out_tensor_d_dout);
d_dout.device(dev) = (ddx * d_dy + ddy * d_dx).sum();
}
if (out_tensor_d_ddx) {
framework::Tensor in_tensor_dout_help, in_tensor_y_help;
in_tensor_dout_help.Resize(in_tensor_dout->dims());
in_tensor_dout_help.mutable_data<T>(ctx.GetPlace());
in_tensor_y_help.Resize(in_tensor_y->dims());
in_tensor_y_help.mutable_data<T>(ctx.GetPlace());
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
paddle::platform::ForRange<DeviceContext> for_range(
dev_raw, in_tensor_dout->numel());
math::ConjFunctor<T> functor_dout(in_tensor_dout->data<T>(),
in_tensor_dout->numel(),
in_tensor_dout_help.data<T>());
for_range(functor_dout);
math::ConjFunctor<T> functor_y(in_tensor_y->data<T>(),
in_tensor_y->numel(),
in_tensor_y_help.data<T>());
for_range(functor_y);
auto dout = framework::EigenVector<T>::Flatten(in_tensor_dout_help);
auto y = framework::EigenVector<T>::Flatten(in_tensor_y_help);
auto d_ddout = framework::EigenVector<T>::Flatten(*in_tensor_d_ddout);
auto d_dy = framework::EigenVector<T>::Flatten(*in_tensor_d_dy);
auto d_ddx = framework::EigenVector<T>::Flatten(*out_tensor_d_ddx);
Eigen::DSizes<int, 1> size(in_tensor_y->numel());
d_ddx.device(dev) =
(dout.broadcast(size) * d_dy + y * d_ddout.broadcast(size));
}
if (out_tensor_d_ddy) {
framework::Tensor in_tensor_dout_help, in_tensor_x_help;
in_tensor_dout_help.Resize(in_tensor_dout->dims());
in_tensor_dout_help.mutable_data<T>(ctx.GetPlace());
in_tensor_x_help.Resize(in_tensor_x->dims());
in_tensor_x_help.mutable_data<T>(ctx.GetPlace());
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
paddle::platform::ForRange<DeviceContext> for_range(
dev_raw, in_tensor_dout->numel());
math::ConjFunctor<T> functor_dout(in_tensor_dout->data<T>(),
in_tensor_dout->numel(),
in_tensor_dout_help.data<T>());
for_range(functor_dout);
math::ConjFunctor<T> functor_x(in_tensor_x->data<T>(),
in_tensor_x->numel(),
in_tensor_x_help.data<T>());
for_range(functor_x);
auto dout = framework::EigenVector<T>::Flatten(in_tensor_dout_help);
auto x = framework::EigenVector<T>::Flatten(in_tensor_x_help);
auto d_ddout = framework::EigenVector<T>::Flatten(*in_tensor_d_ddout);
auto d_dx = framework::EigenVector<T>::Flatten(*in_tensor_d_dx);
auto d_ddy = framework::EigenVector<T>::Flatten(*out_tensor_d_ddy);
Eigen::DSizes<int, 1> size(in_tensor_x->numel());
d_ddy.device(dev) =
(dout.broadcast(size) * d_dx + x * d_ddout.broadcast(size));
}
}
#else
const auto* data_d_ddout = in_tensor_d_ddout->data<T>();
if (out_tensor_d_x) {
auto* data_d_x = out_tensor_d_x->mutable_data<T>(ctx.GetPlace());
const auto* data_ddy = in_tensor_ddy->data<T>();
const framework::DDim& dim = out_tensor_d_x->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_d_x[i] = T(data_ddy[i].real, -data_ddy[i].imag) * data_d_ddout[s];
}
}
if (out_tensor_d_y) {
auto* data_d_y = out_tensor_d_y->mutable_data<T>(ctx.GetPlace());
const auto* data_ddx = in_tensor_ddx->data<T>();
const framework::DDim& dim = out_tensor_d_y->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_d_y[i] = T(data_ddx[i].real, -data_ddx[i].imag) * data_d_ddout[s];
}
}
if (out_tensor_d_dout) {
auto* data_d_dout = out_tensor_d_dout->mutable_data<T>(ctx.GetPlace());
auto* data_ddx = in_tensor_ddx->data<T>();
auto* data_ddy = in_tensor_ddy->data<T>();
auto* data_d_dx = in_tensor_d_dx->data<T>();
auto* data_d_dy = in_tensor_d_dy->data<T>();
const framework::DDim& dim = out_tensor_d_dout->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
bool new_s = false;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) {
++s;
new_s = true;
}
if (new_s) {
data_d_dout[s] =
T(data_ddy[i].real, -data_ddy[i].imag) * data_d_dx[i] +
T(data_ddx[i].real, -data_ddx[i].imag) * data_d_dy[i];
} else {
data_d_dout[s] +=
T(data_ddy[i].real, -data_ddy[i].imag) * data_d_dx[i] +
T(data_ddx[i].real, -data_ddx[i].imag) * data_d_dy[i];
}
new_s = false;
}
}
class MatMulV2DoubleGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* y = context.Input<framework::Tensor>("Y");
auto* dout = context.Input<framework::Tensor>("DOut");
auto* ddx = context.Input<framework::Tensor>("DDX");
auto* ddy = context.Input<framework::Tensor>("DDY");
if (out_tensor_d_ddx) {
auto* data_d_ddx = out_tensor_d_ddx->mutable_data<T>(ctx.GetPlace());
auto* data_dout = in_tensor_dout->data<T>();
auto* data_d_dy = in_tensor_d_dy->data<T>();
auto* data_y = in_tensor_y->data<T>();
auto* data_d_ddout = in_tensor_d_ddout->data<T>();
auto* dx = context.Output<framework::Tensor>("DX");
auto* dy = context.Output<framework::Tensor>("DY");
auto* ddout = context.Output<framework::Tensor>("DDOut");
const framework::DDim& dim = out_tensor_d_ddx->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
bool transpose_x = context.Attr<bool>("trans_x");
bool transpose_y = context.Attr<bool>("trans_y");
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_d_ddx[i] =
T(data_dout[s].real, -data_dout[s].imag) * data_d_dy[i] +
T(data_y[i].real, -data_y[i].imag) * data_d_ddout[s];
}
}
if (dx) dx->mutable_data<T>(context.GetPlace());
if (dy) dy->mutable_data<T>(context.GetPlace());
if (ddout) ddout->mutable_data<T>(context.GetPlace());
if (out_tensor_d_ddy) {
auto* data_d_ddy = out_tensor_d_ddy->mutable_data<T>(ctx.GetPlace());
auto* data_dout = in_tensor_dout->data<T>();
auto* data_d_dx = in_tensor_d_dx->data<T>();
auto* data_x = in_tensor_x->data<T>();
auto* data_d_ddout = in_tensor_d_ddout->data<T>();
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_dout = paddle::experimental::MakePtenDenseTensor(*dout);
auto pt_ddx = paddle::experimental::MakePtenDenseTensor(*ddx);
auto pt_ddy = paddle::experimental::MakePtenDenseTensor(*ddy);
auto pt_dx = paddle::experimental::MakePtenDenseTensor(*dx);
auto pt_dy = paddle::experimental::MakePtenDenseTensor(*dy);
auto pt_ddout = paddle::experimental::MakePtenDenseTensor(*ddout);
const framework::DDim& dim = out_tensor_d_ddy->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
auto& dev_ctx = context.device_context<DeviceContext>();
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_d_ddy[i] =
T(data_dout[s].real, -data_dout[s].imag) * data_d_dx[i] +
T(data_x[i].real, -data_x[i].imag) * data_d_ddout[s];
}
}
#endif
// call new kernel
pten::MatmulDoubleGradKernel<T>(dev_ctx, *pt_x, *pt_y, *pt_dout, *pt_ddx,
*pt_ddy, transpose_x, transpose_y,
pt_dx.get(), pt_dy.get(), pt_ddout.get());
}
};
template <typename DeviceContext, typename T>
struct DotTripleGradFunction<DeviceContext, T, math::DisableComplex<T>> {
void operator()(const Tensor* in_tensor_x, const Tensor* in_tensor_y,
const Tensor* in_tensor_ddx, const Tensor* in_tensor_ddy,
const Tensor* in_tensor_d_dx, const Tensor* in_tensor_d_dy,
const Tensor* in_tensor_dout, const Tensor* in_tensor_d_ddout,
Tensor* out_tensor_d_x, Tensor* out_tensor_d_y,
Tensor* out_tensor_d_dout, Tensor* out_tensor_d_ddx,
Tensor* out_tensor_d_ddy,
const paddle::framework::ExecutionContext& ctx) {
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == in_tensor_d_ddout->dims().size()) {
auto& dev_raw = ctx.template device_context<DeviceContext>();
auto& dev = *dev_raw.eigen_device();
auto d_ddout = framework::EigenVector<T>::Flatten(*in_tensor_d_ddout);
if (out_tensor_d_x) {
out_tensor_d_x->mutable_data<T>(ctx.GetPlace());
auto ddy = framework::EigenVector<T>::Flatten(*in_tensor_ddy);
Eigen::DSizes<int, 1> size(in_tensor_ddy->numel());
auto d_x = framework::EigenVector<T>::Flatten(*out_tensor_d_x);
d_x.device(dev) = ddy * d_ddout.broadcast(size);
}
if (out_tensor_d_y) {
out_tensor_d_y->mutable_data<T>(ctx.GetPlace());
auto ddx = framework::EigenVector<T>::Flatten(*in_tensor_ddx);
Eigen::DSizes<int, 1> size(in_tensor_ddx->numel());
auto d_y = framework::EigenVector<T>::Flatten(*out_tensor_d_y);
d_y.device(dev) = ddx * d_ddout.broadcast(size);
}
if (out_tensor_d_dout) {
out_tensor_d_dout->mutable_data<T>(ctx.GetPlace());
auto ddx = framework::EigenVector<T>::Flatten(*in_tensor_ddx);
auto ddy = framework::EigenVector<T>::Flatten(*in_tensor_ddy);
auto d_dx = framework::EigenVector<T>::Flatten(*in_tensor_d_dx);
auto d_dy = framework::EigenVector<T>::Flatten(*in_tensor_d_dy);
auto d_dout = framework::EigenVector<T>::Flatten(*out_tensor_d_dout);
d_dout.device(dev) = (ddx * d_dy + ddy * d_dx).sum();
}
if (out_tensor_d_ddx) {
out_tensor_d_ddx->mutable_data<T>(ctx.GetPlace());
auto dout = framework::EigenVector<T>::Flatten(*in_tensor_dout);
auto y = framework::EigenVector<T>::Flatten(*in_tensor_y);
auto d_ddout = framework::EigenVector<T>::Flatten(*in_tensor_d_ddout);
auto d_dy = framework::EigenVector<T>::Flatten(*in_tensor_d_dy);
auto d_ddx = framework::EigenVector<T>::Flatten(*out_tensor_d_ddx);
Eigen::DSizes<int, 1> size(in_tensor_y->numel());
d_ddx.device(dev) =
(dout.broadcast(size) * d_dy + y * d_ddout.broadcast(size));
}
if (out_tensor_d_ddy) {
out_tensor_d_ddy->mutable_data<T>(ctx.GetPlace());
auto dout = framework::EigenVector<T>::Flatten(*in_tensor_dout);
auto x = framework::EigenVector<T>::Flatten(*in_tensor_x);
auto d_ddout = framework::EigenVector<T>::Flatten(*in_tensor_d_ddout);
auto d_dx = framework::EigenVector<T>::Flatten(*in_tensor_d_dx);
auto d_ddy = framework::EigenVector<T>::Flatten(*out_tensor_d_ddy);
Eigen::DSizes<int, 1> size(in_tensor_x->numel());
d_ddy.device(dev) =
(dout.broadcast(size) * d_dx + x * d_ddout.broadcast(size));
}
}
#else
const auto* data_d_ddout = in_tensor_d_ddout->data<T>();
if (out_tensor_d_x) {
auto* data_d_x = out_tensor_d_x->mutable_data<T>(ctx.GetPlace());
const auto* data_ddy = in_tensor_ddy->data<T>();
const framework::DDim& dim = out_tensor_d_x->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_d_x[i] = data_ddy[i] * data_d_ddout[s];
}
}
if (out_tensor_d_y) {
auto* data_d_y = out_tensor_d_y->mutable_data<T>(ctx.GetPlace());
const auto* data_ddx = in_tensor_ddx->data<T>();
const framework::DDim& dim = out_tensor_d_y->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_d_y[i] = data_ddx[i] * data_d_ddout[s];
}
}
if (out_tensor_d_dout) {
auto* data_d_dout = out_tensor_d_dout->mutable_data<T>(ctx.GetPlace());
auto* data_ddx = in_tensor_ddx->data<T>();
auto* data_ddy = in_tensor_ddy->data<T>();
auto* data_d_dx = in_tensor_d_dx->data<T>();
auto* data_d_dy = in_tensor_d_dy->data<T>();
const framework::DDim& dim = in_tensor_ddx->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
bool new_s = false;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) {
++s;
new_s = true;
}
if (new_s) {
data_d_dout[s] =
data_ddy[i] * data_d_dx[i] + data_ddx[i] * data_d_dy[i];
} else {
data_d_dout[s] +=
data_ddy[i] * data_d_dx[i] + data_ddx[i] * data_d_dy[i];
}
new_s = false;
}
}
class MatMulV2TripleGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
// get input
auto* x = context.Input<framework::Tensor>("X");
auto* y = context.Input<framework::Tensor>("Y");
auto* dout = context.Input<framework::Tensor>("DOut");
auto* ddx = context.Input<framework::Tensor>("DDX");
auto* ddy = context.Input<framework::Tensor>("DDY");
if (out_tensor_d_ddx) {
auto* data_d_ddx = out_tensor_d_ddx->mutable_data<T>(ctx.GetPlace());
auto* data_dout = in_tensor_dout->data<T>();
auto* data_d_dy = in_tensor_d_dy->data<T>();
auto* data_y = in_tensor_y->data<T>();
auto* data_d_ddout = in_tensor_d_ddout->data<T>();
auto* d_dx = context.Input<framework::Tensor>("D_DX");
auto* d_dy = context.Input<framework::Tensor>("D_DY");
auto* d_ddout = context.Input<framework::Tensor>("D_DDOut");
const framework::DDim& dim = out_tensor_d_ddx->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
// get output
auto* out_d_x = context.Output<framework::Tensor>("D_X_out");
auto* out_d_y = context.Output<framework::Tensor>("D_Y_out");
auto* out_d_dout = context.Output<framework::Tensor>("D_DOut_out");
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_d_ddx[i] =
data_dout[s] * data_d_dy[i] + data_y[i] * data_d_ddout[s];
}
}
auto* out_d_ddx = context.Output<framework::Tensor>("D_DDX_out");
auto* out_d_ddy = context.Output<framework::Tensor>("D_DDY_out");
if (out_tensor_d_ddy) {
auto* data_d_ddy = out_tensor_d_ddy->mutable_data<T>(ctx.GetPlace());
auto* data_dout = in_tensor_dout->data<T>();
auto* data_d_dx = in_tensor_d_dx->data<T>();
auto* data_x = in_tensor_x->data<T>();
auto* data_d_ddout = in_tensor_d_ddout->data<T>();
bool transpose_x = context.Attr<bool>("trans_x");
bool transpose_y = context.Attr<bool>("trans_y");
const framework::DDim& dim = out_tensor_d_ddy->dims();
size_t N = static_cast<size_t>(framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_d_ddy[i] =
data_dout[s] * data_d_dx[i] + data_x[i] * data_d_ddout[s];
}
}
#endif
}
};
template <typename DeviceContext, typename T>
class MatMulV2GradKernel : public framework::OpKernel<T> {
public:
void MatMul(const framework::ExecutionContext& context,
const framework::Tensor& a, bool trans_a,
const framework::Tensor& b, bool trans_b,
framework::Tensor* out) const {
out->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(context);
auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a);
auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b);
if (a.dims().size() == 3 && b.dims().size() <= 2) {
// the transpose_X must be false, if is true, the transpose cost much time
if (!trans_a) {
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
}
}
blas.MatMul(a, mat_dim_a, b, mat_dim_b, static_cast<T>(1), out,
static_cast<T>(0));
}
void CalcInputGrad(const framework::ExecutionContext& context,
const framework::Tensor& a, bool trans_a,
bool is_fold_init_dims_a, const framework::Tensor& b,
bool trans_b, bool is_fold_init_dims_b,
framework::Tensor* out) const {
if (out == nullptr) return;
bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) &&
out->dims().size() == 2;
if (!need_combine) {
MatMul(context, a, trans_a, b, trans_b, out);
} else {
auto& ctx = context.template device_context<DeviceContext>();
MatMul(context, is_fold_init_dims_a
? FoldInitDims(a)
: FoldHeadAndLastDims<DeviceContext, T>(ctx, a),
trans_a, is_fold_init_dims_b
? FoldInitDims(b)
: FoldHeadAndLastDims<DeviceContext, T>(ctx, b),
trans_b, out);
}
}
void Compute(const framework::ExecutionContext& ctx) const override {
bool transpose_x = ctx.Attr<bool>("trans_x");
bool transpose_y = ctx.Attr<bool>("trans_y");
auto x = *ctx.Input<framework::Tensor>("X");
auto y = *ctx.Input<framework::Tensor>("Y");
auto dout = *ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor y_conj(y.type());
framework::Tensor x_conj(y.type());
// get dims
std::vector<std::int64_t> x_dims = vectorize(x.dims());
std::vector<std::int64_t> y_dims = vectorize(y.dims());
std::vector<std::int64_t> dout_dims = vectorize(dout.dims());
int x_ndim = x_dims.size();
int y_ndim = y_dims.size();
int ndim = dout_dims.size();
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
// Case1 : x's or y's dim = 1
if (x_ndim == 1 && y_ndim == 1) {
if (dx) dx->mutable_data<T>(ctx.GetPlace());
if (dy) dy->mutable_data<T>(ctx.GetPlace());
if (dout.numel() == 1) {
DotGradFunction<DeviceContext, T>()(&x, &y, &dout, dx, dy, ctx);
return;
}
}
bool is_broadcast = true;
if (x_ndim <= 2 || y_ndim <= 2) {
is_broadcast = false;
} else if (x_ndim != y_ndim) {
is_broadcast = true;
} else {
is_broadcast = !std::equal(x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2,
y_dims.cbegin());
}
// Case2: no broadcast or no batch size, it aims to speed and it is same as
// matmul in old version.
if (!is_broadcast) {
ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);
framework::DDim dx_dims;
if (dx) {
dx_dims = dx->dims();
if (dx_dims != x.dims()) {
dx->Resize(x.dims());
}
// for complex
ConjHelper<DeviceContext, T> conj_helper(ctx);
conj_helper(y, y_conj);
}
framework::DDim dy_dims;
if (dy) {
dy_dims = dy->dims();
if (dy_dims != y.dims()) {
dy->Resize(y.dims());
}
// for complex
ConjHelper<DeviceContext, T> conj_helper(ctx);
conj_helper(x, x_conj);
}
if (transpose_x && transpose_y) {
CalcInputGrad(ctx, y_conj, true, true, dout, true, false, dx);
CalcInputGrad(ctx, dout, true, true, x_conj, true, false, dy);
} else if (transpose_x) {
CalcInputGrad(ctx, y_conj, false, false, dout, true, false, dx);
CalcInputGrad(ctx, x_conj, false, false, dout, false, true, dy);
} else if (transpose_y) {
CalcInputGrad(ctx, dout, false, false, y_conj, false, true, dx);
CalcInputGrad(ctx, dout, true, true, x_conj, false, true, dy);
} else {
CalcInputGrad(ctx, dout, false, false, y_conj, true, false, dx);
CalcInputGrad(ctx, x_conj, true, true, dout, false, true, dy);
}
if (dx) {
if (dx_dims != x.dims()) {
dx->Resize(dx_dims);
}
}
if (dy) {
if (dy_dims != y.dims()) {
dy->Resize(dy_dims);
}
}
} else {
// Case3: broadcast. It need cost much time to reduce sum for the
// broadcast and wastes the memory.
// So we should avoid the case in reality.
VLOG(3) << "It need cost much time to reduce sum for the broadcast and "
"wastes the memory. So we should avoid the case in reality";
Tensor dx_help, dy_help;
ConjHelper<DeviceContext, T> conj_helper(ctx);
conj_helper(x, x_conj);
conj_helper(y, y_conj);
if (transpose_x) {
if (transpose_y) {
// X'Y': dA = Y'G', dB = G'X'
if (dx)
MatMulFunction<DeviceContext, T>(&y_conj, &dout, y_dims, dout_dims,
&dx_help, true, true, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(&dout, &x_conj, dout_dims, x_dims,
&dy_help, true, true, ctx);
} else {
// X'Y: dX = YG', dY = XG
if (dx)
MatMulFunction<DeviceContext, T>(&y_conj, &dout, y_dims, dout_dims,
&dx_help, false, true, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(&x_conj, &dout, x_dims, dout_dims,
&dy_help, false, false, ctx);
}
} else {
if (transpose_y) {
// XY': dX = GY, dY = G'X
if (dx)
MatMulFunction<DeviceContext, T>(&dout, &y_conj, dout_dims, y_dims,
&dx_help, false, false, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(&dout, &x_conj, dout_dims, x_dims,
&dy_help, true, false, ctx);
} else {
// XY: dX = GY', dY = X'G
if (dx)
MatMulFunction<DeviceContext, T>(&dout, &y_conj, dout_dims, y_dims,
&dx_help, false, true, ctx);
if (dy)
MatMulFunction<DeviceContext, T>(&x_conj, &dout, x_dims, dout_dims,
&dy_help, true, false, ctx);
}
}
// get help dims
const std::vector<std::int64_t> dx_help_dims = vectorize(dx_help.dims());
const std::vector<std::int64_t> dy_help_dims = vectorize(dy_help.dims());
std::vector<std::int64_t> dx_broadcast_dims(ndim);
std::vector<std::int64_t> dy_broadcast_dims(ndim);
std::fill(dx_broadcast_dims.data(),
dx_broadcast_dims.data() + ndim - x_ndim, 1);
std::fill(dy_broadcast_dims.data(),
dy_broadcast_dims.data() + ndim - y_ndim, 1);
std::copy(x_dims.data(), x_dims.data() + x_ndim,
dx_broadcast_dims.data() + ndim - x_ndim);
std::copy(y_dims.data(), y_dims.data() + y_ndim,
dy_broadcast_dims.data() + ndim - y_ndim);
std::vector<int> dx_reduce_dims;
std::vector<int> dy_reduce_dims;
for (int idx = 0; idx <= ndim - 3; idx++) {
if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) {
dx_reduce_dims.push_back(idx);
}
if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) {
dy_reduce_dims.push_back(idx);
}
}
// reduce sum to get grad by ReduceSum
if (dx) {
if (dx_reduce_dims.empty()) {
*dx = std::move(dx_help);
} else {
ReduceSumForMatmulGrad<DeviceContext, T>(&dx_help, dx, dx_reduce_dims,
ctx);
}
dx->Resize(x.dims());
}
if (dy) {
if (dy_reduce_dims.empty()) {
*dy = std::move(dy_help);
} else {
ReduceSumForMatmulGrad<DeviceContext, T>(&dy_help, dy, dy_reduce_dims,
ctx);
}
dy->Resize(y.dims());
}
// Get the OutputGrad(out)
}
}
};
template <typename DeviceContext, typename T>
class MatMulV2DoubleGradKernel : public framework::OpKernel<T> {
public:
void MatMul(const framework::ExecutionContext& context,
const framework::Tensor& a, bool trans_a,
const framework::Tensor& b, bool trans_b, framework::Tensor* out,
bool flag) const {
out->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(context);
auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a);
auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b);
if (a.dims().size() == 3 && b.dims().size() <= 2) {
// the transpose_X must be false, if is true, the transpose cost much time
if (!trans_a) {
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
}
}
blas.MatMul(a, mat_dim_a, b, mat_dim_b, static_cast<T>(1), out,
static_cast<T>(flag));
}
void CalcInputGrad(const framework::ExecutionContext& context,
const framework::Tensor& a, bool trans_a,
bool is_fold_init_dims_a, const framework::Tensor& b,
bool trans_b, bool is_fold_init_dims_b,
framework::Tensor* out, bool flag) const {
if (out == nullptr) return;
bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) &&
out->dims().size() == 2;
if (!need_combine) {
MatMul(context, a, trans_a, b, trans_b, out, flag);
} else {
auto& ctx = context.template device_context<DeviceContext>();
MatMul(context, is_fold_init_dims_a
? FoldInitDims(a)
: FoldHeadAndLastDims<DeviceContext, T>(ctx, a),
trans_a, is_fold_init_dims_b
? FoldInitDims(b)
: FoldHeadAndLastDims<DeviceContext, T>(ctx, b),
trans_b, out, flag);
}
}
void Compute(const framework::ExecutionContext& context) const override {
auto x = *context.Input<framework::Tensor>("X");
auto y = *context.Input<framework::Tensor>("Y");
auto dout = *context.Input<framework::Tensor>("DOut");
auto* ddx = context.Input<framework::Tensor>("DDX");
auto* ddy = context.Input<framework::Tensor>("DDY");
auto* dx = context.Output<framework::Tensor>("DX");
auto* dy = context.Output<framework::Tensor>("DY");
auto* ddout = context.Output<framework::Tensor>("DDOut");
bool transpose_x = context.Attr<bool>("trans_x");
bool transpose_y = context.Attr<bool>("trans_y");
// Get dims from the input x, y, output_grad
std::vector<std::int64_t> x_dims = vectorize(x.dims());
std::vector<std::int64_t> y_dims = vectorize(y.dims());
std::vector<std::int64_t> dout_dims = vectorize(dout.dims());
framework::Tensor x_conj(x.type());
framework::Tensor y_conj(y.type());
framework::Tensor dout_conj(dout.type());
int x_ndim = x_dims.size();
int y_ndim = y_dims.size();
int ndim = dout_dims.size();
// Case1 : x's or y's dim = 1
if (x_ndim == 1 && y_ndim == 1) {
DotDoubleGradFunction<DeviceContext, T>()(&x, &y, dx, dy, &dout, ddx, ddy,
ddout, context);
return;
}
bool is_broadcast = true;
if (x_ndim <= 2 || y_ndim <= 2) {
is_broadcast = false;
} else if (x_ndim != y_ndim) {
is_broadcast = true;
} else {
is_broadcast = !std::equal(x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2,
y_dims.cbegin());
}
if (!is_broadcast) {
// Case2: no broadcast or no batch size
ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);
framework::DDim dx_dims;
ConjHelper<DeviceContext, T> conj_helper(context);
if (dx) {
dx_dims = dx->dims();
if (dx_dims != x.dims()) {
dx->Resize(x.dims());
}
}
framework::DDim dy_dims;
if (dy) {
dy_dims = dy->dims();
if (dy_dims != y.dims()) {
dy->Resize(y.dims());
}
}
framework::DDim ddout_dims;
if (ddout) {
ddout_dims = ddout->dims();
if (ddout_dims != dout.dims()) {
ddout->Resize(dout.dims());
}
}
if (ddx || ddy) {
ConjHelper<DeviceContext, T> conj_helper(context);
conj_helper(dout, dout_conj);
}
if (ddout) {
ConjHelper<DeviceContext, T> conj_helper(context);
conj_helper(x, x_conj);
conj_helper(y, y_conj);
}
bool ddout_flag = false;
if (ddx) {
auto ddx_mat = *ddx;
if (ddx_mat.dims() != x.dims()) {
ddx_mat.Resize(x.dims());
}
if (dy) {
if (transpose_x && transpose_y) {
// dy = dout' * ddx'
CalcInputGrad(context, dout_conj, true, true, ddx_mat, true, false,
dy, false);
} else if (transpose_x) {
// dy = ddx * dout
CalcInputGrad(context, ddx_mat, false, false, dout_conj, false,
true, dy, false);
} else if (transpose_y) {
// dy = dout' * ddx
CalcInputGrad(context, dout_conj, true, true, ddx_mat, false, true,
dy, false);
} else {
// dy = ddx' * dout
CalcInputGrad(context, ddx_mat, true, true, dout_conj, false, true,
dy, false);
}
}
if (ddout) {
CalcInputGrad(context, ddx_mat, transpose_x, true, y_conj,
transpose_y, false, ddout, ddout_flag);
ddout_flag = true;
}
}
if (ddy) {
auto ddy_mat = *ddy;
if (ddy_mat.dims() != y.dims()) {
ddy_mat.Resize(y.dims());
}
if (dx) {
if (transpose_x && transpose_y) {
// dx = ddy' * dout'
CalcInputGrad(context, ddy_mat, true, true, dout_conj, true, false,
dx, false);
} else if (transpose_x) {
// dx = ddy * dout'
CalcInputGrad(context, ddy_mat, false, false, dout_conj, true,
false, dx, false);
} else if (transpose_y) {
// dx = dout * ddy
CalcInputGrad(context, dout_conj, false, false, ddy_mat, false,
true, dx, false);
} else {
// dx = dout * ddy'
CalcInputGrad(context, dout_conj, false, false, ddy_mat, true,
false, dx, false);
}
}
if (ddout) {
CalcInputGrad(context, x_conj, transpose_x, true, ddy_mat,
transpose_y, false, ddout, ddout_flag);
}
}
if (dx) {
if (dx_dims != x.dims()) {
dx->Resize(dx_dims);
}
}
if (dy) {
if (dy_dims != y.dims()) {
dy->Resize(dy_dims);
}
}
if (ddout) {
if (ddout_dims != dout.dims()) {
ddout->Resize(ddout_dims);
}
}
} else {
// Case3: broadcast. It need cost much time to reduce sum for the
// broadcast and wastes the memory.
// So we should avoid the case in reality.
VLOG(3) << "It need cost much time to reduce sum for the broadcast and "
"wastes the memory. So we should avoid the case in reality";
framework::Tensor ddy_conj(ddx->type());
framework::Tensor ddx_conj(ddy->type());
Tensor dx_help, dy_help;
if (dx || dy) {
ConjHelper<DeviceContext, T> conj_helper(context);
conj_helper(dout, dout_conj);
}
if (ddout) {
ConjHelper<DeviceContext, T> conj_helper(context);
conj_helper(x, x_conj);
conj_helper(y, y_conj);
}
if (transpose_x) {
if (transpose_y) {
if (dx)
MatMulFunction<DeviceContext, T>(ddy, &dout_conj, y_dims, dout_dims,
&dx_help, true, true, context);
if (dy)
MatMulFunction<DeviceContext, T>(&dout_conj, ddx, dout_dims, x_dims,
&dy_help, true, true, context);
} else {
if (dx)
MatMulFunction<DeviceContext, T>(ddy, &dout_conj, y_dims, dout_dims,
&dx_help, false, true, context);
if (dy)
MatMulFunction<DeviceContext, T>(ddx, &dout_conj, x_dims, dout_dims,
&dy_help, false, false, context);
}
} else {
if (transpose_y) {
if (dx)
MatMulFunction<DeviceContext, T>(&dout_conj, ddy, dout_dims, y_dims,
&dx_help, false, false, context);
if (dy)
MatMulFunction<DeviceContext, T>(&dout_conj, ddx, dout_dims, x_dims,
&dy_help, true, false, context);
} else {
if (dx)
MatMulFunction<DeviceContext, T>(&dout_conj, ddy, dout_dims, y_dims,
&dx_help, false, true, context);
if (dy)
MatMulFunction<DeviceContext, T>(ddx, &dout_conj, x_dims, dout_dims,
&dy_help, true, false, context);
}
}
// get help dims
const std::vector<std::int64_t> dx_help_dims = vectorize(dx_help.dims());
const std::vector<std::int64_t> dy_help_dims = vectorize(dy_help.dims());
std::vector<std::int64_t> dx_broadcast_dims(ndim);
std::vector<std::int64_t> dy_broadcast_dims(ndim);
std::fill(dx_broadcast_dims.data(),
dx_broadcast_dims.data() + ndim - x_ndim, 1);
std::fill(dy_broadcast_dims.data(),
dy_broadcast_dims.data() + ndim - y_ndim, 1);
std::copy(x_dims.data(), x_dims.data() + x_ndim,
dx_broadcast_dims.data() + ndim - x_ndim);
std::copy(y_dims.data(), y_dims.data() + y_ndim,
dy_broadcast_dims.data() + ndim - y_ndim);
std::vector<int> dx_reduce_dims;
std::vector<int> dy_reduce_dims;
for (int idx = 0; idx <= ndim - 3; idx++) {
if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) {
dx_reduce_dims.push_back(idx);
}
if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) {
dy_reduce_dims.push_back(idx);
}
}
// Reduce sum to get grad by ReduceSum
if (dx) {
if (dx_reduce_dims.empty()) {
*dx = std::move(dx_help);
} else {
ReduceSumForMatmulGrad<DeviceContext, T>(&dx_help, dx, dx_reduce_dims,
context);
}
dx->Resize(x.dims());
}
if (dy) {
if (dy_reduce_dims.empty()) {
*dy = std::move(dy_help);
} else {
ReduceSumForMatmulGrad<DeviceContext, T>(&dy_help, dy, dy_reduce_dims,
context);
}
dy->Resize(y.dims());
}
if (ddout) {
// Calculate the gradient of OutputGrad(Out)
MatMulFunction<DeviceContext, T>(ddx, &y_conj, x_dims, y_dims, ddout,
transpose_x, transpose_y, context);
MatMulFunction<DeviceContext, T>(&x_conj, ddy, x_dims, y_dims, ddout,
transpose_x, transpose_y, context,
true);
}
}
}
};
template <typename DeviceContext, typename T>
class MatMulV2TripleGradKernel : public framework::OpKernel<T> {
public:
void MatMul(const framework::ExecutionContext& context,
const framework::Tensor& a, bool trans_a,
const framework::Tensor& b, bool trans_b, framework::Tensor* out,
bool flag) const {
out->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(context);
auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a);
auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b);
if (a.dims().size() == 3 && b.dims().size() <= 2) {
// the transpose_X must be false, if is true, the transpose cost much time
if (!trans_a) {
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
}
}
blas.MatMul(a, mat_dim_a, b, mat_dim_b, static_cast<T>(1), out,
static_cast<T>(flag));
}
void CalcInputGrad(const framework::ExecutionContext& context,
const framework::Tensor& a, bool trans_a,
bool is_fold_init_dims_a, const framework::Tensor& b,
bool trans_b, bool is_fold_init_dims_b,
framework::Tensor* out, bool flag) const {
if (out == nullptr) return;
bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) &&
out->dims().size() == 2;
if (!need_combine) {
MatMul(context, a, trans_a, b, trans_b, out, flag);
} else {
auto& ctx = context.template device_context<DeviceContext>();
MatMul(context, is_fold_init_dims_a
? FoldInitDims(a)
: FoldHeadAndLastDims<DeviceContext, T>(ctx, a),
trans_a, is_fold_init_dims_b
? FoldInitDims(b)
: FoldHeadAndLastDims<DeviceContext, T>(ctx, b),
trans_b, out, flag);
}
}
void Compute(const framework::ExecutionContext& context) const override {
// get input
auto x = *context.Input<framework::Tensor>("X");
auto y = *context.Input<framework::Tensor>("Y");
auto dout = *context.Input<framework::Tensor>("DOut");
auto ddx = *context.Input<framework::Tensor>("DDX");
auto ddy = *context.Input<framework::Tensor>("DDY");
auto* d_dx = context.Input<framework::Tensor>("D_DX");
auto* d_dy = context.Input<framework::Tensor>("D_DY");
auto* d_ddout = context.Input<framework::Tensor>("D_DDOut");
// get output
auto* out_d_x = context.Output<framework::Tensor>("D_X_out");
auto* out_d_y = context.Output<framework::Tensor>("D_Y_out");
auto* out_d_dout = context.Output<framework::Tensor>("D_DOut_out");
auto* out_d_ddx = context.Output<framework::Tensor>("D_DDX_out");
auto* out_d_ddy = context.Output<framework::Tensor>("D_DDY_out");
bool transpose_x = context.Attr<bool>("trans_x");
bool transpose_y = context.Attr<bool>("trans_y");
// Get dims from the input x, y, output_grad
std::vector<std::int64_t> x_dims = vectorize(x.dims());
std::vector<std::int64_t> y_dims = vectorize(y.dims());
std::vector<std::int64_t> dout_dims = vectorize(dout.dims());
framework::Tensor x_conj(x.type());
framework::Tensor y_conj(y.type());
framework::Tensor dout_conj(dout.type());
framework::Tensor ddx_conj(ddx.type());
framework::Tensor ddy_conj(ddy.type());
int x_ndim = x_dims.size();
int y_ndim = y_dims.size();
int ndim = dout_dims.size();
// Case1 : x's and y's dim = 1
if (x_ndim == 1 && y_ndim == 1) {
VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 1";
DotTripleGradFunction<DeviceContext, T>()(
&x, &y, &ddx, &ddy, d_dx, d_dy, &dout, d_ddout, out_d_x, out_d_y,
out_d_dout, out_d_ddx, out_d_ddy, context);
return;
}
bool is_broadcast = true;
if (x_ndim <= 2 || y_ndim <= 2) {
is_broadcast = false;
} else if (x_ndim != y_ndim) {
is_broadcast = true;
} else {
is_broadcast = !std::equal(x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2,
y_dims.cbegin());
}
if (!is_broadcast) {
// Case2: no broadcast or no batch size
VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 2";
ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);
if (ddx.dims() != x.dims()) {
ddx.Resize(x.dims());
}
if (ddy.dims() != y.dims()) {
ddy.Resize(y.dims());
}
ConjHelper<DeviceContext, T> conj_helper(context);
framework::DDim out_dx_dims;
if (out_d_x) {
out_dx_dims = out_d_x->dims();
if (out_dx_dims != x.dims()) {
out_d_x->Resize(x.dims());
}
}
framework::DDim out_dy_dims;
if (out_d_y) {
out_dy_dims = out_d_y->dims();
if (out_dy_dims != y.dims()) {
out_d_y->Resize(y.dims());
}
}
framework::DDim out_d_dout_dims;
if (out_d_dout) {
out_d_dout_dims = out_d_dout->dims();
if (out_d_dout_dims != dout.dims()) {
out_d_dout->Resize(dout.dims());
}
}
framework::DDim out_d_ddx_dims;
if (out_d_ddx) {
out_d_ddx_dims = out_d_ddx->dims();
if (out_d_ddx_dims != x.dims()) {
out_d_ddx->Resize(x.dims());
}
}
framework::DDim out_d_ddy_dims;
if (out_d_ddy) {
out_d_ddy_dims = out_d_ddy->dims();
if (out_d_ddy_dims != y.dims()) {
out_d_ddy->Resize(y.dims());
}
}
if (out_d_dout) {
ConjHelper<DeviceContext, T> conj_helper(context);
conj_helper(ddx, ddx_conj);
conj_helper(ddy, ddy_conj);
}
if (out_d_ddx || out_d_ddy) {
ConjHelper<DeviceContext, T> conj_helper(context);
conj_helper(x, x_conj);
conj_helper(y, y_conj);
conj_helper(dout, dout_conj);
}
bool d_dout_flag = false;
bool d_ddx_flag = false;
bool d_ddy_flag = false;
if (d_ddout) {
auto d_ddout_mat = *d_ddout;
if (d_ddout_mat.dims() != dout.dims()) {
d_ddout_mat.Resize(dout.dims());
}
if (out_d_y) {
if (transpose_x && transpose_y) {
// out_d_y = d_ddout' * ddx'
CalcInputGrad(context, d_ddout_mat, true, true, ddx_conj, true,
false, out_d_y, false);
} else if (transpose_x) {
// out_d_y = ddx * d_ddout
CalcInputGrad(context, ddx_conj, false, false, d_ddout_mat, false,
true, out_d_y, false);
} else if (transpose_y) {
// out_d_y = d_ddout' * ddx
CalcInputGrad(context, d_ddout_mat, true, true, ddx_conj, false,
true, out_d_y, false);
} else {
// out_d_y = ddx' * d_ddout
CalcInputGrad(context, ddx_conj, true, true, d_ddout_mat, false,
true, out_d_y, false);
}
}
if (out_d_x) {
if (transpose_x && transpose_y) {
// out_d_x = ddy' * d_ddout'
CalcInputGrad(context, ddy_conj, true, true, d_ddout_mat, true,
false, out_d_x, false);
} else if (transpose_x) {
// out_d_x = ddy * d_ddout'
CalcInputGrad(context, ddy_conj, false, false, d_ddout_mat, true,
false, out_d_x, false);
} else if (transpose_y) {
// out_d_x = d_ddout * ddy
CalcInputGrad(context, d_ddout_mat, false, false, ddy_conj, false,
true, out_d_x, false);
} else {
// out_d_x = d_ddout * ddy'
CalcInputGrad(context, d_ddout_mat, false, false, ddy_conj, true,
false, out_d_x, false);
}
}
// equations:
// d_ddx = DOut * D_DY + Y * D_DDOut
// Let: d_ddx1 = Y * D_DDOut
// Let: d_ddx2 = DOut * D_DY
// d_ddy = DOut * D_DX + X * D_DDOut
// Let: d_ddy1 = X * D_DDOut
// Let: d_ddy2 = DOut * D_DX
// d_dout = DDY * D_DX + DDX * D_DY
// Let: d_dout1 = DDX * D_DY
// Let: d_dout2 = DDY * D_DX
// compute d_ddx1
if (out_d_ddx) {
if (transpose_x && transpose_y) {
// out_d_ddx1 = y' * d_ddout'
CalcInputGrad(context, y_conj, true, true, d_ddout_mat, true, false,
out_d_ddx, d_ddx_flag);
} else if (transpose_x) {
// out_d_ddx1 = y * d_ddout'
CalcInputGrad(context, y_conj, false, false, d_ddout_mat, true,
false, out_d_ddx, d_ddx_flag);
} else if (transpose_y) {
// out_d_ddx1 = d_ddout * y
CalcInputGrad(context, d_ddout_mat, false, false, y_conj, false,
true, out_d_ddx, d_ddx_flag);
} else {
// out_d_ddx1 = d_ddout * y'
CalcInputGrad(context, d_ddout_mat, false, false, y_conj, true,
false, out_d_ddx, d_ddx_flag);
}
d_ddx_flag = true;
}
// compute d_ddy1
if (out_d_ddy) {
if (transpose_x && transpose_y) {
// out_d_ddy1 = d_ddout' * x'
CalcInputGrad(context, d_ddout_mat, true, true, x_conj, true, false,
out_d_ddy, false);
} else if (transpose_x) {
// out_d_ddy1 = x * d_ddout
CalcInputGrad(context, x_conj, false, false, d_ddout_mat, false,
true, out_d_ddy, false);
} else if (transpose_y) {
// out_d_ddy1 = d_ddout' * x
CalcInputGrad(context, d_ddout_mat, true, true, x_conj, false, true,
out_d_ddy, false);
} else {
// out_d_ddy1 = x' * d_ddout
CalcInputGrad(context, x_conj, true, true, d_ddout_mat, false, true,
out_d_ddy, false);
}
d_ddy_flag = true;
}
}
if (d_dy) {
auto d_dy_mat = *d_dy;
if (d_dy_mat.dims() != y.dims()) {
d_dy_mat.Resize(y.dims());
}
// compute d_dout1
if (out_d_dout) {
CalcInputGrad(context, ddx_conj, transpose_x, true, d_dy_mat,
transpose_y, false, out_d_dout, d_dout_flag);
d_dout_flag = true;
}
// compute d_ddx2
if (out_d_ddx) {
if (transpose_x && transpose_y) {
// out_d_ddx2 = D_DY' * DOut'
CalcInputGrad(context, d_dy_mat, true, true, dout_conj, true, false,
out_d_ddx, d_ddx_flag);
} else if (transpose_x) {
// out_d_ddx2 = D_DY * Dout'
CalcInputGrad(context, d_dy_mat, false, false, dout_conj, true,
false, out_d_ddx, d_ddx_flag);
} else if (transpose_y) {
// out_d_ddx2 = Dout * D_DY
CalcInputGrad(context, dout_conj, false, false, d_dy_mat, false,
true, out_d_ddx, d_ddx_flag);
} else {
// out_d_ddx2 = Dout * D_DY'
CalcInputGrad(context, dout_conj, false, false, d_dy_mat, true,
false, out_d_ddx, d_ddx_flag);
}
}
}
if (d_dx) {
auto d_dx_mat = *d_dx;
if (d_dx_mat.dims() != x.dims()) {
d_dx_mat.Resize(x.dims());
}
// compute d_dout2
if (out_d_dout) {
CalcInputGrad(context, d_dx_mat, transpose_x, true, ddy_conj,
transpose_y, false, out_d_dout, d_dout_flag);
}
// compute d_ddy2
if (out_d_ddy) {
if (transpose_x && transpose_y) {
// out_d_ddy2 = dout' * d_dx'
CalcInputGrad(context, dout_conj, true, true, d_dx_mat, true, false,
out_d_ddy, d_ddy_flag);
} else if (transpose_x) {
// out_d_ddy2 = d_dx * dout
CalcInputGrad(context, d_dx_mat, false, false, dout_conj, false,
true, out_d_ddy, d_ddy_flag);
} else if (transpose_y) {
// out_d_ddy2 = dout' * d_dx
CalcInputGrad(context, dout_conj, true, true, d_dx_mat, false, true,
out_d_ddy, d_ddy_flag);
} else {
// out_d_ddy2 = d_dx' * dout
CalcInputGrad(context, d_dx_mat, true, true, dout_conj, false, true,
out_d_ddy, d_ddy_flag);
}
}
}
if (out_d_x) {
if (out_dx_dims != x.dims()) {
out_d_x->Resize(out_dx_dims);
}
}
if (out_d_y) {
if (out_dy_dims != y.dims()) {
out_d_y->Resize(out_dy_dims);
}
}
if (out_d_dout) {
if (out_d_dout_dims != dout.dims()) {
out_d_dout->Resize(out_d_dout_dims);
}
}
if (out_d_ddx) {
if (out_d_ddx_dims != x.dims()) {
out_d_ddx->Resize(out_d_ddx_dims);
}
}
if (out_d_ddy) {
if (out_d_ddy_dims != x.dims()) {
out_d_ddy->Resize(out_d_ddy_dims);
}
}
} else {
// Case3: broadcast. It need cost much time to reduce sum for the
// broadcast and wastes the memory.
// So we should avoid the case in reality.
VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 3";
VLOG(3) << "It need cost much time to reduce sum for the broadcast and "
"wastes the memory. So we should avoid the case in reality";
Tensor out_dx_help, out_dy_help;
Tensor out_d_ddx_help, out_d_ddy_help;
if (out_d_dout) {
ConjHelper<DeviceContext, T> conj_helper(context);
conj_helper(ddx, ddx_conj);
conj_helper(ddy, ddy_conj);
}
if (out_d_ddx || out_d_ddy) {
ConjHelper<DeviceContext, T> conj_helper(context);
conj_helper(x, x_conj);
conj_helper(y, y_conj);
conj_helper(dout, dout_conj);
}
if (transpose_x) {
if (transpose_y) {
// dX = ddY' d_ddout’, dY = d_ddout’ ddX'
if (out_d_x)
MatMulFunction<DeviceContext, T>(&ddy_conj, d_ddout, y_dims,
dout_dims, &out_dx_help, true,
true, context);
if (out_d_y)
MatMulFunction<DeviceContext, T>(d_ddout, &ddx_conj, dout_dims,
x_dims, &out_dy_help, true, true,
context);
} else {
// dX = ddY d_ddout', dY = ddX d_ddout
if (out_d_x)
MatMulFunction<DeviceContext, T>(&ddy_conj, d_ddout, y_dims,
dout_dims, &out_dx_help, false,
true, context);
if (out_d_y)
MatMulFunction<DeviceContext, T>(&ddx_conj, d_ddout, x_dims,
dout_dims, &out_dy_help, false,
false, context);
}
} else {
if (transpose_y) {
// dX = d_ddout ddY, dY = d_ddout’ ddX
if (out_d_x)
MatMulFunction<DeviceContext, T>(d_ddout, &ddy_conj, dout_dims,
y_dims, &out_dx_help, false, false,
context);
if (out_d_y)
MatMulFunction<DeviceContext, T>(d_ddout, &ddx_conj, dout_dims,
x_dims, &out_dy_help, true, false,
context);
} else {
// dX = d_ddout ddY', dY = ddX' d_ddout
if (out_d_x)
MatMulFunction<DeviceContext, T>(d_ddout, &ddy_conj, dout_dims,
y_dims, &out_dx_help, false, true,
context);
if (out_d_y)
MatMulFunction<DeviceContext, T>(&ddx_conj, d_ddout, x_dims,
dout_dims, &out_dy_help, true,
false, context);
}
}
// get help dims
const std::vector<std::int64_t> dx_help_dims =
vectorize(out_dx_help.dims());
const std::vector<std::int64_t> dy_help_dims =
vectorize(out_dx_help.dims());
std::vector<std::int64_t> dx_broadcast_dims(ndim);
std::vector<std::int64_t> dy_broadcast_dims(ndim);
std::fill(dx_broadcast_dims.data(),
dx_broadcast_dims.data() + ndim - x_ndim, 1);
std::fill(dy_broadcast_dims.data(),
dy_broadcast_dims.data() + ndim - y_ndim, 1);
std::copy(x_dims.data(), x_dims.data() + x_ndim,
dx_broadcast_dims.data() + ndim - x_ndim);
std::copy(y_dims.data(), y_dims.data() + y_ndim,
dy_broadcast_dims.data() + ndim - y_ndim);
std::vector<int> dx_reduce_dims;
std::vector<int> dy_reduce_dims;
for (int idx = 0; idx <= ndim - 3; idx++) {
if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) {
dx_reduce_dims.push_back(idx);
}
if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) {
dy_reduce_dims.push_back(idx);
}
}
// Reduce sum to get grad by ReduceSum
if (out_d_x) {
if (dx_reduce_dims.empty()) {
*out_d_x = std::move(out_dx_help);
} else {
ReduceSumForMatmulGrad<DeviceContext, T>(&out_dx_help, out_d_x,
dx_reduce_dims, context);
}
out_d_x->Resize(x.dims());
}
if (out_d_y) {
if (dy_reduce_dims.empty()) {
*out_d_y = std::move(out_dy_help);
} else {
ReduceSumForMatmulGrad<DeviceContext, T>(&out_dy_help, out_d_y,
dy_reduce_dims, context);
}
out_d_y->Resize(y.dims());
}
// compute d_dout
if (out_d_dout) {
MatMulFunction<DeviceContext, T>(d_dx, &ddy_conj, x_dims, y_dims,
out_d_dout, transpose_x, transpose_y,
context);
MatMulFunction<DeviceContext, T>(&ddx_conj, d_dy, x_dims, y_dims,
out_d_dout, transpose_x, transpose_y,
context, true);
}
// compute d_ddx
if (out_d_ddx) {
if (transpose_x && transpose_y) {
// out_d_ddx1 = y' * d_ddout'
MatMulFunction<DeviceContext, T>(&y_conj, d_ddout, y_dims, dout_dims,
&out_d_ddx_help, true, true,
context);
// out_d_ddx2 = D_DY' * DOut'
MatMulFunction<DeviceContext, T>(d_dy, &dout_conj, y_dims, dout_dims,
&out_d_ddx_help, true, true, context,
true);
} else if (transpose_x) {
// out_d_ddx1 = y * d_ddout'
MatMulFunction<DeviceContext, T>(&y_conj, d_ddout, y_dims, dout_dims,
&out_d_ddx_help, false, true,
context);
// out_d_ddx2 = D_DY * Dout'
MatMulFunction<DeviceContext, T>(d_dy, &dout_conj, y_dims, dout_dims,
&out_d_ddx_help, false, true,
context, true);
} else if (transpose_y) {
// out_d_ddx1 = d_ddout * y
MatMulFunction<DeviceContext, T>(d_ddout, &y_conj, dout_dims, y_dims,
&out_d_ddx_help, false, false,
context);
// out_d_ddx2 = Dout * D_DY
MatMulFunction<DeviceContext, T>(&dout_conj, d_dy, dout_dims, y_dims,
&out_d_ddx_help, false, false,
context, true);
} else {
// out_d_ddx1 = d_ddout * y'
MatMulFunction<DeviceContext, T>(d_ddout, &y_conj, dout_dims, y_dims,
&out_d_ddx_help, false, true,
context);
// out_d_ddx2 = Dout * D_DY'
MatMulFunction<DeviceContext, T>(&dout_conj, d_dy, dout_dims, y_dims,
&out_d_ddx_help, false, true,
context, true);
}
if (dx_reduce_dims.empty()) {
*out_d_ddx = std::move(out_d_ddx_help);
} else {
ReduceSumForMatmulGrad<DeviceContext, T>(&out_d_ddx_help, out_d_ddx,
dx_reduce_dims, context);
}
out_d_ddx->Resize(x.dims());
}
// compute d_ddy
if (out_d_ddy) {
if (transpose_x && transpose_y) {
// out_d_ddy1 = d_ddout' * x'
MatMulFunction<DeviceContext, T>(d_ddout, &x_conj, dout_dims, x_dims,
&out_d_ddy_help, true, true,
context);
// out_d_ddy2 = dout' * d_dx'
MatMulFunction<DeviceContext, T>(&dout_conj, d_dx, dout_dims, x_dims,
&out_d_ddy_help, true, true, context,
true);
} else if (transpose_x) {
// out_d_ddy1 = x * d_ddout
MatMulFunction<DeviceContext, T>(&x_conj, d_ddout, x_dims, dout_dims,
&out_d_ddy_help, false, false,
context);
// out_d_ddy2 = d_dx * dout
MatMulFunction<DeviceContext, T>(d_dx, &dout_conj, x_dims, dout_dims,
&out_d_ddy_help, false, false,
context, true);
} else if (transpose_y) {
// out_d_ddy1 = d_ddout' * x
MatMulFunction<DeviceContext, T>(d_ddout, &x_conj, dout_dims, x_dims,
&out_d_ddy_help, true, false,
context);
// out_d_ddy2 = dout' * d_dx
MatMulFunction<DeviceContext, T>(&dout_conj, d_dx, dout_dims, x_dims,
&out_d_ddy_help, true, false,
context, true);
} else {
// out_d_ddy1 = x' * d_ddout
MatMulFunction<DeviceContext, T>(&x_conj, d_ddout, x_dims, dout_dims,
&out_d_ddy_help, true, false,
context);
// out_d_ddy2 = d_dx' * dout
MatMulFunction<DeviceContext, T>(d_dx, &dout_conj, x_dims, dout_dims,
&out_d_ddy_help, true, false,
context, true);
}
if (dy_reduce_dims.empty()) {
*out_d_ddy = std::move(out_d_ddy_help);
} else {
ReduceSumForMatmulGrad<DeviceContext, T>(&out_d_ddy_help, out_d_ddy,
dy_reduce_dims, context);
}
out_d_ddy->Resize(y.dims());
}
}
if (out_d_x) out_d_x->mutable_data<T>(context.GetPlace());
if (out_d_y) out_d_y->mutable_data<T>(context.GetPlace());
if (out_d_dout) out_d_dout->mutable_data<T>(context.GetPlace());
if (out_d_ddx) out_d_ddx->mutable_data<T>(context.GetPlace());
if (out_d_ddy) out_d_ddy->mutable_data<T>(context.GetPlace());
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_dout = paddle::experimental::MakePtenDenseTensor(*dout);
auto pt_ddx = paddle::experimental::MakePtenDenseTensor(*ddx);
auto pt_ddy = paddle::experimental::MakePtenDenseTensor(*ddy);
auto pt_d_dx = paddle::experimental::MakePtenDenseTensor(*d_dx);
auto pt_d_dy = paddle::experimental::MakePtenDenseTensor(*d_dy);
auto pt_d_ddout = paddle::experimental::MakePtenDenseTensor(*d_ddout);
auto pt_out_d_x = paddle::experimental::MakePtenDenseTensor(*out_d_x);
auto pt_out_d_y = paddle::experimental::MakePtenDenseTensor(*out_d_y);
auto pt_out_d_dout = paddle::experimental::MakePtenDenseTensor(*out_d_dout);
auto pt_out_d_ddx = paddle::experimental::MakePtenDenseTensor(*out_d_ddx);
auto pt_out_d_ddy = paddle::experimental::MakePtenDenseTensor(*out_d_ddy);
auto& dev_ctx = context.device_context<DeviceContext>();
// call new kernel
pten::MatmulTripleGradKernel<T>(dev_ctx, *pt_x, *pt_y, *pt_dout, *pt_ddx,
*pt_ddy, *pt_d_dx, *pt_d_dy, *pt_d_ddout,
transpose_x, transpose_y, pt_out_d_x.get(),
pt_out_d_y.get(), pt_out_d_dout.get(),
pt_out_d_ddx.get(), pt_out_d_ddy.get());
}
};
......
......@@ -70,6 +70,12 @@ DenseTensor& DenseTensor::operator=(const DenseTensor& other) {
return *this;
}
DenseTensor& DenseTensor::operator=(DenseTensor&& other) {
meta_ = std::move(other.meta_);
storage_.swap(other.storage_);
return *this;
}
int64_t DenseTensor::numel() const {
if (meta_.is_scalar) {
return 1;
......
......@@ -97,6 +97,8 @@ class DenseTensor : public TensorBase,
/// \brief DenseTensor shallow copy assignment.
DenseTensor& operator=(const DenseTensor& other);
DenseTensor& operator=(DenseTensor&& other);
/// \brief Destroy the tensor object and release exclusive resources.
virtual ~DenseTensor() = default;
......
......@@ -29,6 +29,9 @@ const std::unordered_map<std::string, std::string> kernel_alias_name_map = {
{"flatten_contiguous_range", "flatten"},
{"flatten_contiguous_range_grad", "flatten_grad"},
{"matmul_v2", "matmul"},
{"matmul_v2_grad", "matmul_grad"},
{"matmul_v2_grad_grad", "matmul_double_grad"},
{"matmul_v2_triple_grad", "matmul_triple_grad"},
{"reduce_mean", "mean"},
{"reduce_sum", "sum"},
{"reshape2", "reshape"},
......@@ -36,6 +39,8 @@ const std::unordered_map<std::string, std::string> kernel_alias_name_map = {
{"flatten", "deprecated"},
{"flatten_grad", "deprecated"},
{"matmul", "deprecated"},
{"matmul_grad", "deprecated"},
{"matmul_grad_grad", "deprecated"},
{"mean", "deprecated"},
{"reshape", "deprecated"},
{"sum", "deprecated"}};
......
......@@ -50,6 +50,11 @@ void KernelContext::EmplaceBackOutputWithoutSetRange(
outputs_.emplace_back(std::move(output));
}
void KernelContext::SetOutputWithoutSetRange(
int index, std::shared_ptr<TensorBase> output) {
outputs_.at(index) = std::move(output);
}
void KernelContext::EmplaceBackOutputs(
paddle::SmallVector<std::shared_ptr<TensorBase>> outputs) {
int index = outputs_.size();
......@@ -119,9 +124,11 @@ void KernelContext::ClearData() {
}
}
for (auto& out : outputs_) {
if (out) {
CompatibleDenseTensorUtils::ClearStorage(
static_cast<DenseTensor*>(out.get()));
}
}
attrs_.clear();
}
} // namespace pten
......@@ -62,6 +62,8 @@ class KernelContext {
void EmplaceBackOutputWithoutSetRange(std::shared_ptr<TensorBase> output);
void SetOutputWithoutSetRange(int index, std::shared_ptr<TensorBase> output);
void EmplaceBackOutputs(
paddle::SmallVector<std::shared_ptr<TensorBase>> outputs);
......@@ -80,6 +82,14 @@ class KernelContext {
return static_cast<const TensorType&>(*(inputs_.at(idx)));
}
template <typename TensorType>
paddle::optional<const TensorType&> OptionalInputAt(size_t idx) const {
const auto& input = inputs_.at(idx);
return input ? paddle::optional<const TensorType&>{static_cast<
const TensorType&>(*input)}
: paddle::optional<const TensorType&>{paddle::none};
}
std::shared_ptr<TensorBase>& MutableInputPtrAt(size_t idx) {
return inputs_.at(idx);
}
......
......@@ -65,6 +65,10 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
} else if (arg_type == std::type_index(typeid(const DenseTensor&))) {
args_def->AppendInput(
default_key.backend(), default_tensor_layout, default_key.dtype());
} else if (arg_type == std::type_index(typeid(
paddle::optional<const DenseTensor&>))) {
args_def->AppendInput(
default_key.backend(), default_tensor_layout, default_key.dtype());
} else if (arg_type ==
std::type_index(typeid(const std::vector<DenseTensor>&))) {
args_def->AppendInput(
......
......@@ -77,6 +77,27 @@ namespace pten {
} \
}
#define PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(tensor_type) \
template <typename... Tail> \
struct KernelCallHelper<paddle::optional<const tensor_type&>, Tail...> { \
template <int dev_ctx_idx, \
int in_idx, \
int attr_idx, \
int out_idx, \
typename... PreviousArgs> \
static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \
static_assert(attr_idx == 0, \
"Kernel's Input should appear before Attributes."); \
static_assert(out_idx == 0, \
"Kernel's Input should appear before Outputs."); \
const std::pair<int, int> range = ctx->InputRangeAt(in_idx); \
auto arg = ctx->OptionalInputAt<tensor_type>(range.first); \
KernelCallHelper<Tail...>:: \
template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>( \
ctx, pargs..., arg); \
} \
}
#define PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(tensor_type) \
template <typename... Tail> \
struct KernelCallHelper<const std::vector<tensor_type>&, Tail...> { \
......@@ -190,6 +211,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
/* Input Helpers */
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor);
// TODO(chenweihang): adapt SelectedRows
// PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRowsTensor);
......
......@@ -30,7 +30,7 @@ DenseTensor Dot(const ContextT& dev_ctx,
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
std::move(out_meta));
Dot<T, ContextT>(dev_ctx, x, y, &dense_out);
DotKernel<T, ContextT>(dev_ctx, x, y, &dense_out);
return dense_out;
}
......
......@@ -48,15 +48,4 @@ DenseTensor Scale(const ContextT& dev_ctx,
return dense_out;
}
template <typename T, typename ContextT>
DenseTensor Conj(const ContextT& dev_ctx, const DenseTensor& x) {
auto out_meta = UnchangedInferMeta(x.meta());
pten::DenseTensor dense_out(
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
std::move(out_meta));
Conj<T>(dev_ctx, x, &dense_out);
return dense_out;
}
} // namespace pten
......@@ -16,9 +16,20 @@ limitations under the License. */
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/infermeta/unary.h"
#include "paddle/pten/kernels/empty_kernel.h"
namespace pten {
template <typename T, typename Context>
void Conj(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out);
void ConjKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out);
template <typename T, typename Context>
DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) {
auto out_meta = UnchangedInferMeta(x.meta());
auto dense_out = Empty<T, Context>(dev_ctx, std::move(out_meta));
ConjKernel<T>(dev_ctx, x, &dense_out);
return dense_out;
}
} // namespace pten
......@@ -24,7 +24,7 @@
PT_REGISTER_CTX_KERNEL(conj,
CPU,
ALL_LAYOUT,
pten::Conj,
pten::ConjKernel,
paddle::platform::complex<float>,
paddle::platform::complex<double>,
float,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/dot_grad_kernel.h"
#include "paddle/pten/kernels/impl/dot_grad_kernel_impl.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_CTX_KERNEL(dot_grad,
CPU,
ALL_LAYOUT,
pten::DotGradKernel,
float,
double,
int,
int64_t,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
......@@ -23,7 +23,7 @@
namespace pten {
template <typename T, typename Context>
void Dot(const Context& dev_ctx,
void DotKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
......@@ -52,7 +52,7 @@ using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_CTX_KERNEL(dot,
CPU,
ALL_LAYOUT,
pten::Dot,
pten::DotKernel,
float,
double,
int,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/kernels/matmul_grad_kernel.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(matmul_grad,
CPU,
ALL_LAYOUT,
pten::MatmulGradKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(matmul_double_grad,
CPU,
ALL_LAYOUT,
pten::MatmulDoubleGradKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(matmul_triple_grad,
CPU,
ALL_LAYOUT,
pten::MatmulTripleGradKernel,
float,
double,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/core/dense_tensor.h"
namespace pten {
template <typename T, typename Context>
void DotGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy);
template <typename T, typename Context>
void DotDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& ddx,
const DenseTensor& ddy,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy,
DenseTensor* ddout);
template <typename T, typename Context>
void DotTripleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& ddx,
const DenseTensor& ddy,
const DenseTensor& d_dx,
const DenseTensor& d_dy,
const DenseTensor& dout,
const DenseTensor& d_ddout,
DenseTensor* d_x,
DenseTensor* d_y,
DenseTensor* d_ddx,
DenseTensor* d_ddy,
DenseTensor* d_dout);
} // namespace pten
......@@ -19,7 +19,7 @@
namespace pten {
template <typename T, typename Context>
void Dot(const Context& dev_ctx,
void DotKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/kernels/empty_kernel.h"
#include "paddle/pten/backends/all_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/fluid/platform/complex.h"
namespace pten {
template <typename T, typename ContextT>
void EmptyKernel(const ContextT& dev_ctx,
template <typename T, typename Context>
void EmptyKernel(const Context& dev_ctx,
const ScalarArray& shape,
DenseTensor* out) {
out->Resize(paddle::framework::make_ddim(shape.GetData()));
}
template <typename T, typename ContextT>
void EmptyLikeKernel(const ContextT& dev_ctx, DenseTensor* out) {
template <typename T, typename Context>
void EmptyLikeKernel(const Context& dev_ctx, DenseTensor* out) {
out->mutable_data<T>();
}
......@@ -37,44 +38,62 @@ PT_REGISTER_CTX_KERNEL(empty,
CPU,
ALL_LAYOUT,
pten::EmptyKernel,
bool,
int,
int64_t,
float,
double,
paddle::platform::float16) {}
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(empty_like,
CPU,
ALL_LAYOUT,
pten::EmptyLikeKernel,
bool,
int,
int64_t,
float,
double,
paddle::platform::float16) {}
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_CTX_KERNEL(empty,
GPU,
ALL_LAYOUT,
pten::EmptyKernel,
bool,
int,
int64_t,
float,
double,
paddle::platform::float16) {}
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(empty_like,
GPU,
ALL_LAYOUT,
pten::EmptyLikeKernel,
bool,
int,
int64_t,
float,
double,
paddle::platform::float16) {}
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
#endif
......@@ -41,6 +41,14 @@ DenseTensor Empty(const Context& dev_ctx, DenseTensorMeta&& meta) {
return dense_out;
}
template <typename T, typename Context>
DenseTensor Empty(const Context& dev_ctx) {
return Empty<T, Context>(dev_ctx,
{paddle::experimental::CppTypeToDataType<T>::Type(),
{-1},
DataLayout::NCHW});
}
template <typename T, typename Context>
DenseTensor Empty(const Context& dev_ctx,
const ScalarArray& shape,
......
......@@ -24,7 +24,8 @@
PT_REGISTER_CTX_KERNEL(conj,
GPU,
ALL_LAYOUT,
pten::Conj,
pten::ConjKernel,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>,
float,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/kernels/dot_grad_kernel.h"
#include "paddle/pten/kernels/impl/dot_grad_kernel_impl.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/fluid/platform/complex.h"
PT_REGISTER_CTX_KERNEL(dot_grad,
GPU,
ALL_LAYOUT,
pten::DotGradKernel,
float,
double,
int,
int64_t,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
......@@ -25,7 +25,7 @@
namespace pten {
template <typename T, typename Context>
void Dot(const Context& dev_ctx,
void DotKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
......@@ -55,7 +55,7 @@ using complex128 = ::paddle::platform::complex<double>;
PT_REGISTER_CTX_KERNEL(dot,
GPU,
ALL_LAYOUT,
pten::Dot,
pten::DotKernel,
float,
double,
int,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/kernels/matmul_grad_kernel.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/matmul_grad_kernel_impl.h"
PT_REGISTER_CTX_KERNEL(matmul_grad,
GPU,
ALL_LAYOUT,
pten::MatmulGradKernel,
float,
double,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(matmul_double_grad,
GPU,
ALL_LAYOUT,
pten::MatmulDoubleGradKernel,
float,
double,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_CTX_KERNEL(matmul_triple_grad,
GPU,
ALL_LAYOUT,
pten::MatmulTripleGradKernel,
float,
double,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
......@@ -17,6 +17,9 @@
#include "paddle/fluid/framework/ddim.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/pten/kernels/hybird/eigen/common.h"
namespace pten {
namespace math {
......@@ -30,5 +33,30 @@ struct TransposeNormal {
const std::vector<int64_t>& axis);
};
template <typename DeviceContext, typename T, int Rank>
struct Transpose {
void operator()(const DeviceContext& dev_ctx,
const DenseTensor& in,
DenseTensor* out,
const std::vector<int>& axis) {
Eigen::array<int, Rank> permute;
for (int i = 0; i < Rank; i++) {
permute[i] = axis[i];
}
auto eigen_in = pten::EigenTensor<T, Rank>::From(in);
auto eigen_out = pten::EigenTensor<T, Rank>::From(*out);
auto* dev = dev_ctx.eigen_device();
// use 32bit index to speed up computation
bool use_32bit_index = eigen_out.size() < Eigen::NumTraits<int>::highest();
bool is_gpu_place = paddle::platform::is_gpu_place(dev_ctx.GetPlace());
if (use_32bit_index && is_gpu_place) {
To32BitIndex(eigen_out).device(*dev) =
To32BitIndex(eigen_in).shuffle(permute);
} else {
eigen_out.device(*dev) = eigen_in.shuffle(permute);
}
}
};
} // namespace math
} // namespace pten
......@@ -21,12 +21,14 @@
namespace pten {
template <typename T, typename Context>
void Conj(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) {
void ConjKernel(const Context& context,
const DenseTensor& x,
DenseTensor* out) {
auto numel = x.numel();
auto* x_data = x.data<T>();
auto* out_data = out->mutable_data<T>();
paddle::platform::ForRange<Context> for_range(dev_ctx, numel);
paddle::platform::ForRange<Context> for_range(context, numel);
paddle::operators::math::ConjFunctor<T> functor(x_data, numel, out_data);
for_range(functor);
}
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/hybird/eigen/common.h"
#include "paddle/pten/kernels/complex_kernel.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/math/complex_functors.h"
namespace pten {
template <typename DeviceContext, typename T, typename Enabel = void>
struct DotGradFunction {
void operator()(const DeviceContext& ctx,
const DenseTensor* tensor_x,
const DenseTensor* tensor_y,
const DenseTensor* tensor_dout,
DenseTensor* tensor_dx,
DenseTensor* tensor_dy);
};
template <typename DeviceContext, typename T>
struct DotGradFunction<DeviceContext,
T,
paddle::operators::math::EnableComplex<T>> {
void operator()(const DeviceContext& ctx,
const DenseTensor* tensor_x,
const DenseTensor* tensor_y,
const DenseTensor* tensor_dout,
DenseTensor* tensor_dx,
DenseTensor* tensor_dy) {
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) {
auto dout = EigenVector<T>::Flatten(*tensor_dout);
if (tensor_dx) {
auto y = EigenVector<T>::Flatten(*tensor_y);
auto& dev = *ctx.eigen_device();
Eigen::DSizes<int, 1> size(tensor_dx->numel());
ConjKernel<T, DeviceContext>(ctx, *tensor_y, tensor_dx);
auto dx = EigenVector<T>::Flatten(*tensor_dx);
dx.device(dev) = dx * dout.broadcast(size);
}
if (tensor_dy) {
auto x = EigenVector<T>::Flatten(*tensor_x);
auto& dev = *ctx.eigen_device();
Eigen::DSizes<int, 1> size(tensor_dy->numel());
ConjKernel<T, DeviceContext>(ctx, *tensor_x, tensor_dy);
auto dy = EigenVector<T>::Flatten(*tensor_dy);
dy.device(dev) = dy * dout.broadcast(size);
}
} else {
auto dout = EigenMatrix<T>::From(*tensor_dout);
if (tensor_dx) {
tensor_dx->mutable_data<T>();
auto y = EigenMatrix<T>::From(*tensor_y);
auto& dev = *ctx.eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
ConjKernel<T, DeviceContext>(ctx, *tensor_y, tensor_dx);
auto dx = EigenMatrix<T>::From(*tensor_dx);
dx.device(dev) = dx * dout.broadcast(size);
}
if (tensor_dy) {
tensor_dy->mutable_data<T>();
auto x = EigenMatrix<T>::From(*tensor_x);
auto& dev = *ctx.eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
ConjKernel<T, DeviceContext>(ctx, *tensor_x, tensor_dy);
auto dy = EigenMatrix<T>::From(*tensor_dy);
dy.device(dev) = dy * dout.broadcast(size);
}
}
#else
const auto* data_dout = tensor_dout->data<T>();
if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>();
const auto* data_y = tensor_y->data<T>();
const DDim& dim = tensor_x->dims();
size_t N = static_cast<size_t>(paddle::framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = T(data_y[i].real, -data_y[i].imag) * data_dout[s];
}
}
if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>();
const auto* data_x = tensor_x->data<T>();
const DDim& dim = tensor_y->dims();
size_t N = static_cast<size_t>(paddle::framework::product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = T(data_x[i].real, -data_x[i].imag) * data_dout[s];
}
}
#endif
}
};
template <typename DeviceContext, typename T>
struct DotGradFunction<DeviceContext,
T,
paddle::operators::math::DisableComplex<T>> {
void operator()(const DeviceContext& ctx,
const DenseTensor* tensor_x,
const DenseTensor* tensor_y,
const DenseTensor* tensor_dout,
DenseTensor* tensor_dx,
DenseTensor* tensor_dy) {
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) {
auto dout = EigenVector<T>::Flatten(*tensor_dout);
if (tensor_dx) {
auto y = EigenVector<T>::Flatten(*tensor_y);
auto dx = EigenVector<T>::Flatten(*tensor_dx);
auto& dev = *ctx.eigen_device();
Eigen::DSizes<int, 1> size(tensor_dx->numel());
dx.device(dev) = y * dout.broadcast(size);
}
if (tensor_dy) {
auto x = EigenVector<T>::Flatten(*tensor_x);
auto dy = EigenVector<T>::Flatten(*tensor_dy);
auto& dev = *ctx.eigen_device();
Eigen::DSizes<int, 1> size(tensor_dy->numel());
dy.device(dev) = x * dout.broadcast(size);
}
} else {
auto dout = EigenMatrix<T>::From(*tensor_dout);
if (tensor_dx) {
tensor_dx->mutable_data<T>();
auto y = EigenMatrix<T>::From(*tensor_y);
auto dx = EigenMatrix<T>::From(*tensor_dx);
auto& dev = *ctx.eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dx->dims()[1]);
dx.device(dev) = y * dout.broadcast(size);
}
if (tensor_dy) {
tensor_dy->mutable_data<T>();
auto x = EigenMatrix<T>::From(*tensor_x);
auto dy = EigenMatrix<T>::From(*tensor_dy);
auto& dev = *ctx.eigen_device();
Eigen::DSizes<int, 2> size(1, tensor_dy->dims()[1]);
dy.device(dev) = x * dout.broadcast(size);
}
}
#else
auto const *x = tensor_x->data<T>(), *y = tensor_y->data<T>(),
*dz = tensor_dout->data<T>();
auto&& d = tensor_x->dims();
auto const N = tensor_x->numel();
auto const B = d[d.size() - 1];
if (tensor_dx) {
auto* dx = tensor_dx->mutable_data<T>();
for (auto j = 0; j < N / B; ++j) {
auto const ss = dz[j];
for (auto i = 0; i < B; ++i) *dx++ = *y++ * ss;
}
}
if (tensor_dy) {
auto* dy = tensor_dy->mutable_data<T>();
for (auto j = 0; j < N / B; ++j) {
auto const ss = dz[j];
for (auto i = 0; i < B; i++) *dy++ = *x++ * ss;
}
}
#endif
}
};
template <typename DeviceContext, typename T, typename Enabel = void>
struct DotDoubleGradFunction {
void operator()(const DeviceContext& ctx,
const DenseTensor* tensor_x,
const DenseTensor* tensor_y,
const DenseTensor* tensor_dout,
const DenseTensor* tensor_ddx,
const DenseTensor* tensor_ddy,
DenseTensor* tensor_dx,
DenseTensor* tensor_dy,
DenseTensor* tensor_ddout);
};
template <typename DeviceContext, typename T>
struct DotDoubleGradFunction<DeviceContext,
T,
paddle::operators::math::EnableComplex<T>> {
void operator()(const DeviceContext& ctx,
const DenseTensor* tensor_x,
const DenseTensor* tensor_y,
const DenseTensor* tensor_dout,
const DenseTensor* tensor_ddx,
const DenseTensor* tensor_ddy,
DenseTensor* tensor_dx,
DenseTensor* tensor_dy,
DenseTensor* tensor_ddout) {
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) {
DenseTensor tensor_dout_help;
auto& dev = *ctx.eigen_device();
if (tensor_dx || tensor_dy) {
tensor_dout_help = Conj<T, DeviceContext>(ctx, *tensor_dout);
}
if (tensor_dx) {
auto ddy = EigenVector<T>::Flatten(*tensor_ddy);
Eigen::DSizes<int, 1> size(tensor_ddy->numel());
auto dx = EigenVector<T>::Flatten(*tensor_dx);
auto dout = EigenVector<T>::Flatten(tensor_dout_help);
dx.device(dev) = ddy * dout.broadcast(size);
}
if (tensor_dy) {
auto ddx = EigenVector<T>::Flatten(*tensor_ddx);
Eigen::DSizes<int, 1> size(tensor_ddx->numel());
auto dy = EigenVector<T>::Flatten(*tensor_dy);
auto dout = EigenVector<T>::Flatten(tensor_dout_help);
dy.device(dev) = ddx * dout.broadcast(size);
}
if (tensor_ddout) {
DenseTensor tensor_x_help = Conj<T, DeviceContext>(ctx, *tensor_x);
DenseTensor tensor_y_help = Conj<T, DeviceContext>(ctx, *tensor_y);
auto x = EigenVector<T>::Flatten(tensor_x_help);
auto y = EigenVector<T>::Flatten(tensor_y_help);
auto ddx = EigenVector<T>::Flatten(*tensor_ddx);
auto ddy = EigenVector<T>::Flatten(*tensor_ddy);
auto ddout = EigenVector<T>::Flatten(*tensor_ddout);
ddout.device(dev) = (x * ddy + y * ddx).sum();
}
}
#else
const auto* data_dout = tensor_dout->data<T>();
if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>();
const auto* data_ddy = tensor_ddy->data<T>();
const DDim& dim = tensor_dx->dims();
size_t N = static_cast<size_t>(product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = T(data_dout[s].real, -data_dout[s].imag) * data_ddy[i];
}
}
if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>();
const auto* data_ddx = tensor_ddx->data<T>();
const DDim& dim = tensor_dy->dims();
size_t N = static_cast<size_t>(product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = T(data_dout[s].real, -data_dout[s].imag) * data_ddx[i];
}
}
if (tensor_ddout) {
auto* data_ddout = tensor_ddout->mutable_data<T>();
auto* data_x = tensor_x->data<T>();
auto* data_y = tensor_y->data<T>();
auto* data_ddx = tensor_ddx->data<T>();
auto* data_ddy = tensor_ddy->data<T>();
const DDim& dim = tensor_dy->dims();
size_t N = static_cast<size_t>(product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
bool new_s = false;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) {
++s;
new_s = true;
}
if (new_s) {
data_ddout[s] = T(data_x[i].real, -data_x[i].imag) * data_ddy[i] +
T(data_y[i].real, -data_y[i].imag) * data_ddx[i];
} else {
data_ddout[s] += T(data_x[i].real, -data_x[i].imag) * data_ddy[i] +
T(data_y[i].real, -data_y[i].imag) * data_ddx[i];
}
new_s = false;
}
}
#endif
}
};
template <typename DeviceContext, typename T>
struct DotDoubleGradFunction<DeviceContext,
T,
paddle::operators::math::DisableComplex<T>> {
void operator()(const DeviceContext& ctx,
const DenseTensor* tensor_x,
const DenseTensor* tensor_y,
const DenseTensor* tensor_dout,
const DenseTensor* tensor_ddx,
const DenseTensor* tensor_ddy,
DenseTensor* tensor_dx,
DenseTensor* tensor_dy,
DenseTensor* tensor_ddout) {
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == tensor_dout->dims().size()) {
auto& dev = *ctx.eigen_device();
auto dout = EigenVector<T>::Flatten(*tensor_dout);
if (tensor_dx) {
tensor_dx->mutable_data<T>();
auto ddy = EigenVector<T>::Flatten(*tensor_ddy);
Eigen::DSizes<int, 1> size(tensor_ddy->numel());
auto dx = EigenVector<T>::Flatten(*tensor_dx);
dx.device(dev) = ddy * dout.broadcast(size);
}
if (tensor_dy) {
tensor_dy->mutable_data<T>();
auto ddx = EigenVector<T>::Flatten(*tensor_ddx);
Eigen::DSizes<int, 1> size(tensor_ddx->numel());
auto dy = EigenVector<T>::Flatten(*tensor_dy);
dy.device(dev) = ddx * dout.broadcast(size);
}
if (tensor_ddout) {
tensor_ddout->mutable_data<T>();
auto x = EigenVector<T>::Flatten(*tensor_x);
auto y = EigenVector<T>::Flatten(*tensor_y);
auto ddx = EigenVector<T>::Flatten(*tensor_ddx);
auto ddy = EigenVector<T>::Flatten(*tensor_ddy);
auto ddout = EigenVector<T>::Flatten(*tensor_ddout);
ddout.device(dev) = (x * ddy + y * ddx).sum();
}
}
#else
const auto* data_dout = tensor_dout->data<T>();
if (tensor_dx) {
auto* data_dx = tensor_dx->mutable_data<T>();
const auto* data_ddy = tensor_ddy->data<T>();
const DDim& dim = tensor_dx->dims();
size_t N = static_cast<size_t>(product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dx[i] = data_dout[s] * data_ddy[i];
}
}
if (tensor_dy) {
auto* data_dy = tensor_dy->mutable_data<T>();
const auto* data_ddx = tensor_ddx->data<T>();
const DDim& dim = tensor_dy->dims();
size_t N = static_cast<size_t>(product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_dy[i] = data_dout[s] * data_ddx[i];
}
}
if (tensor_ddout) {
auto* data_ddout = tensor_ddout->mutable_data<T>();
auto* data_x = tensor_x->data<T>();
auto* data_y = tensor_y->data<T>();
auto* data_ddx = tensor_ddx->data<T>();
auto* data_ddy = tensor_ddy->data<T>();
const DDim& dim = tensor_dy->dims();
size_t N = static_cast<size_t>(product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
bool new_s = false;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) {
++s;
new_s = true;
}
if (new_s) {
data_ddout[s] = data_x[i] * data_ddy[i] + data_y[i] * data_ddx[i];
} else {
data_ddout[s] += data_x[i] * data_ddy[i] + data_y[i] * data_ddx[i];
}
new_s = false;
}
}
#endif
}
};
template <typename DeviceContext, typename T, typename Enabel = void>
struct DotTripleGradFunction {
void operator()(const DeviceContext& ctx,
const DenseTensor* in_tensor_x,
const DenseTensor* in_tensor_y,
const DenseTensor* in_tensor_ddx,
const DenseTensor* in_tensor_ddy,
const DenseTensor* in_tensor_d_dx,
const DenseTensor* in_tensor_d_dy,
const DenseTensor* in_tensor_dout,
const DenseTensor* in_tensor_d_ddout,
DenseTensor* out_tensor_d_x,
DenseTensor* out_tensor_d_y,
DenseTensor* out_tensor_d_dout,
DenseTensor* out_tensor_d_ddx,
DenseTensor* out_tensor_d_ddy);
};
// TODO(wuweilong): enable this function when the unittests framewark for multi
// grad is ok (dtype: complex64 or complex128).
template <typename DeviceContext, typename T>
struct DotTripleGradFunction<DeviceContext,
T,
paddle::operators::math::EnableComplex<T>> {
void operator()(const DeviceContext& ctx,
const DenseTensor* in_tensor_x,
const DenseTensor* in_tensor_y,
const DenseTensor* in_tensor_ddx,
const DenseTensor* in_tensor_ddy,
const DenseTensor* in_tensor_d_dx,
const DenseTensor* in_tensor_d_dy,
const DenseTensor* in_tensor_dout,
const DenseTensor* in_tensor_d_ddout,
DenseTensor* out_tensor_d_x,
DenseTensor* out_tensor_d_y,
DenseTensor* out_tensor_d_dout,
DenseTensor* out_tensor_d_ddx,
DenseTensor* out_tensor_d_ddy) {
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == in_tensor_d_ddout->dims().size()) {
DenseTensor in_tensor_d_ddout_help;
auto& dev = *ctx.eigen_device();
if (out_tensor_d_x || out_tensor_d_y) {
in_tensor_d_ddout_help =
Conj<T, DeviceContext>(ctx, *in_tensor_d_ddout);
}
if (out_tensor_d_x) {
auto ddy = EigenVector<T>::Flatten(*in_tensor_ddy);
Eigen::DSizes<int, 1> size(in_tensor_ddy->numel());
auto d_x = EigenVector<T>::Flatten(*out_tensor_d_x);
auto d_ddout = EigenVector<T>::Flatten(in_tensor_d_ddout_help);
d_x.device(dev) = ddy * d_ddout.broadcast(size);
}
if (out_tensor_d_y) {
auto ddx = EigenVector<T>::Flatten(*in_tensor_ddx);
Eigen::DSizes<int, 1> size(in_tensor_ddx->numel());
auto d_y = EigenVector<T>::Flatten(*out_tensor_d_y);
auto d_ddout = EigenVector<T>::Flatten(in_tensor_d_ddout_help);
d_y.device(dev) = ddx * d_ddout.broadcast(size);
}
if (out_tensor_d_dout) {
DenseTensor in_tensor_ddx_help =
Conj<T, DeviceContext>(ctx, *in_tensor_ddx);
DenseTensor in_tensor_ddy_help =
Conj<T, DeviceContext>(ctx, *in_tensor_ddy);
auto ddx = EigenVector<T>::Flatten(in_tensor_ddx_help);
auto ddy = EigenVector<T>::Flatten(in_tensor_ddy_help);
auto d_dx = EigenVector<T>::Flatten(*in_tensor_d_dx);
auto d_dy = EigenVector<T>::Flatten(*in_tensor_d_dy);
auto d_dout = EigenVector<T>::Flatten(*out_tensor_d_dout);
d_dout.device(dev) = (ddx * d_dy + ddy * d_dx).sum();
}
if (out_tensor_d_ddx) {
DenseTensor in_tensor_dout_help =
Conj<T, DeviceContext>(ctx, *in_tensor_dout);
DenseTensor in_tensor_y_help =
Conj<T, DeviceContext>(ctx, *in_tensor_y);
auto dout = EigenVector<T>::Flatten(in_tensor_dout_help);
auto y = EigenVector<T>::Flatten(in_tensor_y_help);
auto d_ddout = EigenVector<T>::Flatten(*in_tensor_d_ddout);
auto d_dy = EigenVector<T>::Flatten(*in_tensor_d_dy);
auto d_ddx = EigenVector<T>::Flatten(*out_tensor_d_ddx);
Eigen::DSizes<int, 1> size(in_tensor_y->numel());
d_ddx.device(dev) =
(dout.broadcast(size) * d_dy + y * d_ddout.broadcast(size));
}
if (out_tensor_d_ddy) {
DenseTensor in_tensor_dout_help =
Conj<T, DeviceContext>(ctx, *in_tensor_dout);
DenseTensor in_tensor_x_help =
Conj<T, DeviceContext>(ctx, *in_tensor_x);
auto dout = EigenVector<T>::Flatten(in_tensor_dout_help);
auto x = EigenVector<T>::Flatten(in_tensor_x_help);
auto d_ddout = EigenVector<T>::Flatten(*in_tensor_d_ddout);
auto d_dx = EigenVector<T>::Flatten(*in_tensor_d_dx);
auto d_ddy = EigenVector<T>::Flatten(*out_tensor_d_ddy);
Eigen::DSizes<int, 1> size(in_tensor_x->numel());
d_ddy.device(dev) =
(dout.broadcast(size) * d_dx + x * d_ddout.broadcast(size));
}
}
#else
const auto* data_d_ddout = in_tensor_d_ddout->data<T>();
if (out_tensor_d_x) {
auto* data_d_x = out_tensor_d_x->mutable_data<T>();
const auto* data_ddy = in_tensor_ddy->data<T>();
const DDim& dim = out_tensor_d_x->dims();
size_t N = static_cast<size_t>(product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_d_x[i] = T(data_ddy[i].real, -data_ddy[i].imag) * data_d_ddout[s];
}
}
if (out_tensor_d_y) {
auto* data_d_y = out_tensor_d_y->mutable_data<T>();
const auto* data_ddx = in_tensor_ddx->data<T>();
const DDim& dim = out_tensor_d_y->dims();
size_t N = static_cast<size_t>(product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_d_y[i] = T(data_ddx[i].real, -data_ddx[i].imag) * data_d_ddout[s];
}
}
if (out_tensor_d_dout) {
auto* data_d_dout = out_tensor_d_dout->mutable_data<T>();
auto* data_ddx = in_tensor_ddx->data<T>();
auto* data_ddy = in_tensor_ddy->data<T>();
auto* data_d_dx = in_tensor_d_dx->data<T>();
auto* data_d_dy = in_tensor_d_dy->data<T>();
const DDim& dim = out_tensor_d_dout->dims();
size_t N = static_cast<size_t>(product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
bool new_s = false;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) {
++s;
new_s = true;
}
if (new_s) {
data_d_dout[s] =
T(data_ddy[i].real, -data_ddy[i].imag) * data_d_dx[i] +
T(data_ddx[i].real, -data_ddx[i].imag) * data_d_dy[i];
} else {
data_d_dout[s] +=
T(data_ddy[i].real, -data_ddy[i].imag) * data_d_dx[i] +
T(data_ddx[i].real, -data_ddx[i].imag) * data_d_dy[i];
}
new_s = false;
}
}
if (out_tensor_d_ddx) {
auto* data_d_ddx = out_tensor_d_ddx->mutable_data<T>();
auto* data_dout = in_tensor_dout->data<T>();
auto* data_d_dy = in_tensor_d_dy->data<T>();
auto* data_y = in_tensor_y->data<T>();
auto* data_d_ddout = in_tensor_d_ddout->data<T>();
const DDim& dim = out_tensor_d_ddx->dims();
size_t N = static_cast<size_t>(product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_d_ddx[i] =
T(data_dout[s].real, -data_dout[s].imag) * data_d_dy[i] +
T(data_y[i].real, -data_y[i].imag) * data_d_ddout[s];
}
}
if (out_tensor_d_ddy) {
auto* data_d_ddy = out_tensor_d_ddy->mutable_data<T>();
auto* data_dout = in_tensor_dout->data<T>();
auto* data_d_dx = in_tensor_d_dx->data<T>();
auto* data_x = in_tensor_x->data<T>();
auto* data_d_ddout = in_tensor_d_ddout->data<T>();
const DDim& dim = out_tensor_d_ddy->dims();
size_t N = static_cast<size_t>(product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_d_ddy[i] =
T(data_dout[s].real, -data_dout[s].imag) * data_d_dx[i] +
T(data_x[i].real, -data_x[i].imag) * data_d_ddout[s];
}
}
#endif
}
};
template <typename DeviceContext, typename T>
struct DotTripleGradFunction<DeviceContext,
T,
paddle::operators::math::DisableComplex<T>> {
void operator()(const DeviceContext& ctx,
const DenseTensor* in_tensor_x,
const DenseTensor* in_tensor_y,
const DenseTensor* in_tensor_ddx,
const DenseTensor* in_tensor_ddy,
const DenseTensor* in_tensor_d_dx,
const DenseTensor* in_tensor_d_dy,
const DenseTensor* in_tensor_dout,
const DenseTensor* in_tensor_d_ddout,
DenseTensor* out_tensor_d_x,
DenseTensor* out_tensor_d_y,
DenseTensor* out_tensor_d_dout,
DenseTensor* out_tensor_d_ddx,
DenseTensor* out_tensor_d_ddy) {
#if defined(__NVCC__) || defined(__HIPCC__)
if (1 == in_tensor_d_ddout->dims().size()) {
auto& dev = *ctx.eigen_device();
auto d_ddout = EigenVector<T>::Flatten(*in_tensor_d_ddout);
if (out_tensor_d_x) {
out_tensor_d_x->mutable_data<T>();
auto ddy = EigenVector<T>::Flatten(*in_tensor_ddy);
Eigen::DSizes<int, 1> size(in_tensor_ddy->numel());
auto d_x = EigenVector<T>::Flatten(*out_tensor_d_x);
d_x.device(dev) = ddy * d_ddout.broadcast(size);
}
if (out_tensor_d_y) {
out_tensor_d_y->mutable_data<T>();
auto ddx = EigenVector<T>::Flatten(*in_tensor_ddx);
Eigen::DSizes<int, 1> size(in_tensor_ddx->numel());
auto d_y = EigenVector<T>::Flatten(*out_tensor_d_y);
d_y.device(dev) = ddx * d_ddout.broadcast(size);
}
if (out_tensor_d_dout) {
out_tensor_d_dout->mutable_data<T>();
auto ddx = EigenVector<T>::Flatten(*in_tensor_ddx);
auto ddy = EigenVector<T>::Flatten(*in_tensor_ddy);
auto d_dx = EigenVector<T>::Flatten(*in_tensor_d_dx);
auto d_dy = EigenVector<T>::Flatten(*in_tensor_d_dy);
auto d_dout = EigenVector<T>::Flatten(*out_tensor_d_dout);
d_dout.device(dev) = (ddx * d_dy + ddy * d_dx).sum();
}
if (out_tensor_d_ddx) {
out_tensor_d_ddx->mutable_data<T>();
auto dout = EigenVector<T>::Flatten(*in_tensor_dout);
auto y = EigenVector<T>::Flatten(*in_tensor_y);
auto d_ddout = EigenVector<T>::Flatten(*in_tensor_d_ddout);
auto d_dy = EigenVector<T>::Flatten(*in_tensor_d_dy);
auto d_ddx = EigenVector<T>::Flatten(*out_tensor_d_ddx);
Eigen::DSizes<int, 1> size(in_tensor_y->numel());
d_ddx.device(dev) =
(dout.broadcast(size) * d_dy + y * d_ddout.broadcast(size));
}
if (out_tensor_d_ddy) {
out_tensor_d_ddy->mutable_data<T>();
auto dout = EigenVector<T>::Flatten(*in_tensor_dout);
auto x = EigenVector<T>::Flatten(*in_tensor_x);
auto d_ddout = EigenVector<T>::Flatten(*in_tensor_d_ddout);
auto d_dx = EigenVector<T>::Flatten(*in_tensor_d_dx);
auto d_ddy = EigenVector<T>::Flatten(*out_tensor_d_ddy);
Eigen::DSizes<int, 1> size(in_tensor_x->numel());
d_ddy.device(dev) =
(dout.broadcast(size) * d_dx + x * d_ddout.broadcast(size));
}
}
#else
const auto* data_d_ddout = in_tensor_d_ddout->data<T>();
if (out_tensor_d_x) {
auto* data_d_x = out_tensor_d_x->mutable_data<T>();
const auto* data_ddy = in_tensor_ddy->data<T>();
const DDim& dim = out_tensor_d_x->dims();
size_t N = static_cast<size_t>(product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_d_x[i] = data_ddy[i] * data_d_ddout[s];
}
}
if (out_tensor_d_y) {
auto* data_d_y = out_tensor_d_y->mutable_data<T>();
const auto* data_ddx = in_tensor_ddx->data<T>();
const DDim& dim = out_tensor_d_y->dims();
size_t N = static_cast<size_t>(product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_d_y[i] = data_ddx[i] * data_d_ddout[s];
}
}
if (out_tensor_d_dout) {
auto* data_d_dout = out_tensor_d_dout->mutable_data<T>();
auto* data_ddx = in_tensor_ddx->data<T>();
auto* data_ddy = in_tensor_ddy->data<T>();
auto* data_d_dx = in_tensor_d_dx->data<T>();
auto* data_d_dy = in_tensor_d_dy->data<T>();
const DDim& dim = in_tensor_ddx->dims();
size_t N = static_cast<size_t>(product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
bool new_s = false;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) {
++s;
new_s = true;
}
if (new_s) {
data_d_dout[s] =
data_ddy[i] * data_d_dx[i] + data_ddx[i] * data_d_dy[i];
} else {
data_d_dout[s] +=
data_ddy[i] * data_d_dx[i] + data_ddx[i] * data_d_dy[i];
}
new_s = false;
}
}
if (out_tensor_d_ddx) {
auto* data_d_ddx = out_tensor_d_ddx->mutable_data<T>();
auto* data_dout = in_tensor_dout->data<T>();
auto* data_d_dy = in_tensor_d_dy->data<T>();
auto* data_y = in_tensor_y->data<T>();
auto* data_d_ddout = in_tensor_d_ddout->data<T>();
const DDim& dim = out_tensor_d_ddx->dims();
size_t N = static_cast<size_t>(product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_d_ddx[i] =
data_dout[s] * data_d_dy[i] + data_y[i] * data_d_ddout[s];
}
}
if (out_tensor_d_ddy) {
auto* data_d_ddy = out_tensor_d_ddy->mutable_data<T>();
auto* data_dout = in_tensor_dout->data<T>();
auto* data_d_dx = in_tensor_d_dx->data<T>();
auto* data_x = in_tensor_x->data<T>();
auto* data_d_ddout = in_tensor_d_ddout->data<T>();
const DDim& dim = out_tensor_d_ddy->dims();
size_t N = static_cast<size_t>(product(dim));
auto step = dim[dim.size() - 1];
int s = -1;
for (size_t i = 0; i < N; ++i) {
if (0 == i % step) ++s;
data_d_ddy[i] =
data_dout[s] * data_d_dx[i] + data_x[i] * data_d_ddout[s];
}
}
#endif
}
};
template <typename T, typename Context>
void DotGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy) {
if (dx) {
dx->mutable_data<T>();
}
if (dy) {
dy->mutable_data<T>();
}
DotGradFunction<Context, T>()(dev_ctx, &x, &y, &dout, dx, dy);
}
template <typename T, typename Context>
void DotDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& ddx,
const DenseTensor& ddy,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy,
DenseTensor* ddout) {
if (dx) {
dx->mutable_data<T>();
}
if (dy) {
dy->mutable_data<T>();
}
if (ddout) {
ddout->mutable_data<T>();
}
DotDoubleGradFunction<Context, T>()(
dev_ctx, &x, &y, &dout, ddx, ddy, dx, dy, ddout);
}
template <typename T, typename Context>
void DotTripleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& ddx,
const DenseTensor& ddy,
const DenseTensor& d_dx,
const DenseTensor& d_dy,
const DenseTensor& dout,
const DenseTensor& d_ddout,
DenseTensor* d_x,
DenseTensor* d_y,
DenseTensor* d_ddx,
DenseTensor* d_ddy,
DenseTensor* d_dout) {
if (d_x) {
d_x->mutable_data<T>();
}
if (d_y) {
d_y->mutable_data<T>();
}
if (d_ddx) {
d_ddx->mutable_data<T>();
}
if (d_ddy) {
d_ddy->mutable_data<T>();
}
if (d_dout) {
d_dout->mutable_data<T>();
}
DotTripleGradFunction<Context, T>()(dev_ctx,
&x,
&y,
ddx,
ddy,
d_dx,
d_dy,
dout,
d_ddout,
d_x,
d_y,
d_dout,
d_ddx,
d_ddy);
}
} // namespace pten
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
// #include "paddle/pten/kernels/complex_kernel.h"
#include "paddle/pten/include/math.h"
#include "paddle/pten/kernels/empty_kernel.h"
#include "paddle/pten/kernels/impl/dot_grad_kernel_impl.h"
#include "paddle/pten/kernels/impl/matmul_kernel_impl.h"
#include "paddle/pten/kernels/cpu/reduce.h"
#include "paddle/pten/kernels/funcs/reduce_functor.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/pten/kernels/gpu/reduce.h"
#endif
namespace pten {
template <typename Context, typename T>
struct ReduceSumForMatmulGrad {
void operator()(const Context& dev_ctx,
const DenseTensor& input,
DenseTensor* output,
const std::vector<int>& reduce_dims);
};
template <typename T>
struct ReduceSumForMatmulGrad<CPUContext, T> {
void operator()(const CPUContext& dev_ctx,
const DenseTensor& input,
DenseTensor* output,
const std::vector<int>& reduce_dims) {
std::vector<int64_t> reduce_dims_tmp(reduce_dims.begin(),
reduce_dims.end());
ReduceKernelImpl<CPUContext, T, T, pten::funcs::SumFunctor>(
dev_ctx, input, output, reduce_dims_tmp, true, false);
}
};
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T>
struct ReduceSumForMatmulGrad<GPUContext, T> {
void operator()(const GPUContext& dev_ctx,
const DenseTensor& input,
DenseTensor* output,
const std::vector<int>& reduce_dims) {
auto stream = dev_ctx.stream();
kernels::
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
input, output, kps::IdentityFunctor<T>(), reduce_dims, stream);
}
};
#endif
// Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Identity op if the tensor is not of rank 3.
static DenseTensor FoldInitDims(const DenseTensor& input) {
DenseTensor output = input;
auto in_dims = input.dims();
if (in_dims.size() == 3) {
output.Resize({in_dims[0] * in_dims[1], in_dims[2]});
}
return output;
}
// Reshape a rank-3 tensor from P x M x N to M x (P * N).
// (Warning: This requires transposing data and writes into new memory.)
// Identity op if the tensor is not of rank 3.
template <typename Context, typename T>
static DenseTensor FoldHeadAndLastDims(const Context& dev_ctx,
const DenseTensor& input) {
auto in_dims = input.dims();
if (in_dims.size() != 3) {
return input;
}
DenseTensor output = EmptyLike<T, Context>(dev_ctx, input);
output.Resize({in_dims[1], in_dims[0], in_dims[2]});
std::vector<int> axis = {1, 0, 2};
math::Transpose<Context, T, 3> trans;
trans(dev_ctx, input, &output, axis);
output.Resize({in_dims[1], in_dims[0] * in_dims[2]});
return output;
}
template <typename Context, typename T>
void MatMul(const Context& dev_ctx,
const DenseTensor& a,
bool trans_a,
const DenseTensor& b,
bool trans_b,
DenseTensor* out,
bool flag = false) {
out->mutable_data<T>();
auto blas = paddle::operators::math::GetBlas<Context, T>(dev_ctx);
auto mat_dim_a =
paddle::operators::math::CreateMatrixDescriptor(a.dims(), 0, trans_a);
auto mat_dim_b =
paddle::operators::math::CreateMatrixDescriptor(b.dims(), 0, trans_b);
if (a.dims().size() == 3 && b.dims().size() <= 2) {
// the transpose_X must be false, if is true, the transpose cost much time
if (!trans_a) {
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
}
}
blas.MatMul(a.data<T>(),
mat_dim_a,
b.data<T>(),
mat_dim_b,
static_cast<T>(1),
out->mutable_data<T>(),
static_cast<T>(flag));
}
/**
* Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
* original x_dim is returned.
*/
static DDim RowMatrixFromVector(const DDim& x_dim) {
if (x_dim.size() > 1) {
return x_dim;
}
return paddle::framework::make_ddim({1, x_dim[0]});
}
/**
* Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
* original y_dim is returned.
*/
static DDim ColumnMatrixFromVector(const DDim& y_dim) {
if (y_dim.size() > 1) {
return y_dim;
}
return paddle::framework::make_ddim({y_dim[0], 1});
}
/**
* Reshape a tensor to 3-D or 2-D tensor by matrix descriptor.
*
* The shape would be [BatchSize, H, W] or [H, W].
* If transposed, `H,W` will be swapped.
*/
static void ReshapeTensorIntoMatrixSequence(
DenseTensor* x, const paddle::operators::math::MatDescriptor& descriptor) {
int64_t h, w;
h = descriptor.height_;
w = descriptor.width_;
if (descriptor.trans_) {
std::swap(w, h);
}
if (descriptor.batch_size_) {
x->Resize({descriptor.batch_size_, h, w});
} else {
x->Resize({h, w});
}
}
static void ReshapeXYOutIntoMatrixSequence(DenseTensor* x,
DenseTensor* y,
DenseTensor* out,
bool trans_x,
bool trans_y) {
auto x_dim = RowMatrixFromVector(x->dims());
auto y_dim = ColumnMatrixFromVector(y->dims());
auto mat_dim_x =
paddle::operators::math::CreateMatrixDescriptor(x_dim, 0, trans_x);
auto mat_dim_y =
paddle::operators::math::CreateMatrixDescriptor(y_dim, 0, trans_y);
if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) {
out->Resize({mat_dim_x.height_, mat_dim_y.width_});
} else {
out->Resize({(std::max)(mat_dim_x.batch_size_, mat_dim_y.batch_size_),
mat_dim_x.height_,
mat_dim_y.width_});
}
ReshapeTensorIntoMatrixSequence(x, mat_dim_x);
ReshapeTensorIntoMatrixSequence(y, mat_dim_y);
}
template <typename T, typename Context>
void CalcInputGrad(const Context& dev_ctx,
const DenseTensor& a,
bool trans_a,
bool is_fold_init_dims_a,
const DenseTensor& b,
bool trans_b,
bool is_fold_init_dims_b,
DenseTensor* out,
bool flag = false) {
if (out == nullptr) return;
bool need_combine =
(a.dims().size() == 3 || b.dims().size() == 3) && out->dims().size() == 2;
if (!need_combine) {
MatMul<Context, T>(dev_ctx, a, trans_a, b, trans_b, out, flag);
} else {
MatMul<Context, T>(
dev_ctx,
is_fold_init_dims_a ? FoldInitDims(a)
: FoldHeadAndLastDims<Context, T>(dev_ctx, a),
trans_a,
is_fold_init_dims_b ? FoldInitDims(b)
: FoldHeadAndLastDims<Context, T>(dev_ctx, b),
trans_b,
out,
flag);
}
}
template <typename T, typename Context>
void MatmulGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
bool transpose_x,
bool transpose_y,
DenseTensor* dx,
DenseTensor* dy) {
// get dims
std::vector<std::int64_t> x_dims = vectorize(x.dims());
std::vector<std::int64_t> y_dims = vectorize(y.dims());
std::vector<std::int64_t> dout_dims = vectorize(out_grad.dims());
int x_ndim = x_dims.size();
int y_ndim = y_dims.size();
int ndim = dout_dims.size();
// Case1 : x's or y's dim = 1
if (x_ndim == 1 && y_ndim == 1) {
if (dx) dx->mutable_data<T>();
if (dy) dy->mutable_data<T>();
if (out_grad.numel() == 1) {
DotGradFunction<Context, T>()(dev_ctx, &x, &y, &out_grad, dx, dy);
return;
}
}
bool is_broadcast = true;
if (x_ndim <= 2 || y_ndim <= 2) {
is_broadcast = false;
} else if (x_ndim != y_ndim) {
is_broadcast = true;
} else {
is_broadcast = !std::equal(
x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin());
}
// for complex
DenseTensor x_conj;
DenseTensor y_conj;
// Case2: no broadcast or no batch size, it aims to speed and it is same as
// matmul in old version.
if (!is_broadcast) {
DenseTensor x_help = x;
DenseTensor y_help = y;
DenseTensor out_grad_help = out_grad;
ReshapeXYOutIntoMatrixSequence(
&x_help, &y_help, &out_grad_help, transpose_x, transpose_y);
DDim dx_dims;
if (dx) {
dx_dims = dx->dims();
if (dx_dims != x_help.dims()) {
dx->Resize(x_help.dims());
}
y_conj = Conj<T>(dev_ctx, y_help);
}
DDim dy_dims;
if (dy) {
dy_dims = dy->dims();
if (dy_dims != y_help.dims()) {
dy->Resize(y_help.dims());
}
x_conj = Conj<T>(dev_ctx, x_help);
}
if (transpose_x && transpose_y) {
CalcInputGrad<T>(
dev_ctx, y_conj, true, true, out_grad_help, true, false, dx);
CalcInputGrad<T>(
dev_ctx, out_grad_help, true, true, x_conj, true, false, dy);
} else if (transpose_x) {
CalcInputGrad<T>(
dev_ctx, y_conj, false, false, out_grad_help, true, false, dx);
CalcInputGrad<T>(
dev_ctx, x_conj, false, false, out_grad_help, false, true, dy);
} else if (transpose_y) {
CalcInputGrad<T>(
dev_ctx, out_grad_help, false, false, y_conj, false, true, dx);
CalcInputGrad<T>(
dev_ctx, out_grad_help, true, true, x_conj, false, true, dy);
} else {
CalcInputGrad<T>(
dev_ctx, out_grad_help, false, false, y_conj, true, false, dx);
CalcInputGrad<T>(
dev_ctx, x_conj, true, true, out_grad_help, false, true, dy);
}
if (dx) {
if (dx_dims != x_help.dims()) {
dx->Resize(dx_dims);
}
}
if (dy) {
if (dy_dims != y_help.dims()) {
dy->Resize(dy_dims);
}
}
} else {
// Case3: broadcast. It need cost much time to reduce sum for the
// broadcast and wastes the memory.
// So we should avoid the case in reality.
VLOG(3) << "It need cost much time to reduce sum for the broadcast and "
"wastes the memory. So we should avoid the case in reality";
x_conj = Conj<T>(dev_ctx, x);
y_conj = Conj<T>(dev_ctx, y);
DenseTensor dx_help = Empty<T, Context>(dev_ctx);
DenseTensor dy_help = Empty<T, Context>(dev_ctx);
if (transpose_x) {
if (transpose_y) {
// X'Y': dA = Y'G', dB = G'X'
if (dx)
MatMulFunction<Context, T>(dev_ctx,
y_conj,
out_grad,
y_dims,
dout_dims,
&dx_help,
true,
true);
if (dy)
MatMulFunction<Context, T>(dev_ctx,
out_grad,
x_conj,
dout_dims,
x_dims,
&dy_help,
true,
true);
} else {
// X'Y: dX = YG', dY = XG
if (dx)
MatMulFunction<Context, T>(dev_ctx,
y_conj,
out_grad,
y_dims,
dout_dims,
&dx_help,
false,
true);
if (dy)
MatMulFunction<Context, T>(dev_ctx,
x_conj,
out_grad,
x_dims,
dout_dims,
&dy_help,
false,
false);
}
} else {
if (transpose_y) {
// XY': dX = GY, dY = G'X
if (dx)
MatMulFunction<Context, T>(dev_ctx,
out_grad,
y_conj,
dout_dims,
y_dims,
&dx_help,
false,
false);
if (dy)
MatMulFunction<Context, T>(dev_ctx,
out_grad,
x_conj,
dout_dims,
x_dims,
&dy_help,
true,
false);
} else {
// XY: dX = GY', dY = X'G
if (dx)
MatMulFunction<Context, T>(dev_ctx,
out_grad,
y_conj,
dout_dims,
y_dims,
&dx_help,
false,
true);
if (dy)
MatMulFunction<Context, T>(dev_ctx,
x_conj,
out_grad,
x_dims,
dout_dims,
&dy_help,
true,
false);
}
}
// get help dims
const std::vector<std::int64_t> dx_help_dims = vectorize(dx_help.dims());
const std::vector<std::int64_t> dy_help_dims = vectorize(dy_help.dims());
std::vector<std::int64_t> dx_broadcast_dims(ndim);
std::vector<std::int64_t> dy_broadcast_dims(ndim);
std::fill(
dx_broadcast_dims.data(), dx_broadcast_dims.data() + ndim - x_ndim, 1);
std::fill(
dy_broadcast_dims.data(), dy_broadcast_dims.data() + ndim - y_ndim, 1);
std::copy(x_dims.data(),
x_dims.data() + x_ndim,
dx_broadcast_dims.data() + ndim - x_ndim);
std::copy(y_dims.data(),
y_dims.data() + y_ndim,
dy_broadcast_dims.data() + ndim - y_ndim);
std::vector<int> dx_reduce_dims;
std::vector<int> dy_reduce_dims;
for (int idx = 0; idx <= ndim - 3; idx++) {
if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) {
dx_reduce_dims.push_back(idx);
}
if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) {
dy_reduce_dims.push_back(idx);
}
}
// reduce sum to get grad by ReduceSum
if (dx) {
if (dx_reduce_dims.empty()) {
*dx = std::move(dx_help);
} else {
ReduceSumForMatmulGrad<Context, T>()(
dev_ctx, dx_help, dx, dx_reduce_dims);
}
dx->Resize(x.dims());
}
if (dy) {
if (dy_reduce_dims.empty()) {
*dy = std::move(dy_help);
} else {
ReduceSumForMatmulGrad<Context, T>()(
dev_ctx, dy_help, dy, dy_reduce_dims);
}
dy->Resize(y.dims());
}
// Get the OutputGrad(out)
}
}
template <typename T, typename Context>
void MatmulDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
paddle::optional<const DenseTensor&> ddx,
paddle::optional<const DenseTensor&> ddy,
bool transpose_x,
bool transpose_y,
DenseTensor* dx,
DenseTensor* dy,
DenseTensor* ddout) {
// Get dims from the input x, y, output_grad
std::vector<std::int64_t> x_dims = vectorize(x.dims());
std::vector<std::int64_t> y_dims = vectorize(y.dims());
std::vector<std::int64_t> dout_dims = vectorize(dout.dims());
int x_ndim = x_dims.size();
int y_ndim = y_dims.size();
int ndim = dout_dims.size();
// Case1 : x's or y's dim = 1
if (x_ndim == 1 && y_ndim == 1) {
DotDoubleGradFunction<Context, T>()(
dev_ctx, &x, &y, &dout, ddx.get_ptr(), ddy.get_ptr(), dx, dy, ddout);
return;
}
DenseTensor x_conj;
DenseTensor y_conj;
DenseTensor dout_conj;
bool is_broadcast = true;
if (x_ndim <= 2 || y_ndim <= 2) {
is_broadcast = false;
} else if (x_ndim != y_ndim) {
is_broadcast = true;
} else {
is_broadcast = !std::equal(
x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin());
}
if (!is_broadcast) {
// Case2: no broadcast or no batch size
DenseTensor x_help = x;
DenseTensor y_help = y;
DenseTensor dout_help = dout;
ReshapeXYOutIntoMatrixSequence(
&x_help, &y_help, &dout_help, transpose_x, transpose_y);
DDim dx_dims;
if (dx) {
dx_dims = dx->dims();
if (dx_dims != x_help.dims()) {
dx->Resize(x_help.dims());
}
}
DDim dy_dims;
if (dy) {
dy_dims = dy->dims();
if (dy_dims != y_help.dims()) {
dy->Resize(y_help.dims());
}
}
DDim ddout_dims;
if (ddout) {
ddout_dims = ddout->dims();
if (ddout_dims != dout_help.dims()) {
ddout->Resize(dout_help.dims());
}
x_conj = Conj<T>(dev_ctx, x_help);
y_conj = Conj<T>(dev_ctx, y_help);
}
if (dx || dy) {
dout_conj = Conj<T>(dev_ctx, dout_help);
}
bool ddout_flag = false;
if (ddx) {
auto ddx_mat = ddx.get();
if (ddx_mat.dims() != x_help.dims()) {
ddx_mat.Resize(x_help.dims());
}
if (dy) {
if (transpose_x && transpose_y) {
// dy = dout' * ddx'
CalcInputGrad<T>(
dev_ctx, dout_conj, true, true, ddx_mat, true, false, dy, false);
} else if (transpose_x) {
// dy = ddx * dout
CalcInputGrad<T>(dev_ctx,
ddx_mat,
false,
false,
dout_conj,
false,
true,
dy,
false);
} else if (transpose_y) {
// dy = dout' * ddx
CalcInputGrad<T>(
dev_ctx, dout_conj, true, true, ddx_mat, false, true, dy, false);
} else {
// dy = ddx' * dout
CalcInputGrad<T>(
dev_ctx, ddx_mat, true, true, dout_conj, false, true, dy, false);
}
}
if (ddout) {
CalcInputGrad<T>(dev_ctx,
ddx_mat,
transpose_x,
true,
y_conj,
transpose_y,
false,
ddout,
ddout_flag);
ddout_flag = true;
}
}
if (ddy) {
auto ddy_mat = ddy.get();
if (ddy_mat.dims() != y_help.dims()) {
ddy_mat.Resize(y_help.dims());
}
if (dx) {
if (transpose_x && transpose_y) {
// dx = ddy' * dout'
CalcInputGrad<T>(
dev_ctx, ddy_mat, true, true, dout_conj, true, false, dx, false);
} else if (transpose_x) {
// dx = ddy * dout'
CalcInputGrad<T>(dev_ctx,
ddy_mat,
false,
false,
dout_conj,
true,
false,
dx,
false);
} else if (transpose_y) {
// dx = dout * ddy
CalcInputGrad<T>(dev_ctx,
dout_conj,
false,
false,
ddy_mat,
false,
true,
dx,
false);
} else {
// dx = dout * ddy'
CalcInputGrad<T>(dev_ctx,
dout_conj,
false,
false,
ddy_mat,
true,
false,
dx,
false);
}
}
if (ddout) {
CalcInputGrad<T>(dev_ctx,
x_conj,
transpose_x,
true,
ddy_mat,
transpose_y,
false,
ddout,
ddout_flag);
}
}
if (dx) {
if (dx_dims != x_help.dims()) {
dx->Resize(dx_dims);
}
}
if (dy) {
if (dy_dims != y_help.dims()) {
dy->Resize(dy_dims);
}
}
if (ddout) {
if (ddout_dims != dout_help.dims()) {
ddout->Resize(ddout_dims);
}
}
} else {
// Case3: broadcast. It need cost much time to reduce sum for the
// broadcast and wastes the memory.
// So we should avoid the case in reality.
VLOG(3) << "It need cost much time to reduce sum for the broadcast and "
"wastes the memory. So we should avoid the case in reality";
if (dx || dy) {
dout_conj = Conj<T>(dev_ctx, dout);
}
if (ddout) {
x_conj = Conj<T>(dev_ctx, x);
y_conj = Conj<T>(dev_ctx, y);
}
DenseTensor dx_help = Empty<T>(dev_ctx);
DenseTensor dy_help = Empty<T>(dev_ctx);
if (transpose_x) {
if (transpose_y) {
if (dx) {
MatMulFunction<Context, T>(dev_ctx,
ddy.get(),
dout_conj,
y_dims,
dout_dims,
&dx_help,
true,
true);
}
if (dy) {
MatMulFunction<Context, T>(dev_ctx,
dout_conj,
ddx.get(),
dout_dims,
x_dims,
&dy_help,
true,
true);
}
} else {
if (dx)
MatMulFunction<Context, T>(dev_ctx,
ddy.get(),
dout_conj,
y_dims,
dout_dims,
&dx_help,
false,
true);
if (dy)
MatMulFunction<Context, T>(dev_ctx,
ddx.get(),
dout_conj,
x_dims,
dout_dims,
&dy_help,
false,
false);
}
} else {
if (transpose_y) {
if (dx) {
MatMulFunction<Context, T>(dev_ctx,
dout_conj,
ddy.get(),
dout_dims,
y_dims,
&dx_help,
false,
false);
}
if (dy) {
MatMulFunction<Context, T>(dev_ctx,
dout_conj,
ddx.get(),
dout_dims,
x_dims,
&dy_help,
true,
false);
}
} else {
if (dx) {
MatMulFunction<Context, T>(dev_ctx,
dout_conj,
ddy.get(),
dout_dims,
y_dims,
&dx_help,
false,
true);
}
if (dy) {
MatMulFunction<Context, T>(dev_ctx,
ddx.get(),
dout_conj,
x_dims,
dout_dims,
&dy_help,
true,
false);
}
}
}
// get help dims
const std::vector<std::int64_t> dx_help_dims = vectorize(dx_help.dims());
const std::vector<std::int64_t> dy_help_dims = vectorize(dy_help.dims());
std::vector<std::int64_t> dx_broadcast_dims(ndim);
std::vector<std::int64_t> dy_broadcast_dims(ndim);
std::fill(
dx_broadcast_dims.data(), dx_broadcast_dims.data() + ndim - x_ndim, 1);
std::fill(
dy_broadcast_dims.data(), dy_broadcast_dims.data() + ndim - y_ndim, 1);
std::copy(x_dims.data(),
x_dims.data() + x_ndim,
dx_broadcast_dims.data() + ndim - x_ndim);
std::copy(y_dims.data(),
y_dims.data() + y_ndim,
dy_broadcast_dims.data() + ndim - y_ndim);
std::vector<int> dx_reduce_dims;
std::vector<int> dy_reduce_dims;
for (int idx = 0; idx <= ndim - 3; idx++) {
if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) {
dx_reduce_dims.push_back(idx);
}
if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) {
dy_reduce_dims.push_back(idx);
}
}
// Reduce sum to get grad by ReduceSum
if (dx) {
if (dx_reduce_dims.empty()) {
*dx = std::move(dx_help);
} else {
ReduceSumForMatmulGrad<Context, T>()(
dev_ctx, dx_help, dx, dx_reduce_dims);
}
dx->Resize(x.dims());
}
if (dy) {
if (dy_reduce_dims.empty()) {
*dy = std::move(dy_help);
} else {
ReduceSumForMatmulGrad<Context, T>()(
dev_ctx, dy_help, dy, dy_reduce_dims);
}
dy->Resize(y.dims());
}
if (ddout) {
// Calculate the gradient of OutputGrad(Out)
MatMulFunction<Context, T>(dev_ctx,
ddx.get(),
y_conj,
x_dims,
y_dims,
ddout,
transpose_x,
transpose_y);
MatMulFunction<Context, T>(dev_ctx,
x_conj,
ddy.get(),
x_dims,
y_dims,
ddout,
transpose_x,
transpose_y,
true);
}
}
}
template <typename T, typename Context>
void MatmulTripleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
const DenseTensor& ddx,
const DenseTensor& ddy,
paddle::optional<const DenseTensor&> d_dx,
paddle::optional<const DenseTensor&> d_dy,
paddle::optional<const DenseTensor&> d_ddout,
bool transpose_x,
bool transpose_y,
DenseTensor* out_d_x,
DenseTensor* out_d_y,
DenseTensor* out_d_dout,
DenseTensor* out_d_ddx,
DenseTensor* out_d_ddy) {
// Get dims from the input x, y, output_grad
std::vector<std::int64_t> x_dims = vectorize(x.dims());
std::vector<std::int64_t> y_dims = vectorize(y.dims());
std::vector<std::int64_t> dout_dims = vectorize(dout.dims());
int x_ndim = x_dims.size();
int y_ndim = y_dims.size();
int ndim = dout_dims.size();
// Case1 : x's and y's dim = 1
if (x_ndim == 1 && y_ndim == 1) {
VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 1";
DotTripleGradFunction<Context, T>()(dev_ctx,
&x,
&y,
&ddx,
&ddy,
d_dx.get_ptr(),
d_dy.get_ptr(),
&dout,
d_ddout.get_ptr(),
out_d_x,
out_d_y,
out_d_dout,
out_d_ddx,
out_d_ddy);
return;
}
DenseTensor x_conj;
DenseTensor y_conj;
DenseTensor dout_conj;
DenseTensor ddx_conj;
DenseTensor ddy_conj;
bool is_broadcast = true;
if (x_ndim <= 2 || y_ndim <= 2) {
is_broadcast = false;
} else if (x_ndim != y_ndim) {
is_broadcast = true;
} else {
is_broadcast = !std::equal(
x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, y_dims.cbegin());
}
if (!is_broadcast) {
// Case2: no broadcast or no batch size
VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 2";
DenseTensor x_help = x;
DenseTensor y_help = y;
DenseTensor dout_help = dout;
DenseTensor ddx_help = ddx;
DenseTensor ddy_help = ddy;
ReshapeXYOutIntoMatrixSequence(
&x_help, &y_help, &dout_help, transpose_x, transpose_y);
if (ddx_help.dims() != x_help.dims()) {
ddx_help.Resize(x_help.dims());
}
if (ddy_help.dims() != y_help.dims()) {
ddy_help.Resize(y_help.dims());
}
DDim out_dx_dims;
if (out_d_x) {
out_dx_dims = out_d_x->dims();
if (out_dx_dims != x_help.dims()) {
out_d_x->Resize(x_help.dims());
}
}
DDim out_dy_dims;
if (out_d_y) {
out_dy_dims = out_d_y->dims();
if (out_dy_dims != y_help.dims()) {
out_d_y->Resize(y_help.dims());
}
}
DDim out_d_dout_dims;
if (out_d_dout) {
out_d_dout_dims = out_d_dout->dims();
if (out_d_dout_dims != dout_help.dims()) {
out_d_dout->Resize(dout_help.dims());
}
ddx_conj = Conj<T>(dev_ctx, ddx_help);
ddy_conj = Conj<T>(dev_ctx, ddy_help);
}
DDim out_d_ddx_dims;
if (out_d_ddx) {
out_d_ddx_dims = out_d_ddx->dims();
if (out_d_ddx_dims != x_help.dims()) {
out_d_ddx->Resize(x_help.dims());
}
}
DDim out_d_ddy_dims;
if (out_d_ddy) {
out_d_ddy_dims = out_d_ddy->dims();
if (out_d_ddy_dims != y_help.dims()) {
out_d_ddy->Resize(y_help.dims());
}
}
if (out_d_ddx || out_d_ddy) {
x_conj = Conj<T>(dev_ctx, x_help);
y_conj = Conj<T>(dev_ctx, y_help);
dout_conj = Conj<T>(dev_ctx, dout_help);
}
bool d_dout_flag = false;
bool d_ddx_flag = false;
bool d_ddy_flag = false;
if (d_ddout) {
auto d_ddout_mat = d_ddout.get();
if (d_ddout_mat.dims() != dout_help.dims()) {
d_ddout_mat.Resize(dout_help.dims());
}
if (out_d_y) {
if (transpose_x && transpose_y) {
// out_d_y = d_ddout' * ddx'
CalcInputGrad<T>(dev_ctx,
d_ddout_mat,
true,
true,
ddx_conj,
true,
false,
out_d_y,
false);
} else if (transpose_x) {
// out_d_y = ddx * d_ddout
CalcInputGrad<T>(dev_ctx,
ddx_conj,
false,
false,
d_ddout_mat,
false,
true,
out_d_y,
false);
} else if (transpose_y) {
// out_d_y = d_ddout' * ddx
CalcInputGrad<T>(dev_ctx,
d_ddout_mat,
true,
true,
ddx_conj,
false,
true,
out_d_y,
false);
} else {
// out_d_y = ddx' * d_ddout
CalcInputGrad<T>(dev_ctx,
ddx_conj,
true,
true,
d_ddout_mat,
false,
true,
out_d_y,
false);
}
}
if (out_d_x) {
if (transpose_x && transpose_y) {
// out_d_x = ddy' * d_ddout'
CalcInputGrad<T>(dev_ctx,
ddy_conj,
true,
true,
d_ddout_mat,
true,
false,
out_d_x,
false);
} else if (transpose_x) {
// out_d_x = ddy * d_ddout'
CalcInputGrad<T>(dev_ctx,
ddy_conj,
false,
false,
d_ddout_mat,
true,
false,
out_d_x,
false);
} else if (transpose_y) {
// out_d_x = d_ddout * ddy
CalcInputGrad<T>(dev_ctx,
d_ddout_mat,
false,
false,
ddy_conj,
false,
true,
out_d_x,
false);
} else {
// out_d_x = d_ddout * ddy'
CalcInputGrad<T>(dev_ctx,
d_ddout_mat,
false,
false,
ddy_conj,
true,
false,
out_d_x,
false);
}
}
// equations:
// d_ddx = DOut * D_DY + Y * D_DDOut
// Let: d_ddx1 = Y * D_DDOut
// Let: d_ddx2 = DOut * D_DY
// d_ddy = DOut * D_DX + X * D_DDOut
// Let: d_ddy1 = X * D_DDOut
// Let: d_ddy2 = DOut * D_DX
// d_dout = DDY * D_DX + DDX * D_DY
// Let: d_dout1 = DDX * D_DY
// Let: d_dout2 = DDY * D_DX
// compute d_ddx1
if (out_d_ddx) {
if (transpose_x && transpose_y) {
// out_d_ddx1 = y' * d_ddout'
CalcInputGrad<T>(dev_ctx,
y_conj,
true,
true,
d_ddout_mat,
true,
false,
out_d_ddx,
d_ddx_flag);
} else if (transpose_x) {
// out_d_ddx1 = y * d_ddout'
CalcInputGrad<T>(dev_ctx,
y_conj,
false,
false,
d_ddout_mat,
true,
false,
out_d_ddx,
d_ddx_flag);
} else if (transpose_y) {
// out_d_ddx1 = d_ddout * y
CalcInputGrad<T>(dev_ctx,
d_ddout_mat,
false,
false,
y_conj,
false,
true,
out_d_ddx,
d_ddx_flag);
} else {
// out_d_ddx1 = d_ddout * y'
CalcInputGrad<T>(dev_ctx,
d_ddout_mat,
false,
false,
y_conj,
true,
false,
out_d_ddx,
d_ddx_flag);
}
d_ddx_flag = true;
}
// compute d_ddy1
if (out_d_ddy) {
if (transpose_x && transpose_y) {
// out_d_ddy1 = d_ddout' * x'
CalcInputGrad<T>(dev_ctx,
d_ddout_mat,
true,
true,
x_conj,
true,
false,
out_d_ddy,
false);
} else if (transpose_x) {
// out_d_ddy1 = x * d_ddout
CalcInputGrad<T>(dev_ctx,
x_conj,
false,
false,
d_ddout_mat,
false,
true,
out_d_ddy,
false);
} else if (transpose_y) {
// out_d_ddy1 = d_ddout' * x
CalcInputGrad<T>(dev_ctx,
d_ddout_mat,
true,
true,
x_conj,
false,
true,
out_d_ddy,
false);
} else {
// out_d_ddy1 = x' * d_ddout
CalcInputGrad<T>(dev_ctx,
x_conj,
true,
true,
d_ddout_mat,
false,
true,
out_d_ddy,
false);
}
d_ddy_flag = true;
}
}
if (d_dy) {
auto d_dy_mat = d_dy.get();
if (d_dy_mat.dims() != y_help.dims()) {
d_dy_mat.Resize(y_help.dims());
}
// compute d_dout1
if (out_d_dout) {
CalcInputGrad<T>(dev_ctx,
ddx_conj,
transpose_x,
true,
d_dy_mat,
transpose_y,
false,
out_d_dout,
d_dout_flag);
d_dout_flag = true;
}
// compute d_ddx2
if (out_d_ddx) {
if (transpose_x && transpose_y) {
// out_d_ddx2 = D_DY' * DOut'
CalcInputGrad<T>(dev_ctx,
d_dy_mat,
true,
true,
dout_conj,
true,
false,
out_d_ddx,
d_ddx_flag);
} else if (transpose_x) {
// out_d_ddx2 = D_DY * Dout'
CalcInputGrad<T>(dev_ctx,
d_dy_mat,
false,
false,
dout_conj,
true,
false,
out_d_ddx,
d_ddx_flag);
} else if (transpose_y) {
// out_d_ddx2 = Dout * D_DY
CalcInputGrad<T>(dev_ctx,
dout_conj,
false,
false,
d_dy_mat,
false,
true,
out_d_ddx,
d_ddx_flag);
} else {
// out_d_ddx2 = Dout * D_DY'
CalcInputGrad<T>(dev_ctx,
dout_conj,
false,
false,
d_dy_mat,
true,
false,
out_d_ddx,
d_ddx_flag);
}
}
}
if (d_dx) {
auto d_dx_mat = d_dx.get();
if (d_dx_mat.dims() != x_help.dims()) {
d_dx_mat.Resize(x_help.dims());
}
// compute d_dout2
if (out_d_dout) {
CalcInputGrad<T>(dev_ctx,
d_dx_mat,
transpose_x,
true,
ddy_conj,
transpose_y,
false,
out_d_dout,
d_dout_flag);
}
// compute d_ddy2
if (out_d_ddy) {
if (transpose_x && transpose_y) {
// out_d_ddy2 = dout' * d_dx'
CalcInputGrad<T>(dev_ctx,
dout_conj,
true,
true,
d_dx_mat,
true,
false,
out_d_ddy,
d_ddy_flag);
} else if (transpose_x) {
// out_d_ddy2 = d_dx * dout
CalcInputGrad<T>(dev_ctx,
d_dx_mat,
false,
false,
dout_conj,
false,
true,
out_d_ddy,
d_ddy_flag);
} else if (transpose_y) {
// out_d_ddy2 = dout' * d_dx
CalcInputGrad<T>(dev_ctx,
dout_conj,
true,
true,
d_dx_mat,
false,
true,
out_d_ddy,
d_ddy_flag);
} else {
// out_d_ddy2 = d_dx' * dout
CalcInputGrad<T>(dev_ctx,
d_dx_mat,
true,
true,
dout_conj,
false,
true,
out_d_ddy,
d_ddy_flag);
}
}
}
if (out_d_x) {
if (out_dx_dims != x_help.dims()) {
out_d_x->Resize(out_dx_dims);
}
}
if (out_d_y) {
if (out_dy_dims != y_help.dims()) {
out_d_y->Resize(out_dy_dims);
}
}
if (out_d_dout) {
if (out_d_dout_dims != dout_help.dims()) {
out_d_dout->Resize(out_d_dout_dims);
}
}
if (out_d_ddx) {
if (out_d_ddx_dims != x_help.dims()) {
out_d_ddx->Resize(out_d_ddx_dims);
}
}
if (out_d_ddy) {
if (out_d_ddy_dims != y_help.dims()) {
out_d_ddy->Resize(out_d_ddy_dims);
}
}
} else {
// Case3: broadcast. It need cost much time to reduce sum for the
// broadcast and wastes the memory.
// So we should avoid the case in reality.
VLOG(3) << "======== MatMulV2TripleGradKernel, Compute ====== Case 3";
VLOG(3) << "It need cost much time to reduce sum for the broadcast and "
"wastes the memory. So we should avoid the case in reality";
DenseTensor out_dx_help = Empty<T>(dev_ctx);
DenseTensor out_dy_help = Empty<T>(dev_ctx);
DenseTensor out_d_ddx_help = Empty<T>(dev_ctx);
DenseTensor out_d_ddy_help = Empty<T>(dev_ctx);
if (out_d_dout) {
ddx_conj = Conj<T>(dev_ctx, ddx);
ddy_conj = Conj<T>(dev_ctx, ddy);
}
if (out_d_ddx || out_d_ddy) {
x_conj = Conj<T>(dev_ctx, x);
y_conj = Conj<T>(dev_ctx, y);
dout_conj = Conj<T>(dev_ctx, dout);
}
if (transpose_x) {
if (transpose_y) {
// dX = ddY' d_ddout’, dY = d_ddout’ ddX'
if (out_d_x)
MatMulFunction<Context, T>(dev_ctx,
ddy_conj,
d_ddout.get(),
y_dims,
dout_dims,
&out_dx_help,
true,
true);
if (out_d_y)
MatMulFunction<Context, T>(dev_ctx,
d_ddout.get(),
ddx_conj,
dout_dims,
x_dims,
&out_dy_help,
true,
true);
} else {
// dX = ddY d_ddout', dY = ddX d_ddout
if (out_d_x)
MatMulFunction<Context, T>(dev_ctx,
ddy_conj,
d_ddout.get(),
y_dims,
dout_dims,
&out_dx_help,
false,
true);
if (out_d_y)
MatMulFunction<Context, T>(dev_ctx,
ddx_conj,
d_ddout.get(),
x_dims,
dout_dims,
&out_dy_help,
false,
false);
}
} else {
if (transpose_y) {
// dX = d_ddout ddY, dY = d_ddout’ ddX
if (out_d_x)
MatMulFunction<Context, T>(dev_ctx,
d_ddout.get(),
ddy_conj,
dout_dims,
y_dims,
&out_dx_help,
false,
false);
if (out_d_y)
MatMulFunction<Context, T>(dev_ctx,
d_ddout.get(),
ddx_conj,
dout_dims,
x_dims,
&out_dy_help,
true,
false);
} else {
// dX = d_ddout ddY', dY = ddX' d_ddout
if (out_d_x)
MatMulFunction<Context, T>(dev_ctx,
d_ddout.get(),
ddy_conj,
dout_dims,
y_dims,
&out_dx_help,
false,
true);
if (out_d_y)
MatMulFunction<Context, T>(dev_ctx,
ddx_conj,
d_ddout.get(),
x_dims,
dout_dims,
&out_dy_help,
true,
false);
}
}
// get help dims
const std::vector<std::int64_t> dx_help_dims =
vectorize(out_dx_help.dims());
const std::vector<std::int64_t> dy_help_dims =
vectorize(out_dx_help.dims());
std::vector<std::int64_t> dx_broadcast_dims(ndim);
std::vector<std::int64_t> dy_broadcast_dims(ndim);
std::fill(
dx_broadcast_dims.data(), dx_broadcast_dims.data() + ndim - x_ndim, 1);
std::fill(
dy_broadcast_dims.data(), dy_broadcast_dims.data() + ndim - y_ndim, 1);
std::copy(x_dims.data(),
x_dims.data() + x_ndim,
dx_broadcast_dims.data() + ndim - x_ndim);
std::copy(y_dims.data(),
y_dims.data() + y_ndim,
dy_broadcast_dims.data() + ndim - y_ndim);
std::vector<int> dx_reduce_dims;
std::vector<int> dy_reduce_dims;
for (int idx = 0; idx <= ndim - 3; idx++) {
if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) {
dx_reduce_dims.push_back(idx);
}
if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) {
dy_reduce_dims.push_back(idx);
}
}
// Reduce sum to get grad by ReduceSum
if (out_d_x) {
if (dx_reduce_dims.empty()) {
*out_d_x = std::move(out_dx_help);
} else {
ReduceSumForMatmulGrad<Context, T>()(
dev_ctx, out_dx_help, out_d_x, dx_reduce_dims);
}
out_d_x->Resize(x.dims());
}
if (out_d_y) {
if (dy_reduce_dims.empty()) {
*out_d_y = std::move(out_dy_help);
} else {
ReduceSumForMatmulGrad<Context, T>()(
dev_ctx, out_dy_help, out_d_y, dy_reduce_dims);
}
out_d_y->Resize(y.dims());
}
// compute d_dout
if (out_d_dout) {
MatMulFunction<Context, T>(dev_ctx,
d_dx.get(),
ddy_conj,
x_dims,
y_dims,
out_d_dout,
transpose_x,
transpose_y);
MatMulFunction<Context, T>(dev_ctx,
ddx_conj,
d_dy.get(),
x_dims,
y_dims,
out_d_dout,
transpose_x,
transpose_y,
true);
}
// compute d_ddx
if (out_d_ddx) {
if (transpose_x && transpose_y) {
// out_d_ddx1 = y' * d_ddout'
MatMulFunction<Context, T>(dev_ctx,
y_conj,
d_ddout.get(),
y_dims,
dout_dims,
&out_d_ddx_help,
true,
true);
// out_d_ddx2 = D_DY' * DOut'
MatMulFunction<Context, T>(dev_ctx,
d_dy.get(),
dout_conj,
y_dims,
dout_dims,
&out_d_ddx_help,
true,
true,
true);
} else if (transpose_x) {
// out_d_ddx1 = y * d_ddout'
MatMulFunction<Context, T>(dev_ctx,
y_conj,
d_ddout.get(),
y_dims,
dout_dims,
&out_d_ddx_help,
false,
true);
// out_d_ddx2 = D_DY * Dout'
MatMulFunction<Context, T>(dev_ctx,
d_dy.get(),
dout_conj,
y_dims,
dout_dims,
&out_d_ddx_help,
false,
true,
true);
} else if (transpose_y) {
// out_d_ddx1 = d_ddout * y
MatMulFunction<Context, T>(dev_ctx,
d_ddout.get(),
y_conj,
dout_dims,
y_dims,
&out_d_ddx_help,
false,
false);
// out_d_ddx2 = Dout * D_DY
MatMulFunction<Context, T>(dev_ctx,
dout_conj,
d_dy.get(),
dout_dims,
y_dims,
&out_d_ddx_help,
false,
false,
true);
} else {
// out_d_ddx1 = d_ddout * y'
MatMulFunction<Context, T>(dev_ctx,
d_ddout.get(),
y_conj,
dout_dims,
y_dims,
&out_d_ddx_help,
false,
true);
// out_d_ddx2 = Dout * D_DY'
MatMulFunction<Context, T>(dev_ctx,
dout_conj,
d_dy.get(),
dout_dims,
y_dims,
&out_d_ddx_help,
false,
true,
true);
}
if (dx_reduce_dims.empty()) {
*out_d_ddx = std::move(out_d_ddx_help);
} else {
ReduceSumForMatmulGrad<Context, T>()(
dev_ctx, out_d_ddx_help, out_d_ddx, dx_reduce_dims);
}
out_d_ddx->Resize(x.dims());
}
// compute d_ddy
if (out_d_ddy) {
if (transpose_x && transpose_y) {
// out_d_ddy1 = d_ddout' * x'
MatMulFunction<Context, T>(dev_ctx,
d_ddout.get(),
x_conj,
dout_dims,
x_dims,
&out_d_ddy_help,
true,
true);
// out_d_ddy2 = dout' * d_dx'
MatMulFunction<Context, T>(dev_ctx,
dout_conj,
d_dx.get(),
dout_dims,
x_dims,
&out_d_ddy_help,
true,
true,
true);
} else if (transpose_x) {
// out_d_ddy1 = x * d_ddout
MatMulFunction<Context, T>(dev_ctx,
x_conj,
d_ddout.get(),
x_dims,
dout_dims,
&out_d_ddy_help,
false,
false);
// out_d_ddy2 = d_dx * dout
MatMulFunction<Context, T>(dev_ctx,
d_dx.get(),
dout_conj,
x_dims,
dout_dims,
&out_d_ddy_help,
false,
false,
true);
} else if (transpose_y) {
// out_d_ddy1 = d_ddout' * x
MatMulFunction<Context, T>(dev_ctx,
d_ddout.get(),
x_conj,
dout_dims,
x_dims,
&out_d_ddy_help,
true,
false);
// out_d_ddy2 = dout' * d_dx
MatMulFunction<Context, T>(dev_ctx,
dout_conj,
d_dx.get(),
dout_dims,
x_dims,
&out_d_ddy_help,
true,
false,
true);
} else {
// out_d_ddy1 = x' * d_ddout
MatMulFunction<Context, T>(dev_ctx,
x_conj,
d_ddout.get(),
x_dims,
dout_dims,
&out_d_ddy_help,
true,
false);
// out_d_ddy2 = d_dx' * dout
MatMulFunction<Context, T>(dev_ctx,
d_dx.get(),
dout_conj,
x_dims,
dout_dims,
&out_d_ddy_help,
true,
false,
true);
}
if (dy_reduce_dims.empty()) {
*out_d_ddy = std::move(out_d_ddy_help);
} else {
ReduceSumForMatmulGrad<Context, T>()(
dev_ctx, out_d_ddy_help, out_d_ddy, dy_reduce_dims);
}
out_d_ddy->Resize(y.dims());
}
}
}
} // namespace pten
......@@ -86,7 +86,7 @@ static void IndexIncreaseFromDims(const int ndim,
}
template <typename Context, typename T>
void MatMulFunction(const Context& context,
void MatMulFunction(const Context& dev_ctx,
const DenseTensor& X,
const DenseTensor& Y,
const std::vector<std::int64_t>& x_dims,
......@@ -102,7 +102,7 @@ void MatMulFunction(const Context& context,
const T* x_data = X.data<T>();
const T* y_data = Y.data<T>();
auto blas = paddle::operators::math::GetBlas<Context, T>(context);
auto blas = paddle::operators::math::GetBlas<Context, T>(dev_ctx);
if (x_ndim == 1 && y_ndim == 1) {
const int M = X.numel();
......@@ -117,6 +117,8 @@ void MatMulFunction(const Context& context,
M,
N));
VLOG(3) << "MatMul's case 1";
Out->Resize({1});
Out->mutable_data<T>();
blas.GEMM(CblasNoTrans,
CblasTrans,
1,
......@@ -471,7 +473,7 @@ void MatMulFunction(const Context& context,
}
template <typename Context, typename T>
void MatMulFunction(const Context& context,
void MatMulFunction(const Context& dev_ctx,
const DenseTensor& X,
const DenseTensor& Y,
DenseTensor* Out,
......@@ -481,11 +483,11 @@ void MatMulFunction(const Context& context,
const std::vector<std::int64_t> x_dims = vectorize(X.dims());
const std::vector<std::int64_t> y_dims = vectorize(Y.dims());
MatMulFunction<Context, T>(
context, X, Y, x_dims, y_dims, Out, trans_x, trans_y, flag);
dev_ctx, X, Y, x_dims, y_dims, Out, trans_x, trans_y, flag);
}
template <typename T, typename Context>
void MatmulKernel(const Context& context,
void MatmulKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool transpose_x,
......@@ -501,7 +503,7 @@ void MatmulKernel(const Context& context,
paddle::platform::errors::InvalidArgument(
"The Input(Y) dims size must not be equal 0,"
" but reviced dims size is 0. "));
MatMulFunction<Context, T>(context, x, y, out, transpose_x, transpose_y);
MatMulFunction<Context, T>(dev_ctx, x, y, out, transpose_x, transpose_y);
}
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/utils/optional.h"
namespace pten {
template <typename T, typename Context>
void MatmulGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
bool transpose_x,
bool transpose_y,
DenseTensor* dx,
DenseTensor* dy);
template <typename T, typename Context>
void MatmulDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
paddle::optional<const DenseTensor&> ddx,
paddle::optional<const DenseTensor&> ddy,
bool transpose_x,
bool transpose_y,
DenseTensor* dx,
DenseTensor* dy,
DenseTensor* ddout);
template <typename T, typename Context>
void MatmulTripleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
const DenseTensor& ddx,
const DenseTensor& ddy,
paddle::optional<const DenseTensor&> d_dx,
paddle::optional<const DenseTensor&> d_dy,
paddle::optional<const DenseTensor&> d_ddout,
bool transpose_x,
bool transpose_y,
DenseTensor* out_d_x,
DenseTensor* out_d_y,
DenseTensor* out_d_dout,
DenseTensor* out_d_ddx,
DenseTensor* out_d_ddy);
} // namespace pten
......@@ -14,14 +14,15 @@
#pragma once
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/infermeta/binary.h"
#include "paddle/pten/kernels/empty_kernel.h"
namespace pten {
template <typename T, typename Context>
void MatmulKernel(const Context& context,
void MatmulKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool transpose_x,
......@@ -29,17 +30,14 @@ void MatmulKernel(const Context& context,
DenseTensor* out);
template <typename T, typename Context>
DenseTensor Matmul(const Context& context,
DenseTensor Matmul(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
bool transpose_x,
bool transpose_y) {
auto out_meta = MatmulInferMeta(x.meta(), y.meta(), transpose_x, transpose_y);
DenseTensor dense_out(
pten::make_intrusive<paddle::experimental::SharedStorage>(
context.GetPlace()),
std::move(out_meta));
MatmulKernel<T, Context>(context, x, y, transpose_x, transpose_y, &dense_out);
auto dense_out = Empty<T, Context>(dev_ctx, std::move(out_meta));
MatmulKernel<T, Context>(dev_ctx, x, y, transpose_x, transpose_y, &dense_out);
return dense_out;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册