未验证 提交 1b491818 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Move mul op kernel into phi (#40833)

* add mul phi kernel

* remove mul op kernel

* remove original mul grad op

* fix cinn test

* fix dygraph test failed
上级 cf8be325
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include "paddle/fluid/framework/parallel_executor.h" #include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
USE_OP(mul); USE_OP_ITSELF(mul);
USE_OP(cinn_launch); USE_OP(cinn_launch);
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
namespace paddle::framework { namespace paddle::framework {
......
...@@ -674,7 +674,7 @@ TEST(BuildCinnPassTest, NoNeedBufferInput) { ...@@ -674,7 +674,7 @@ TEST(BuildCinnPassTest, NoNeedBufferInput) {
} // namespace paddle } // namespace paddle
USE_PASS(build_cinn_pass); USE_PASS(build_cinn_pass);
USE_OP(mul); USE_OP_ITSELF(mul);
USE_OP_ITSELF(relu); USE_OP_ITSELF(relu);
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP_ITSELF(relu_grad); USE_OP_ITSELF(relu_grad);
......
...@@ -300,6 +300,6 @@ TEST(CinnCompilerTest, Compile) { ...@@ -300,6 +300,6 @@ TEST(CinnCompilerTest, Compile) {
USE_PASS(build_cinn_pass); USE_PASS(build_cinn_pass);
USE_PASS(graph_viz_pass); USE_PASS(graph_viz_pass);
USE_OP(mul); USE_OP_ITSELF(mul);
USE_OP_ITSELF(relu); USE_OP_ITSELF(relu);
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
...@@ -98,4 +98,4 @@ TEST(test_var_helper, eager_var_helper) { ...@@ -98,4 +98,4 @@ TEST(test_var_helper, eager_var_helper) {
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
USE_OP(mul); USE_OP_ITSELF(mul);
...@@ -28,6 +28,8 @@ ...@@ -28,6 +28,8 @@
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add_grad, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(add_grad, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(matmul_with_flatten, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(matmul_with_flatten_grad, CPU, ALL_LAYOUT);
namespace platform = paddle::platform; namespace platform = paddle::platform;
namespace framework = paddle::framework; namespace framework = paddle::framework;
...@@ -267,7 +269,7 @@ TEST(TestHooks, TestGradVarLeafBackwardHookWithSortedGradAccmulated) { ...@@ -267,7 +269,7 @@ TEST(TestHooks, TestGradVarLeafBackwardHookWithSortedGradAccmulated) {
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
USE_OP(mul); USE_OP_ITSELF(mul);
USE_OP(mul_grad); USE_OP_ITSELF(mul_grad);
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP_ITSELF(elementwise_add_grad); USE_OP_ITSELF(elementwise_add_grad);
...@@ -416,4 +416,4 @@ TEST(test_layer, test_eager) { ...@@ -416,4 +416,4 @@ TEST(test_layer, test_eager) {
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
USE_OP(mul); USE_OP_ITSELF(mul);
...@@ -34,9 +34,13 @@ PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT); ...@@ -34,9 +34,13 @@ PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add_grad, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(add_grad, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sum, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(sum, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sum_grad, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(sum_grad, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(matmul_with_flatten, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(matmul_with_flatten_grad, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_DECLARE_KERNEL(add_grad, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(add_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sum_grad, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(sum_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(matmul_with_flatten, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(matmul_with_flatten_grad, GPU, ALL_LAYOUT);
#endif #endif
namespace imperative = paddle::imperative; namespace imperative = paddle::imperative;
...@@ -598,8 +602,8 @@ TEST(test_tracer, eager_tracer) { ...@@ -598,8 +602,8 @@ TEST(test_tracer, eager_tracer) {
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
USE_OP(mul); USE_OP_ITSELF(mul);
USE_OP(mul_grad); USE_OP_ITSELF(mul_grad);
USE_OP_ITSELF(reduce_sum); USE_OP_ITSELF(reduce_sum);
USE_OP_ITSELF(reduce_sum_grad); USE_OP_ITSELF(reduce_sum_grad);
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
...@@ -43,4 +43,4 @@ TEST(fc_op, test) { ...@@ -43,4 +43,4 @@ TEST(fc_op, test) {
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_OP(mul); USE_OP_ITSELF(mul);
...@@ -46,4 +46,4 @@ TEST(MulOpConverter, main) { ...@@ -46,4 +46,4 @@ TEST(MulOpConverter, main) {
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_OP(mul); USE_OP_ITSELF(mul);
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/operators/mul_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
namespace phi { namespace phi {
...@@ -46,6 +46,9 @@ using dnnl::memory; ...@@ -46,6 +46,9 @@ using dnnl::memory;
using dnnl::prop_kind; using dnnl::prop_kind;
using dnnl::stream; using dnnl::stream;
constexpr int kMULMKLDNNINT8 = 1;
constexpr int kMULMKLDNNFP32 = 2;
template <typename XT, typename YT, typename OT> template <typename XT, typename YT, typename OT>
class MulPrimitiveFactory { class MulPrimitiveFactory {
public: public:
......
...@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/mul_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
...@@ -27,6 +27,9 @@ namespace operators { ...@@ -27,6 +27,9 @@ namespace operators {
using framework::OpKernelType; using framework::OpKernelType;
using framework::Tensor; using framework::Tensor;
constexpr int kMULMKLDNNINT8 = 1;
constexpr int kMULMKLDNNFP32 = 2;
class MulOp : public framework::OperatorWithKernel { class MulOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -354,16 +357,3 @@ REGISTER_OPERATOR(mul_grad, ops::MulGradOp, ...@@ -354,16 +357,3 @@ REGISTER_OPERATOR(mul_grad, ops::MulGradOp,
ops::MulDoubleGradMaker<paddle::imperative::OpBase>); ops::MulDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(mul_grad_grad, ops::MulDoubleGradOp); REGISTER_OPERATOR(mul_grad_grad, ops::MulDoubleGradOp);
REGISTER_OP_CPU_KERNEL(
mul, ops::MulKernel<paddle::platform::CPUDeviceContext, float>,
ops::MulKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
mul_grad, ops::MulGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MulGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
mul_grad_grad,
ops::MulDoubleGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MulDoubleGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/mul_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(mul, ops::MulKernel<plat::CUDADeviceContext, float>,
ops::MulKernel<plat::CUDADeviceContext, double>,
ops::MulKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
mul_grad, ops::MulGradKernel<plat::CUDADeviceContext, float>,
ops::MulGradKernel<plat::CUDADeviceContext, double>,
ops::MulGradKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
mul_grad_grad,
ops::MulDoubleGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::MulDoubleGradKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
constexpr int kMULMKLDNNINT8 = 1;
constexpr int kMULMKLDNNFP32 = 2;
template <typename DeviceContext, typename T>
class MulKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X");
const Tensor* y = context.Input<Tensor>("Y");
Tensor* z = context.Output<Tensor>("Out");
const Tensor x_matrix =
x->dims().size() > 2
? framework::ReshapeToMatrix(
*x, context.template Attr<int>("x_num_col_dims"))
: *x;
const Tensor y_matrix =
y->dims().size() > 2
? framework::ReshapeToMatrix(
*y, context.template Attr<int>("y_num_col_dims"))
: *y;
z->mutable_data<T>(context.GetPlace());
auto z_dim = z->dims();
if (z_dim.size() != 2) {
z->Resize({x_matrix.dims()[0], y_matrix.dims()[1]});
}
auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
blas.MatMul(x_matrix, y_matrix, z);
if (z_dim.size() != 2) {
z->Resize(z_dim);
}
}
};
template <typename DeviceContext, typename T>
class MulGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims");
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto x_matrix = x->dims().size() > 2
? framework::ReshapeToMatrix(*x, x_num_col_dims)
: static_cast<const Tensor&>(*x);
auto y_matrix = y->dims().size() > 2
? framework::ReshapeToMatrix(*y, y_num_col_dims)
: static_cast<const Tensor&>(*y);
auto* dout = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
Tensor dout_mat;
dout_mat.ShareDataWith(*dout);
dout_mat.Resize({phi::flatten_to_2d(x->dims(), x_num_col_dims)[0],
phi::flatten_to_2d(y->dims(), y_num_col_dims)[1]});
auto* dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
if (dx != nullptr) {
dx->set_lod(x->lod());
}
if (dy != nullptr) {
dy->set_lod(y->lod());
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
Tensor dx_matrix = dx->dims().size() > 2
? framework::ReshapeToMatrix(*dx, x_num_col_dims)
: *dx;
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
blas.MatMul(dout_mat, false, y_matrix, true, &dx_matrix);
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
Tensor dy_matrix = dy->dims().size() > 2
? framework::ReshapeToMatrix(*dy, y_num_col_dims)
: *dy;
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
blas.MatMul(x_matrix, true, dout_mat, false, &dy_matrix);
}
}
};
template <typename DeviceContext, typename T>
class MulDoubleGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims");
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto x_mat = x->dims().size() > 2
? framework::ReshapeToMatrix(*x, x_num_col_dims)
: static_cast<const Tensor&>(*x);
auto y_mat = y->dims().size() > 2
? framework::ReshapeToMatrix(*y, y_num_col_dims)
: static_cast<const Tensor&>(*y);
const int m = phi::flatten_to_2d(x->dims(), x_num_col_dims)[0];
const int n = phi::flatten_to_2d(y->dims(), y_num_col_dims)[1];
auto* dout = ctx.Input<framework::LoDTensor>("DOut");
Tensor dout_mat;
dout_mat.ShareDataWith(*dout);
dout_mat.Resize({m, n});
auto* ddx = ctx.Input<framework::LoDTensor>("DDX");
auto* ddy = ctx.Input<framework::LoDTensor>("DDY");
auto* dx = ctx.Output<framework::LoDTensor>("DX");
auto* dy = ctx.Output<framework::LoDTensor>("DY");
auto* ddout = ctx.Output<framework::LoDTensor>("DDOut");
Tensor ddout_mat;
if (ddout) {
ddout->set_lod(dout->lod());
// allocate and reshape ddout
ddout->mutable_data<T>(ctx.GetPlace());
ddout_mat.ShareDataWith(*ddout);
ddout_mat.Resize({m, n});
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = phi::funcs::GetBlas<DeviceContext, T>(dev_ctx);
// a flag to specify whether ddout value has been set, if flag
// is false, MatMul beta should be 0 to set ddout, if flag is
// true, MatMul beta should be 1 to add result to ddout.
bool ddout_flag = false;
if (ddx) {
auto ddx_mat = ddx->dims().size() > 2
? framework::ReshapeToMatrix(*ddx, x_num_col_dims)
: static_cast<const Tensor&>(*ddx);
// dy = ddx' * dout. dy : K x M, ddx' : K x M, dout : M x N
if (dy) {
dy->set_lod(y->lod());
// allocate and reshape dy
dy->mutable_data<T>(ctx.GetPlace());
Tensor dy_mat = dy->dims().size() > 2
? framework::ReshapeToMatrix(*dy, y_num_col_dims)
: *dy;
blas.MatMul(ddx_mat, true, dout_mat, false, &dy_mat);
}
// ddout1 = ddx * y. ddx : M x K, y : K x N, ddout1 : M x N
if (ddout) {
blas.MatMul(ddx_mat, false, y_mat, false, static_cast<T>(1.0),
&ddout_mat, static_cast<T>(ddout_flag));
ddout_flag = true;
}
}
if (ddy) {
auto ddy_mat = ddy->dims().size() > 2
? framework::ReshapeToMatrix(*ddy, y_num_col_dims)
: static_cast<const Tensor&>(*ddy);
// dx = dout * ddy'. dout : M x N, ddy' : N x K, dx : M x K
if (dx) {
dx->set_lod(x->lod());
// allocate and reshape dx
dx->mutable_data<T>(ctx.GetPlace());
Tensor dx_mat = dx->dims().size() > 2
? framework::ReshapeToMatrix(*dx, x_num_col_dims)
: *dx;
blas.MatMul(dout_mat, false, ddy_mat, true, &dx_mat);
}
// ddout2 = x * ddy. x : M x K, ddy : K x N, ddout2 : M x N
if (ddout) {
blas.MatMul(x_mat, false, ddy_mat, false, static_cast<T>(1.0),
&ddout_mat, static_cast<T>(ddout_flag));
}
}
}
};
} // namespace operators
} // namespace paddle
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/operators/mul_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle { namespace paddle {
......
...@@ -14,11 +14,11 @@ limitations under the License. */ ...@@ -14,11 +14,11 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/mul_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -45,3 +45,17 @@ PD_REGISTER_KERNEL(matmul_triple_grad, ...@@ -45,3 +45,17 @@ PD_REGISTER_KERNEL(matmul_triple_grad,
double, double,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(matmul_with_flatten_grad,
CPU,
ALL_LAYOUT,
phi::MatmulWithFlattenGradKernel,
float,
double) {}
PD_REGISTER_KERNEL(matmul_with_flatten_double_grad,
CPU,
ALL_LAYOUT,
phi::MatmulWithFlattenDoubleGradKernel,
float,
double) {}
...@@ -28,3 +28,10 @@ PD_REGISTER_KERNEL(matmul, ...@@ -28,3 +28,10 @@ PD_REGISTER_KERNEL(matmul,
double, double,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(matmul_with_flatten,
CPU,
ALL_LAYOUT,
phi::MatmulWithFlattenKernel,
float,
double) {}
...@@ -49,3 +49,19 @@ PD_REGISTER_KERNEL(matmul_triple_grad, ...@@ -49,3 +49,19 @@ PD_REGISTER_KERNEL(matmul_triple_grad,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(matmul_with_flatten_grad,
GPU,
ALL_LAYOUT,
phi::MatmulWithFlattenGradKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(matmul_with_flatten_double_grad,
GPU,
ALL_LAYOUT,
phi::MatmulWithFlattenDoubleGradKernel,
float,
double,
phi::dtype::float16) {}
...@@ -30,3 +30,11 @@ PD_REGISTER_KERNEL(matmul, ...@@ -30,3 +30,11 @@ PD_REGISTER_KERNEL(matmul,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(matmul_with_flatten,
GPU,
ALL_LAYOUT,
phi::MatmulWithFlattenKernel,
float,
double,
phi::dtype::float16) {}
...@@ -1731,4 +1731,163 @@ void MatmulTripleGradKernel(const Context& dev_ctx, ...@@ -1731,4 +1731,163 @@ void MatmulTripleGradKernel(const Context& dev_ctx,
} }
} }
template <typename T, typename Context>
void MatmulWithFlattenGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
int x_num_col_dims,
int y_num_col_dims,
DenseTensor* x_grad,
DenseTensor* y_grad) {
auto x_matrix = x.dims().size() > 2
? paddle::framework::ReshapeToMatrix(x, x_num_col_dims)
: x;
auto y_matrix = y.dims().size() > 2
? paddle::framework::ReshapeToMatrix(y, y_num_col_dims)
: y;
auto* dout = &out_grad;
DenseTensor dout_mat(*dout);
dout_mat.Resize({phi::flatten_to_2d(x.dims(), x_num_col_dims)[0],
phi::flatten_to_2d(y.dims(), y_num_col_dims)[1]});
auto* dx = x_grad;
auto* dy = y_grad;
if (dx != nullptr) {
dx->set_lod(x.lod());
}
if (dy != nullptr) {
dy->set_lod(y.lod());
}
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
if (dx) {
dev_ctx.template Alloc<T>(dx);
DenseTensor dx_matrix =
dx->dims().size() > 2
? paddle::framework::ReshapeToMatrix(*dx, x_num_col_dims)
: *dx;
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
blas.MatMul(dout_mat, false, y_matrix, true, &dx_matrix);
}
if (dy) {
dev_ctx.template Alloc<T>(dy);
DenseTensor dy_matrix =
dy->dims().size() > 2
? paddle::framework::ReshapeToMatrix(*dy, y_num_col_dims)
: *dy;
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
blas.MatMul(x_matrix, true, dout_mat, false, &dy_matrix);
}
}
template <typename T, typename Context>
void MatmulWithFlattenDoubleGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
paddle::optional<const DenseTensor&> x_grad_grad,
paddle::optional<const DenseTensor&> y_grad_grad,
int x_num_col_dims,
int y_num_col_dims,
DenseTensor* x_grad,
DenseTensor* y_grad,
DenseTensor* out_grad_grad) {
auto x_mat = x.dims().size() > 2
? paddle::framework::ReshapeToMatrix(x, x_num_col_dims)
: x;
auto y_mat = y.dims().size() > 2
? paddle::framework::ReshapeToMatrix(y, y_num_col_dims)
: y;
const int m = phi::flatten_to_2d(x.dims(), x_num_col_dims)[0];
const int n = phi::flatten_to_2d(y.dims(), y_num_col_dims)[1];
auto* dout = &out_grad;
DenseTensor dout_mat(*dout);
dout_mat.Resize({m, n});
auto* ddx = x_grad_grad.get_ptr();
auto* ddy = y_grad_grad.get_ptr();
auto* dx = x_grad;
auto* dy = y_grad;
auto* ddout = out_grad_grad;
DenseTensor ddout_mat;
if (ddout) {
ddout->set_lod(dout->lod());
// allocate and reshape ddout
dev_ctx.template Alloc<T>(ddout);
ddout_mat.ShareDataWith(*ddout);
ddout_mat.Resize({m, n});
}
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
// a flag to specify whether ddout value has been set, if flag
// is false, MatMul beta should be 0 to set ddout, if flag is
// true, MatMul beta should be 1 to add result to ddout.
bool ddout_flag = false;
if (ddx) {
auto ddx_mat =
ddx->dims().size() > 2
? paddle::framework::ReshapeToMatrix(*ddx, x_num_col_dims)
: static_cast<const DenseTensor&>(*ddx);
// dy = ddx' * dout. dy : K x M, ddx' : K x M, dout : M x N
if (dy) {
dy->set_lod(y.lod());
// allocate and reshape dy
dev_ctx.template Alloc<T>(dy);
DenseTensor dy_mat =
dy->dims().size() > 2
? paddle::framework::ReshapeToMatrix(*dy, y_num_col_dims)
: *dy;
blas.MatMul(ddx_mat, true, dout_mat, false, &dy_mat);
}
// ddout1 = ddx * y. ddx : M x K, y : K x N, ddout1 : M x N
if (ddout) {
blas.MatMul(ddx_mat,
false,
y_mat,
false,
static_cast<T>(1.0),
&ddout_mat,
static_cast<T>(ddout_flag));
ddout_flag = true;
}
}
if (ddy) {
auto ddy_mat =
ddy->dims().size() > 2
? paddle::framework::ReshapeToMatrix(*ddy, y_num_col_dims)
: static_cast<const DenseTensor&>(*ddy);
// dx = dout * ddy'. dout : M x N, ddy' : N x K, dx : M x K
if (dx) {
dx->set_lod(x.lod());
// allocate and reshape dx
dev_ctx.template Alloc<T>(dx);
DenseTensor dx_mat =
dx->dims().size() > 2
? paddle::framework::ReshapeToMatrix(*dx, x_num_col_dims)
: *dx;
blas.MatMul(dout_mat, false, ddy_mat, true, &dx_mat);
}
// ddout2 = x * ddy. x : M x K, ddy : K x N, ddout2 : M x N
if (ddout) {
blas.MatMul(x_mat,
false,
ddy_mat,
false,
static_cast<T>(1.0),
&ddout_mat,
static_cast<T>(ddout_flag));
}
}
}
} // namespace phi } // namespace phi
...@@ -506,4 +506,34 @@ void MatmulKernel(const Context& dev_ctx, ...@@ -506,4 +506,34 @@ void MatmulKernel(const Context& dev_ctx,
MatMulFunction<Context, T>(dev_ctx, x, y, out, transpose_x, transpose_y); MatMulFunction<Context, T>(dev_ctx, x, y, out, transpose_x, transpose_y);
} }
template <typename T, typename Context>
void MatmulWithFlattenKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int x_num_col_dims,
int y_num_col_dims,
DenseTensor* out) {
const DenseTensor x_matrix =
x.dims().size() > 2
? paddle::framework::ReshapeToMatrix(x, x_num_col_dims)
: x;
const DenseTensor y_matrix =
y.dims().size() > 2
? paddle::framework::ReshapeToMatrix(y, y_num_col_dims)
: y;
dev_ctx.template Alloc<T>(out);
auto z_dim = out->dims();
if (z_dim.size() != 2) {
out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]});
}
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
blas.MatMul(x_matrix, y_matrix, out);
if (z_dim.size() != 2) {
out->Resize(z_dim);
}
}
} // namespace phi } // namespace phi
...@@ -60,4 +60,28 @@ void MatmulTripleGradKernel(const Context& dev_ctx, ...@@ -60,4 +60,28 @@ void MatmulTripleGradKernel(const Context& dev_ctx,
DenseTensor* out_d_ddx, DenseTensor* out_d_ddx,
DenseTensor* out_d_ddy); DenseTensor* out_d_ddy);
template <typename T, typename Context>
void MatmulWithFlattenGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
int x_num_col_dims,
int y_num_col_dims,
DenseTensor* x_grad,
DenseTensor* y_grad);
template <typename T, typename Context>
void MatmulWithFlattenDoubleGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
paddle::optional<const DenseTensor&> x_grad_grad,
paddle::optional<const DenseTensor&> y_grad_grad,
int x_num_col_dims,
int y_num_col_dims,
DenseTensor* x_grad,
DenseTensor* y_grad,
DenseTensor* out_grad_grad);
} // namespace phi } // namespace phi
...@@ -29,6 +29,16 @@ void MatmulKernel(const Context& dev_ctx, ...@@ -29,6 +29,16 @@ void MatmulKernel(const Context& dev_ctx,
bool transpose_y, bool transpose_y,
DenseTensor* out); DenseTensor* out);
// In order to be compatible with `mul` op in fluid,
// it is no longer used in 2.x API
template <typename T, typename Context>
void MatmulWithFlattenKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int x_num_col_dims,
int y_num_col_dims,
DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
DenseTensor Matmul(const Context& dev_ctx, DenseTensor Matmul(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature MulGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("matmul_with_flatten_grad",
{"X", "Y", GradVarName("Out")},
{"x_num_col_dims", "y_num_col_dims"},
{GradVarName("X"), GradVarName("Y")});
}
KernelSignature MulDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("matmul_with_flatten_double_grad",
{"X", "Y", "DOut", "DDX", "DDY"},
{"x_num_col_dims", "y_num_col_dims"},
{"DX", "DY", "DDOut"});
}
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(mul, matmul_with_flatten);
PD_REGISTER_BASE_KERNEL_NAME(mul_grad, matmul_with_flatten_grad);
PD_REGISTER_BASE_KERNEL_NAME(mul_grad_grad, matmul_with_flatten_double_grad);
PD_REGISTER_ARG_MAPPING_FN(mul_grad, phi::MulGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(mul_grad_grad, phi::MulDoubleGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册