提交 10b23a72 编写于 作者: L lvmengsi 提交者: Kaipeng Deng

Double backward elementwise div (#17416)

* double backward, elementwise_div

* fix dx empty. test=develop

* bug fix (#17392)

fix secure bug

* Eanble stack operator for a Ngraph, test=develop (#17406)

* fix sqrt_grad_grad unittest. test=develop (#17410)

* fix sqrt_grad_grad unittest. test=develop

* disable sqrt_grad_grad unittest. test=develop

* test=develop, fix unittest

* test=develop, fix unittest

* test=develop, fix unittest

* test=develop, fix bug

* fix unittest. test=develop

* fix unittest dx. test=develop

* tmp fix! for test... test=develop

* reduce tmp, test=develop

* test=develop, reduce tmp

* fix broadcast unittest. test=develop

* fix format. test=develop

* refine code. test=develop

* refine code. test=develop

* refine GetDoubleGradSafeTensor. test=develop

* fix format. test=develop
上级 97f0ec23
......@@ -44,6 +44,31 @@ class ElementwiseDivGradOpDescMaker : public framework::SingleGradOpDescMaker {
}
};
class ElementwiseDivDoubleGradDescMaker
: public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("elementwise_div_grad_grad");
op->SetInput("Y", Input("Y"));
op->SetInput("Out", Input("Out"));
op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
op->SetInput("DDY", OutputGrad(framework::GradVarName("Y")));
op->SetInput("DX", Output(framework::GradVarName("X")));
op->SetAttrMap(Attrs());
op->SetOutput(framework::GradVarName("Y"), InputGrad("Y"));
op->SetOutput("DOut", InputGrad("Out"));
op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
return op;
}
};
} // namespace operators
} // namespace paddle
......@@ -53,7 +78,9 @@ REGISTER_OPERATOR(elementwise_div, ops::ElementwiseOp,
ops::ElementwiseDivOpMaker, ops::ElementwiseOpInferVarType,
ops::ElementwiseDivGradOpDescMaker);
REGISTER_OPERATOR(elementwise_div_grad, ops::ElementwiseOpGrad);
REGISTER_OPERATOR(elementwise_div_grad, ops::ElementwiseOpGrad,
ops::ElementwiseDivDoubleGradDescMaker);
REGISTER_OPERATOR(elementwise_div_grad_grad, ops::ElementwiseDivOpDoubleGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_div,
......@@ -67,3 +94,14 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
elementwise_div_grad_grad,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
float>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
double>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
int>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>);
......@@ -33,3 +33,13 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
REGISTER_OP_CUDA_KERNEL(
elementwise_div_grad_grad,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
float>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
double>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
int>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
......@@ -14,8 +14,13 @@ limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
namespace paddle {
namespace operators {
......@@ -51,6 +56,13 @@ struct DivGradDY {
}
};
template <typename T>
struct DivDoubleDY {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return y * out * dout - x * dout;
}
};
template <typename DeviceContext, typename T>
class ElementwiseDivGradKernel : public ElemwiseGradKernel<T> {
public:
......@@ -72,5 +84,109 @@ class ElementwiseDivGradKernel : public ElemwiseGradKernel<T> {
}
};
class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext* ctx) const override {
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput("DOut")) {
ctx->ShareDim("DX", "DOut");
ctx->ShareLoD("DX", "DOut");
}
if (ctx->HasOutput(y_grad_name)) {
ctx->ShareDim("Y", y_grad_name);
ctx->ShareLoD("Y", y_grad_name);
}
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("DX", "DDOut");
ctx->ShareLoD("DX", "DDOut");
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type = ctx.Input<Tensor>("DDX")->type();
#ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
template <typename DeviceContext, typename T>
class ElementwiseDivDoubleGradKernel : public framework::OpKernel<T> {
using Tensor = framework::Tensor;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* Y = ctx.Input<Tensor>("Y");
auto* Out = ctx.Input<Tensor>("Out");
auto* ddX = ctx.Input<Tensor>("DDX");
auto* ddY = ctx.Input<Tensor>("DDY");
auto* dX = ctx.Input<Tensor>("DX");
auto* dY = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* dOut = ctx.Output<Tensor>("DOut");
auto* ddOut = ctx.Output<Tensor>("DDOut");
int axis = ctx.Attr<int>("axis");
if (dY) dY->mutable_data<T>(Y->dims(), ctx.GetPlace());
if (dOut) dOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
if (ddOut) ddOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
// ddX_safe == null ? 0 : ddX
// ddY_safe == null ? 0 : ddY
Tensor ddX_safe, ddY_safe;
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, Out, ddX, &ddX_safe);
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, Y, ddY, &ddY_safe);
if (dOut) {
// dOut = - dX * ddY
default_elementwise_mul<DeviceContext, T>(ctx, dX, &ddY_safe, dOut);
auto& place =
*ctx.template device_context<DeviceContext>().eigen_device();
auto dout = framework::EigenVector<T>::Flatten(*dOut);
dout.device(place) = static_cast<T>(-1) * dout;
}
if (dY) {
// dX_div_Y = dX / Y;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
Tensor dX_div_Y =
ctx.AllocateTmpTensor<T, DeviceContext>(Out->dims(), dev_ctx);
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(
ctx, dX, Y, axis, DivFunctor<T>(), &dX_div_Y);
// NOTE(dengkaipeng): in the following ElemwiseGradCompute, for the
// first output tensor is nullptr, the branch to calculate first
// output tensor will not be activated, DivGradDx function will not
// be called and can be ignored, the first branch has little effect
// on running speed.
// dY = Out * dX * ddY / Y - dX * ddX / Y
ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivDoubleDY<T>>(
ctx, ddX_safe, ddY_safe, *Out, dX_div_Y, axis, nullptr, dY,
DivGradDX<T>(), DivDoubleDY<T>());
}
if (ddOut) {
// ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y
default_elementwise_mul<DeviceContext, T>(ctx, Out, &ddY_safe, ddOut);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &ddX_safe, ddOut, 0, SubFunctor<T>(), ddOut);
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(
ctx, ddOut, Y, axis, DivFunctor<T>(), ddOut);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -1644,7 +1644,8 @@ static inline void GetDoubleGradSafeTensor(
if (ddx) {
*ddx_safe = *ddx;
} else {
ddx_safe->mutable_data<T>(x->dims(), ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<DeviceContext>();
*ddx_safe = ctx.AllocateTmpTensor<T, DeviceContext>(x->dims(), dev_ctx);
math::SetConstant<DeviceContext, T> set_zero;
set_zero(ctx.template device_context<DeviceContext>(), ddx_safe,
static_cast<T>(0));
......
......@@ -378,5 +378,61 @@ class TestMulDoubleGradCheck(unittest.TestCase):
self.func(p)
class TestElementwiseDivDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
# the shape of input variable shoule be clearly specified, not inlcude -1.
shape = [2, 3, 7, 9]
eps = 0.0001
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.data('y', shape, False, dtype)
x.persistable = True
y.persistable = True
out = layers.elementwise_div(x, y, axis=0)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, shape).astype(dtype)
y_arr[np.abs(y_arr) < 0.005] = 0.02
gradient_checker.double_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps, atol=1e-3)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestElementwiseDivBroadcastDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
# the shape of input variable shoule be clearly specified, not inlcude -1.
shape = [2, 3, 7, 9]
eps = 0.0001
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.data('y', shape[1:-1], False, dtype)
x.persistable = True
y.persistable = True
out = layers.elementwise_div(x, y, axis=1)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, shape[1:-1]).astype(dtype)
y_arr[np.abs(y_arr) < 0.005] = 0.02
gradient_checker.double_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps, atol=1e-3)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册