未验证 提交 8bae8590 编写于 作者: K Kaipeng Deng 提交者: GitHub

add double grad for elementwise_mul op (#17255)

* add double grad for elementwise_mul. test=develop

* remove comment. test=develop

* fix grad sum. test=develop

* fix for axis expand. test=develop

* add test for axis expand. test=develop
上级 11d3a38f
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include <memory>
#include <string> #include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h"
...@@ -43,6 +44,30 @@ class ElementwiseMulOpMaker : public ElementwiseOpMaker { ...@@ -43,6 +44,30 @@ class ElementwiseMulOpMaker : public ElementwiseOpMaker {
virtual std::string GetEquation() const { return "Out = X \\\\odot Y"; } virtual std::string GetEquation() const { return "Out = X \\\\odot Y"; }
}; };
class ElementwiseMulDoubleGradDescMaker
: public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("elementwise_mul_grad_grad");
op->SetInput("X", Input("X"));
op->SetInput("Y", Input("Y"));
op->SetInput("DOut", Input(framework::GradVarName("Out")));
op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
op->SetInput("DDY", OutputGrad(framework::GradVarName("Y")));
op->SetAttrMap(Attrs());
op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), InputGrad("Y"));
return op;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -50,7 +75,9 @@ namespace ops = paddle::operators; ...@@ -50,7 +75,9 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_mul, ops::ElementwiseOp, REGISTER_OPERATOR(elementwise_mul, ops::ElementwiseOp,
ops::ElementwiseMulOpMaker, ops::ElementwiseOpInferVarType, ops::ElementwiseMulOpMaker, ops::ElementwiseOpInferVarType,
ops::ElementwiseMulOpGradDescMaker); ops::ElementwiseMulOpGradDescMaker);
REGISTER_OPERATOR(elementwise_mul_grad, ops::ElementwiseOpGrad); REGISTER_OPERATOR(elementwise_mul_grad, ops::ElementwiseOpGrad,
ops::ElementwiseMulDoubleGradDescMaker);
REGISTER_OPERATOR(elementwise_mul_grad_grad, ops::ElementwiseOpDoubleGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_mul, elementwise_mul,
...@@ -64,3 +91,13 @@ REGISTER_OP_CPU_KERNEL( ...@@ -64,3 +91,13 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, double>, ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int>, ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
elementwise_mul_grad_grad,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
float>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
double>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
int>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>);
...@@ -88,3 +88,9 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -88,3 +88,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int>, ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int64_t>, ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>); ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
elementwise_mul_grad_grad,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int64_t>);
...@@ -123,5 +123,56 @@ class ElementwiseMulGradKernel : public ElemwiseGradKernel<T> { ...@@ -123,5 +123,56 @@ class ElementwiseMulGradKernel : public ElemwiseGradKernel<T> {
ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX<T>(), MulGradDY<T>()); ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX<T>(), MulGradDY<T>());
} }
}; };
template <typename DeviceContext, typename T>
class ElementwiseMulDoubleGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>("DOut");
auto* ddx = ctx.Input<Tensor>("DDX");
auto* ddy = ctx.Input<Tensor>("DDY");
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* ddout = ctx.Output<Tensor>("DDOut");
if (ddout) ddout->mutable_data<T>(ctx.GetPlace());
// dx = dout * ddy
// dy = dout * ddx
Tensor ddx_safe, ddy_safe;
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, x, ddx, &ddx_safe);
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, y, ddy, &ddy_safe);
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, MulGradDX<T>, MulGradDY<T>>(
ctx, ddx_safe, ddy_safe, *dout, *dout, axis, dx, dy, MulGradDX<T>(),
MulGradDY<T>());
// ddout = ddx * y + x * ddy
if (ddout) {
if (ddx && ddy) {
Tensor ddout_tmp;
ddout_tmp.mutable_data<T>(ddout->dims(), ctx.GetPlace());
default_elementwise_mul<DeviceContext, T>(ctx, ddx, y, ddout);
default_elementwise_mul<DeviceContext, T>(ctx, x, ddy, &ddout_tmp);
auto& place =
*ctx.template device_context<DeviceContext>().eigen_device();
auto ddout_t = framework::EigenVector<T>::Flatten(*ddout);
auto ddout_tmp_t = framework::EigenVector<T>::Flatten(ddout_tmp);
ddout_t.device(place) = ddout_t + ddout_tmp_t;
} else {
if (ddx) default_elementwise_mul<DeviceContext, T>(ctx, ddx, y, ddout);
if (ddy) default_elementwise_mul<DeviceContext, T>(ctx, x, ddy, ddout);
}
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -212,6 +212,43 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -212,6 +212,43 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
} }
}; };
class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override {
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->ShareDim("X", x_grad_name);
ctx->ShareLoD("X", x_grad_name);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->ShareDim("Y", y_grad_name);
ctx->ShareLoD("Y", y_grad_name);
}
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("DOut", "DDOut");
ctx->ShareLoD("DOut", "DDOut");
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type = ctx.Input<Tensor>("DDX")->type();
#ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
// For Add, Sub op, the X, Out is not needed. // For Add, Sub op, the X, Out is not needed.
class ElementwiseOpExplicitGrad : public ElementwiseOpGrad { class ElementwiseOpExplicitGrad : public ElementwiseOpGrad {
public: public:
......
...@@ -1636,5 +1636,20 @@ void FusedElemwiseAndActComputeEx(const framework::ExecutionContext &ctx, ...@@ -1636,5 +1636,20 @@ void FusedElemwiseAndActComputeEx(const framework::ExecutionContext &ctx,
} }
} }
} }
template <typename DeviceContext, typename T>
static inline void GetDoubleGradSafeTensor(
const framework::ExecutionContext &ctx, const framework::Tensor *x,
const framework::Tensor *ddx, framework::Tensor *ddx_safe) {
if (ddx) {
*ddx_safe = *ddx;
} else {
ddx_safe->mutable_data<T>(x->dims(), ctx.GetPlace());
math::SetConstant<DeviceContext, T> set_zero;
set_zero(ctx.template device_context<DeviceContext>(), ddx_safe,
static_cast<T>(0));
}
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -139,5 +139,59 @@ class TestSquareDoubleGradCheck(unittest.TestCase): ...@@ -139,5 +139,59 @@ class TestSquareDoubleGradCheck(unittest.TestCase):
self.func(p) self.func(p)
class TestElementwiseMulDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
# the shape of input variable shoule be clearly specified, not inlcude -1.
shape = [7, 9]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.data('y', shape, False, dtype)
x.persistable = True
y.persistable = True
out = layers.elementwise_mul(x, y)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, shape).astype(dtype)
gradient_checker.double_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestElementwiseMulBroadcastDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
# the shape of input variable shoule be clearly specified, not inlcude -1.
shape = [7, 9]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.data('y', shape[:-1], False, dtype)
x.persistable = True
y.persistable = True
out = layers.elementwise_mul(x, y, axis=0)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, shape[:-1]).astype(dtype)
gradient_checker.double_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册