未验证 提交 11c9169f 编写于 作者: W wangxinxin08 提交者: GitHub

[cherry-pick]add doublegrad op for matmul (#27800)

* add matmul doublegrad op

* fix compile errors

* modify code according to review

* delete float16

* delete GetDimForInput to be consitent with release/1.8
无相关合并请求
...@@ -318,6 +318,181 @@ class MatMulGradKernel : public framework::OpKernel<T> { ...@@ -318,6 +318,181 @@ class MatMulGradKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T>
class MatMulDoubleGradKernel : public framework::OpKernel<T> {
public:
void MatMul(const framework::ExecutionContext &context,
const framework::Tensor &a, bool trans_a,
const framework::Tensor &b, bool trans_b, bool flag,
framework::Tensor *out) const {
out->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(context);
auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a);
auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b);
int head_number = 1;
#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA)
head_number = context.Attr<int>("head_number");
#endif
if (head_number <= 1 && a.dims().size() == 3 && b.dims().size() <= 2) {
// the transpose_X must be false, if is true, the transpose cost much time
if (!trans_a) {
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
}
}
blas.MatMul(a, mat_dim_a, b, mat_dim_b,
static_cast<T>(context.Attr<float>("alpha")), out,
static_cast<T>(flag));
}
void CalcInputGrad(const framework::ExecutionContext &context,
const framework::Tensor &a, bool trans_a,
bool is_fold_init_dims_a, const framework::Tensor &b,
bool trans_b, bool is_fold_init_dims_b, bool flag,
framework::Tensor *out) const {
if (out == nullptr) return;
bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) &&
out->dims().size() == 2;
if (!need_combine) {
MatMul(context, a, trans_a, b, trans_b, flag, out);
} else {
auto &ctx = context.template device_context<DeviceContext>();
MatMul(context, is_fold_init_dims_a
? FoldInitDims(a)
: FoldHeadAndLastDims<DeviceContext, T>(ctx, a),
trans_a, is_fold_init_dims_b
? FoldInitDims(b)
: FoldHeadAndLastDims<DeviceContext, T>(ctx, b),
trans_b, flag, out);
}
}
void Compute(const framework::ExecutionContext &context) const override {
auto x = *context.Input<framework::Tensor>("X");
auto y = *context.Input<framework::Tensor>("Y");
auto dout = *context.Input<framework::LoDTensor>("DOut");
auto *ddx = context.Input<framework::LoDTensor>("DDX");
auto *ddy = context.Input<framework::LoDTensor>("DDY");
auto *dx = context.Output<framework::LoDTensor>("DX");
auto *dy = context.Output<framework::LoDTensor>("DY");
auto *ddout = context.Output<framework::LoDTensor>("DDOut");
bool transpose_x = context.Attr<bool>("transpose_X");
bool transpose_y = context.Attr<bool>("transpose_Y");
ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);
framework::DDim dx_dims;
if (dx) {
dx_dims = dx->dims();
if (dx_dims != x.dims()) {
dx->Resize(x.dims());
}
}
framework::DDim dy_dims;
if (dy) {
dy_dims = dy->dims();
if (dy_dims != y.dims()) {
dy->Resize(y.dims());
}
}
framework::DDim ddout_dims;
if (ddout) {
ddout_dims = ddout->dims();
if (ddout_dims != dout.dims()) {
ddout->Resize(dout.dims());
}
}
bool ddout_flag = false;
if (ddx) {
auto ddx_mat = *ddx;
if (ddx_mat.dims() != x.dims()) {
ddx_mat.Resize(x.dims());
}
if (dy) {
if (transpose_x && transpose_y) {
// dy = dout' * ddx'
CalcInputGrad(context, dout, true, true, ddx_mat, true, false, false,
dy);
} else if (transpose_x) {
// dy = ddx * dout
CalcInputGrad(context, ddx_mat, false, false, dout, false, true,
false, dy);
} else if (transpose_y) {
// dy = dout' * ddx
CalcInputGrad(context, dout, true, true, ddx_mat, false, true, false,
dy);
} else {
// dy = ddx' * dout
CalcInputGrad(context, ddx_mat, true, true, dout, false, true, false,
dy);
}
}
if (ddout) {
CalcInputGrad(context, ddx_mat, transpose_x, true, y, transpose_y,
false, ddout_flag, ddout);
ddout_flag = true;
}
}
if (ddy) {
auto ddy_mat = *ddy;
if (ddy_mat.dims() != y.dims()) {
ddy_mat.Resize(y.dims());
}
if (dx) {
if (transpose_x && transpose_y) {
// dx = ddy' * dout'
CalcInputGrad(context, ddy_mat, true, true, dout, true, false, false,
dx);
} else if (transpose_x) {
// dx = ddy * dout'
CalcInputGrad(context, ddy_mat, false, false, dout, true, false,
false, dx);
} else if (transpose_y) {
// dx = dout * ddy
CalcInputGrad(context, dout, false, false, ddy_mat, false, true,
false, dx);
} else {
// dx = dout * ddy'
CalcInputGrad(context, dout, false, false, ddy_mat, true, false,
false, dx);
}
}
if (ddout) {
CalcInputGrad(context, x, transpose_x, true, ddy_mat, transpose_y,
false, ddout_flag, ddout);
}
}
if (dx) {
if (dx_dims != x.dims()) {
dx->Resize(dx_dims);
}
}
if (dy) {
if (dy_dims != y.dims()) {
dy->Resize(dy_dims);
}
}
if (ddout) {
if (ddout_dims != dout.dims()) {
ddout->Resize(ddout_dims);
}
}
}
};
class MatMulOp : public framework::OperatorWithKernel { class MatMulOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -602,6 +777,61 @@ class MatMulOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -602,6 +777,61 @@ class MatMulOpGradMaker : public framework::SingleGradOpMaker<T> {
retv->SetAttrMap(this->Attrs()); retv->SetAttrMap(this->Attrs());
} }
}; };
class MatMulOpDoubleGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "matmul");
OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "matmul");
OP_INOUT_CHECK(context->HasInput("DOut"), "Input", "DOut", "matmul");
if (context->HasOutput("DX") && context->HasInput("DDY")) {
context->ShareDim("X", "DX");
}
if (context->HasOutput("DY") && context->HasInput("DDX")) {
context->ShareDim("Y", "DY");
}
if (context->HasOutput("DDOut") &&
(context->HasInput("DDY") || context->HasInput("DDX"))) {
context->ShareDim("DOut", "DDOut");
}
}
};
template <typename T>
class MatMulOpDoubleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("matmul_grad_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput("Y", this->Input("Y"));
retv->SetInput("DOut", this->Input(framework::GradVarName("Out")));
retv->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
retv->SetInput("DDY", this->OutputGrad(framework::GradVarName("Y")));
auto ddx = this->OutputGrad(framework::GradVarName("X"));
auto ddy = this->OutputGrad(framework::GradVarName("Y"));
if (!ddx.empty() || !ddy.empty()) {
retv->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
}
retv->SetOutput(
"DX", ddy.empty() ? this->EmptyInputGrad() : this->InputGrad("X"));
retv->SetOutput(
"DY", ddx.empty() ? this->EmptyInputGrad() : this->InputGrad("Y"));
retv->SetAttrMap(this->Attrs());
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -609,7 +839,10 @@ namespace ops = paddle::operators; ...@@ -609,7 +839,10 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(matmul, ops::MatMulOp, ops::MatMulOpMaker, REGISTER_OPERATOR(matmul, ops::MatMulOp, ops::MatMulOpMaker,
ops::MatMulOpGradMaker<paddle::framework::OpDesc>, ops::MatMulOpGradMaker<paddle::framework::OpDesc>,
ops::MatMulOpGradMaker<paddle::imperative::OpBase>); ops::MatMulOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(matmul_grad, ops::MatMulOpGrad); REGISTER_OPERATOR(matmul_grad, ops::MatMulOpGrad,
ops::MatMulOpDoubleGradMaker<paddle::framework::OpDesc>,
ops::MatMulOpDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(matmul_grad_grad, ops::MatMulOpDoubleGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
matmul, ops::MatMulKernel<paddle::platform::CPUDeviceContext, float>, matmul, ops::MatMulKernel<paddle::platform::CPUDeviceContext, float>,
ops::MatMulKernel<paddle::platform::CPUDeviceContext, double>); ops::MatMulKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -618,6 +851,11 @@ REGISTER_OP_CPU_KERNEL( ...@@ -618,6 +851,11 @@ REGISTER_OP_CPU_KERNEL(
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, float>, ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, double>); ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
matmul_grad_grad,
ops::MatMulDoubleGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MatMulDoubleGradKernel<paddle::platform::CPUDeviceContext, double>);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
matmul, ops::MatMulKernel<paddle::platform::CUDADeviceContext, float>, matmul, ops::MatMulKernel<paddle::platform::CUDADeviceContext, float>,
...@@ -630,4 +868,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -630,4 +868,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, double>, ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, ops::MatMulGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>); paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
matmul_grad_grad,
ops::MatMulDoubleGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::MatMulDoubleGradKernel<paddle::platform::CUDADeviceContext, double>);
#endif #endif
...@@ -152,6 +152,38 @@ class TestMulDoubleGradCheck(unittest.TestCase): ...@@ -152,6 +152,38 @@ class TestMulDoubleGradCheck(unittest.TestCase):
self.func(p) self.func(p)
class TestMatmulDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
eps = 0.005
x_shapes = [[2], [2, 3], [2, 4, 3], [2, 3, 4, 5], [2, 3, 4]]
y_shapes = [[2], [3, 2], [2, 4, 5], [2, 3, 3, 5], [4, 3]]
transpose_xs = [False, True, True, False, False]
transpose_ys = [False, True, False, True, False]
dtypes = [np.float64, np.float64, np.float32, np.float32, np.float64]
typenames = ["float64", "float64", "float32", "float32", "float64"]
for i, (x_shape, y_shape, transpose_x, transpose_y, dtype, typename) \
in enumerate(zip(x_shapes, y_shapes, transpose_xs, transpose_ys, dtypes, typenames)):
x = layers.create_parameter(
dtype=typename, shape=x_shape, name='x{}'.format(i))
y = layers.create_parameter(
dtype=typename, shape=y_shape, name='y{}'.format(i))
out = layers.matmul(
x, y, transpose_x, transpose_y, name='out{}'.format(i))
x_arr = np.random.uniform(-1, 1, x_shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, y_shape).astype(dtype)
gradient_checker.double_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestReshapeDoubleGradCheck(unittest.TestCase): class TestReshapeDoubleGradCheck(unittest.TestCase):
@prog_scope() @prog_scope()
def func(self, place): def func(self, place):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部