未验证 提交 60be66e2 编写于 作者: K Kaipeng Deng 提交者: GitHub

support fc_op double grad (#17317)

* add double grad for mul_op. test=develop

* fix format. test=develop

* fix format. test=develop

* fix format. test=develop

* refine code. test=develop

* remove setzero. test=develop

* fix dx/dy init bug. test=develop

* fix format. test=develop
上级 ad8bbe58
...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/mul_op.h" #include "paddle/fluid/operators/mul_op.h"
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
namespace paddle { namespace paddle {
...@@ -178,16 +180,66 @@ class MulOpGradMaker : public framework::SingleGradOpDescMaker { ...@@ -178,16 +180,66 @@ class MulOpGradMaker : public framework::SingleGradOpDescMaker {
} }
}; };
class MulDoubleGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(ctx->HasInput("DOut"), "Input(DOut) should not be null");
if (ctx->HasOutput("DX")) {
ctx->ShareDim("X", "DX");
}
if (ctx->HasOutput("DY")) {
ctx->ShareDim("Y", "DY");
}
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("DOut", "DDOut");
}
}
};
class MulDoubleGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> retv(new framework::OpDesc());
retv->SetType("mul_grad_grad");
retv->SetInput("X", Input("X"));
retv->SetInput("Y", Input("Y"));
retv->SetInput("DOut", Input(framework::GradVarName("Out")));
retv->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
retv->SetInput("DDY", OutputGrad(framework::GradVarName("Y")));
retv->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
retv->SetOutput("DX", InputGrad("X"));
retv->SetOutput("DY", InputGrad("Y"));
retv->SetAttrMap(Attrs());
return retv;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpInferVarType, REGISTER_OPERATOR(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpInferVarType,
ops::MulOpGradMaker); ops::MulOpGradMaker);
REGISTER_OPERATOR(mul_grad, ops::MulGradOp); REGISTER_OPERATOR(mul_grad, ops::MulGradOp, ops::MulDoubleGradMaker);
REGISTER_OPERATOR(mul_grad_grad, ops::MulDoubleGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mul, ops::MulKernel<paddle::platform::CPUDeviceContext, float>, mul, ops::MulKernel<paddle::platform::CPUDeviceContext, float>,
ops::MulKernel<paddle::platform::CPUDeviceContext, double>); ops::MulKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mul_grad, ops::MulGradKernel<paddle::platform::CPUDeviceContext, float>, mul_grad, ops::MulGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MulGradKernel<paddle::platform::CPUDeviceContext, double>); ops::MulGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
mul_grad_grad,
ops::MulDoubleGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MulDoubleGradKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -24,3 +24,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -24,3 +24,7 @@ REGISTER_OP_CUDA_KERNEL(
mul_grad, ops::MulGradKernel<plat::CUDADeviceContext, float>, mul_grad, ops::MulGradKernel<plat::CUDADeviceContext, float>,
ops::MulGradKernel<plat::CUDADeviceContext, double>, ops::MulGradKernel<plat::CUDADeviceContext, double>,
ops::MulGradKernel<plat::CUDADeviceContext, plat::float16>); ops::MulGradKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
mul_grad_grad,
ops::MulDoubleGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::MulDoubleGradKernel<paddle::platform::CUDADeviceContext, double>);
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -109,5 +110,95 @@ class MulGradKernel : public framework::OpKernel<T> { ...@@ -109,5 +110,95 @@ class MulGradKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T>
class MulDoubleGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims");
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto x_mat = x->dims().size() > 2
? framework::ReshapeToMatrix(*x, x_num_col_dims)
: static_cast<const Tensor&>(*x);
auto y_mat = y->dims().size() > 2
? framework::ReshapeToMatrix(*y, y_num_col_dims)
: static_cast<const Tensor&>(*y);
const int m = framework::flatten_to_2d(x->dims(), x_num_col_dims)[0];
const int n = framework::flatten_to_2d(y->dims(), y_num_col_dims)[1];
auto* dout = ctx.Input<framework::LoDTensor>("DOut");
Tensor dout_mat;
dout_mat.ShareDataWith(*dout);
dout_mat.Resize({m, n});
auto* ddx = ctx.Input<framework::LoDTensor>("DDX");
auto* ddy = ctx.Input<framework::LoDTensor>("DDY");
auto* dx = ctx.Output<framework::LoDTensor>("DX");
auto* dy = ctx.Output<framework::LoDTensor>("DY");
auto* ddout = ctx.Output<framework::LoDTensor>("DDOut");
Tensor ddout_mat;
if (ddout) {
ddout->set_lod(dout->lod());
// allocate and reshape ddout
ddout->mutable_data<T>(ctx.GetPlace());
ddout_mat.ShareDataWith(*ddout);
ddout_mat.Resize({m, n});
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
// a flag to specify whether ddout value has been set, if flag
// is false, MatMul beta should be 0 to set ddout, if flag is
// true, MatMul beta should be 1 to add result to ddout.
bool ddout_flag = false;
if (ddx) {
auto ddx_mat = ddx->dims().size() > 2
? framework::ReshapeToMatrix(*ddx, x_num_col_dims)
: static_cast<const Tensor&>(*ddx);
// dy = ddx' * dout. dy : K x M, ddx' : K x M, dout : M x N
if (dy) {
dy->set_lod(y->lod());
// allocate and reshape dy
dy->mutable_data<T>(ctx.GetPlace());
Tensor dy_mat = dy->dims().size() > 2
? framework::ReshapeToMatrix(*dy, y_num_col_dims)
: *dy;
blas.MatMul(ddx_mat, true, dout_mat, false, &dy_mat);
}
// ddout1 = ddx * y. ddx : M x K, y : K x N, ddout1 : M x N
if (ddout) {
blas.MatMul(ddx_mat, false, y_mat, false, static_cast<T>(1.0),
&ddout_mat, static_cast<T>(ddout_flag));
ddout_flag = true;
}
}
if (ddy) {
auto ddy_mat = ddy->dims().size() > 2
? framework::ReshapeToMatrix(*ddy, y_num_col_dims)
: static_cast<const Tensor&>(*ddy);
// dx = dout * ddy'. dout : M x N, ddy' : N x K, dx : M x K
if (dx) {
dx->set_lod(x->lod());
// allocate and reshape dx
dx->mutable_data<T>(ctx.GetPlace());
Tensor dx_mat = dx->dims().size() > 2
? framework::ReshapeToMatrix(*dx, x_num_col_dims)
: *dx;
blas.MatMul(dout_mat, false, ddy_mat, true, &dx_mat);
}
// ddout2 = x * ddy. x : M x K, ddy : K x N, ddout2 : M x N
if (ddout) {
blas.MatMul(x_mat, false, ddy_mat, false, static_cast<T>(1.0),
&ddout_mat, static_cast<T>(ddout_flag));
}
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -193,5 +193,33 @@ class TestElementwiseMulBroadcastDoubleGradCheck(unittest.TestCase): ...@@ -193,5 +193,33 @@ class TestElementwiseMulBroadcastDoubleGradCheck(unittest.TestCase):
self.func(p) self.func(p)
class TestMulDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
# the shape of input variable shoule be clearly specified, not inlcude -1.
x_shape = [7, 11]
y_shape = [11, 9]
eps = 0.005
dtype = np.float64
x = layers.data('x', x_shape, False, dtype)
x.persistable = True
y = layers.data('y', y_shape, False, dtype)
y.persistable = True
out = layers.mul(x, y)
x_arr = np.random.uniform(-1, 1, x_shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, y_shape).astype(dtype)
gradient_checker.double_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册