未验证 提交 bd9bef5a 编写于 作者: K Kaipeng Deng 提交者: GitHub

add elementwise_add_grad_grad op (#17366)

* add elementwise_add_grad_grad op. test=develop

* use defined GradMaker. test=develop
上级 1c6d0646
......@@ -13,10 +13,48 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace ops = paddle::operators;
namespace paddle {
namespace operators {
class ElementwiseAddDoubleGradDescMaker
: public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("elementwise_add_grad_grad");
op->SetInput("Y", Input("Y"));
op->SetInput("DOut", Input(framework::GradVarName("Out")));
op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
op->SetInput("DDY", OutputGrad(framework::GradVarName("Y")));
op->SetAttrMap(Attrs());
op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
return op;
}
};
} // namespace operators
} // namespace paddle
REGISTER_ELEMWISE_GRAD_MAKER(elementwise_add, Add);
REGISTER_ELEMWISE_EXPLICIT_OP(elementwise_add, "Add", "Out = X + Y");
REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(elementwise_add, "Add",
"Out = X + Y");
namespace ops = paddle::operators;
REGISTER_OPERATOR(elementwise_add_grad, ops::ElementwiseOpExplicitGrad,
ops::ElementwiseGradOpInplace,
ops::ElementwiseGradNoBufVarsInference,
ops::ElementwiseAddDoubleGradDescMaker);
REGISTER_OPERATOR(elementwise_add_grad_grad,
ops::ElementwiseOpDoubleGradWithoutDXDY);
REGISTER_OP_CPU_KERNEL(
elementwise_add,
......@@ -30,3 +68,13 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
elementwise_add_grad_grad,
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
float>,
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
double>,
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
int>,
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>);
......@@ -31,3 +31,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
elementwise_add_grad_grad,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, int64_t>);
......@@ -161,5 +161,31 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
}
};
template <typename DeviceContext, typename T>
class ElementwiseAddDoubleGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using Tensor = framework::Tensor;
auto *y = ctx.Input<Tensor>("Y");
auto *dout = ctx.Input<Tensor>("DOut");
auto *ddx = ctx.Input<Tensor>("DDX");
auto *ddy = ctx.Input<Tensor>("DDY");
auto *ddout = ctx.Output<Tensor>("DDOut");
// ddOut = ddx + ddy
if (ddout) {
Tensor ddx_safe, ddy_safe;
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, dout, ddx, &ddx_safe);
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, y, ddy, &ddy_safe);
ddout->mutable_data<T>(ctx.GetPlace());
default_elementwise_add<DeviceContext, T>(ctx, &ddx_safe, &ddy_safe,
ddout);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -236,7 +236,35 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type = ctx.Input<Tensor>("DDX")->type();
auto input_data_type = ctx.Input<Tensor>("DOut")->type();
#ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class ElementwiseOpDoubleGradWithoutDXDY
: public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override {
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("DOut", "DDOut");
ctx->ShareLoD("DOut", "DDOut");
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type = ctx.Input<Tensor>("DOut")->type();
#ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) {
......@@ -359,3 +387,16 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseGradNoBufVarsInference, "Y");
::paddle::operators::ElementwiseOpExplicitGrad, \
::paddle::operators::ElementwiseGradOpInplace, \
::paddle::operators::ElementwiseGradNoBufVarsInference)
#define REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(op_type, op_name, equation) \
class __ElemwiseOp##op_type##Maker__ \
: public ::paddle::operators::ElementwiseOpMaker { \
protected: \
virtual std::string GetName() const { return op_name; } \
virtual std::string GetEquation() const { return equation; } \
}; \
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
__ElemwiseOp##op_type##Maker__, \
::paddle::operators::ElementwiseOpInferVarType, \
op_type##GradMaker, \
::paddle::operators::ElementwiseOpInplace);
......@@ -196,17 +196,23 @@ def _compute_analytical_jacobian(program, x, y, place, scope):
x = _as_list(x)
jacobian = make_jacobian(x, y_size, np_type)
# filter None in dx for DX/DY may be None in kernel
# only fetch not None dx in exe.run
filted = [(i, dxi) for i, dxi in enumerate(dx) if dxi is not None]
filted_idx, filted_dx = zip(*filted)
for i in six.moves.xrange(y_size):
_set_item(dy_t, i, 1, np_type)
dx_res = exe.run(program, scope=scope, fetch_list=dx)
dx_res = exe.run(program, scope=scope, fetch_list=filted_dx)
for j in six.moves.xrange(len(x)):
for j in six.moves.xrange(len(filted_dx)):
dx_idx = filted_idx[j]
if dx_res[j] is not None:
jacobian[j][:, i] = dx_res[j].flatten()
jacobian[dx_idx][:, i] = dx_res[j].flatten()
else:
jacobian[j][:, i] = np.zeros(
dx[j].shape, dtype=np_type).flatten()
jacobian[dx_idx][:, i] = np.zeros(
dx[dx_idx].shape, dtype=np_type).flatten()
_set_item(dy_t, i, 0, np_type)
......
......@@ -193,6 +193,60 @@ class TestElementwiseMulBroadcastDoubleGradCheck(unittest.TestCase):
self.func(p)
class TestElementwiseAddDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
# the shape of input variable shoule be clearly specified, not inlcude -1.
shape = [7, 9]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.data('y', shape, False, dtype)
x.persistable = True
y.persistable = True
out = layers.elementwise_add(x, y)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, shape).astype(dtype)
gradient_checker.double_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestElementwiseAddBroadcastDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
# the shape of input variable shoule be clearly specified, not inlcude -1.
shape = [7, 9]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.data('y', shape[:-1], False, dtype)
x.persistable = True
y.persistable = True
out = layers.elementwise_add(x, y, axis=0)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, shape[:-1]).astype(dtype)
gradient_checker.double_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestMulDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册