未验证 提交 59fdf4da 编写于 作者: W Weilong Wu 提交者: GitHub

[New features] Add elementwise_mul triple grad kernel (#37152)

* Add elementwise_mul triple grad kernel

* Removed InplaceInferer and polished code
上级 84b0ec97
......@@ -110,6 +110,35 @@ class ElementwiseMulDoubleGradMaker : public framework::SingleGradOpMaker<T> {
}
};
template <typename T>
class ElementwiseMulTripleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("elementwise_mul_triple_grad");
// get input from double grad
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
op->SetInput("DOut", this->Input("DOut"));
op->SetInput("DDX", this->Input("DDX"));
op->SetInput("DDY", this->Input("DDY"));
op->SetInput("D_DX", this->OutputGrad(framework::GradVarName("X")));
op->SetInput("D_DY", this->OutputGrad(framework::GradVarName("Y")));
op->SetInput("D_DDOut", this->OutputGrad("DDOut"));
op->SetAttrMap(this->Attrs());
// set outputs
op->SetOutput("D_X", this->InputGrad("X"));
op->SetOutput("D_Y", this->InputGrad("Y"));
op->SetOutput("D_DOut", this->InputGrad("DOut"));
op->SetOutput("D_DDX", this->InputGrad("DDX"));
op->SetOutput("D_DDY", this->InputGrad("DDY"));
}
};
} // namespace operators
} // namespace paddle
......@@ -123,8 +152,13 @@ REGISTER_OPERATOR(
ops::ElementwiseMulDoubleGradMaker<paddle::framework::OpDesc>,
ops::ElementwiseMulDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_mul_grad_grad, ops::ElementwiseOpDoubleGrad,
ops::ElementwiseDoubleGradOpInplaceInferer);
REGISTER_OPERATOR(
elementwise_mul_grad_grad, ops::ElementwiseOpDoubleGrad,
ops::ElementwiseDoubleGradOpInplaceInferer,
ops::ElementwiseMulTripleGradMaker<paddle::framework::OpDesc>,
ops::ElementwiseMulTripleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_mul_triple_grad, ops::ElementwiseOpTripleGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_mul,
......@@ -164,6 +198,22 @@ REGISTER_OP_CPU_KERNEL(
paddle::platform::complex<float>>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
elementwise_mul_triple_grad,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
float>,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
double>,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
int>,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
bool>,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ElementwiseMulTripleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_VERSION(elementwise_mul)
.AddCheckpoint(
......
......@@ -141,3 +141,15 @@ REGISTER_OP_CUDA_KERNEL(
plat::complex<float>>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
elementwise_mul_triple_grad,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, bool>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext,
plat::complex<float>>,
ops::ElementwiseMulTripleGradKernel<plat::CUDADeviceContext,
plat::complex<double>>);
......@@ -283,5 +283,96 @@ class ElementwiseMulDoubleGradKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class ElementwiseMulTripleGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor;
// get input
auto* x = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y");
auto* dout = ctx.Input<framework::Tensor>("DOut");
auto* ddx = ctx.Input<framework::Tensor>("DDX");
auto* ddy = ctx.Input<framework::Tensor>("DDY");
auto* d_dx = ctx.Input<framework::Tensor>("D_DX");
auto* d_dy = ctx.Input<framework::Tensor>("D_DY");
auto* d_ddout = ctx.Input<framework::Tensor>("D_DDOut");
// get output
auto* out_d_x = ctx.Output<framework::Tensor>("D_X");
auto* out_d_y = ctx.Output<framework::Tensor>("D_Y");
auto* out_d_dout = ctx.Output<framework::Tensor>("D_DOut");
auto* out_d_ddx = ctx.Output<framework::Tensor>("D_DDX");
auto* out_d_ddy = ctx.Output<framework::Tensor>("D_DDY");
if (out_d_x) out_d_x->mutable_data<T>(x->dims(), ctx.GetPlace());
if (out_d_y) out_d_y->mutable_data<T>(y->dims(), ctx.GetPlace());
if (out_d_dout) out_d_dout->mutable_data<T>(dout->dims(), ctx.GetPlace());
if (out_d_ddx) out_d_ddx->mutable_data<T>(x->dims(), ctx.GetPlace());
if (out_d_ddy) out_d_ddy->mutable_data<T>(y->dims(), ctx.GetPlace());
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
Tensor ddx_safe, ddy_safe;
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, x, ddx, &ddx_safe);
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, y, ddy, &ddy_safe);
if (d_ddout) {
if (out_d_x) {
// out_d_x = ddy * d_ddout
default_elementwise_mul<DeviceContext, T>(ctx, &ddy_safe, d_ddout,
out_d_x);
}
if (out_d_y) {
// out_d_y = ddx * d_ddout
default_elementwise_mul<DeviceContext, T>(ctx, &ddx_safe, d_ddout,
out_d_y);
}
}
if (out_d_dout) {
// get out_d_dout
// out_d_dout = ddy * d_dx + d_dy * ddx
Tensor out_d_dout_tmp;
out_d_dout_tmp.mutable_data<T>(dout->dims(), ctx.GetPlace());
default_elementwise_mul<DeviceContext, T>(ctx, d_dy, &ddx_safe,
out_d_dout);
default_elementwise_mul<DeviceContext, T>(ctx, &ddy_safe, d_dx,
&out_d_dout_tmp);
auto out_d_dout_t = framework::EigenVector<T>::Flatten(*out_d_dout);
auto out_d_dout_tmp_t =
framework::EigenVector<T>::Flatten(out_d_dout_tmp);
out_d_dout_t.device(place) = out_d_dout_t + out_d_dout_tmp_t;
}
if (out_d_ddx) {
// get out_d_ddx
// out_d_ddx = dout * d_dy + y * d_ddout
Tensor out_d_ddx_tmp;
out_d_ddx_tmp.mutable_data<T>(ddx->dims(), ctx.GetPlace());
default_elementwise_mul<DeviceContext, T>(ctx, dout, d_dy, out_d_ddx);
default_elementwise_mul<DeviceContext, T>(ctx, y, d_ddout,
&out_d_ddx_tmp);
auto out_d_ddx_t = framework::EigenVector<T>::Flatten(*out_d_ddx);
auto out_d_ddx_tmp_t = framework::EigenVector<T>::Flatten(out_d_ddx_tmp);
out_d_ddx_t.device(place) = out_d_ddx_t + out_d_ddx_tmp_t;
}
if (out_d_ddy) {
// get out_d_ddy
// out_d_ddy = dout * d_dx + x * d_ddout
Tensor out_d_ddy_tmp;
out_d_ddy_tmp.mutable_data<T>(ddy->dims(), ctx.GetPlace());
default_elementwise_mul<DeviceContext, T>(ctx, dout, d_dx, out_d_ddy);
default_elementwise_mul<DeviceContext, T>(ctx, x, d_ddout,
&out_d_ddy_tmp);
auto out_d_ddy_t = framework::EigenVector<T>::Flatten(*out_d_ddy);
auto out_d_ddy_tmp_t = framework::EigenVector<T>::Flatten(out_d_ddy_tmp);
out_d_ddy_t.device(place) = out_d_ddy_t + out_d_ddy_tmp_t;
}
}
};
} // namespace operators
} // namespace paddle
......@@ -451,6 +451,18 @@ class ElementwiseOpTripleGrad : public framework::OperatorWithKernel {
ctx->ShareDim("DDY", "D_DDY");
ctx->ShareLoD("DDY", "D_DDY");
}
if (ctx->HasOutput("D_X")) {
ctx->ShareDim("X", "D_X");
ctx->ShareLoD("X", "D_X");
}
if (ctx->HasOutput("D_Y")) {
ctx->ShareDim("Y", "D_Y");
ctx->ShareLoD("Y", "D_Y");
}
if (ctx->HasOutput("D_DOut")) {
ctx->ShareDim("DOut", "D_DOut");
ctx->ShareLoD("DOut", "D_DOut");
}
}
framework::OpKernelType GetExpectedKernelType(
......
......@@ -297,5 +297,59 @@ class TestElementwiseAddBroadcastTripleGradCheck(unittest.TestCase):
self.func(p)
class TestElementwiseMulTripleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
# the shape of input variable should be clearly specified, not inlcude -1.
shape = [2, 3, 4, 5]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.data('y', shape, False, dtype)
x.persistable = True
y.persistable = True
out = layers.elementwise_mul(x, y)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, shape).astype(dtype)
gradient_checker.triple_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestElementwiseMulBroadcastTripleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
# the shape of input variable should be clearly specified, not inlcude -1.
shape = [2, 3, 4, 5]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.data('y', shape[:-1], False, dtype)
x.persistable = True
y.persistable = True
out = layers.elementwise_add(x, y, axis=0)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, shape[:-1]).astype(dtype)
gradient_checker.triple_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册