未验证 提交 51c97d9f 编写于 作者: W Weilong Wu 提交者: GitHub

Support elementwise_add triple grad Kernel (#36508)

* Support elementwise_add triple grad Kernel

* Change code-format to follow CI std
上级 999242e3
......@@ -110,6 +110,25 @@ class ElementwiseAddDoubleGradMaker : public framework::SingleGradOpMaker<T> {
}
};
template <typename T>
class ElementwiseAddTripleGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("elementwise_add_triple_grad");
op->SetInput("DDX", this->Input("DDX"));
op->SetInput("DDY", this->Input("DDY"));
op->SetInput("D_DDOut", this->OutputGrad("DDOut"));
op->SetAttrMap(this->Attrs());
op->SetOutput("D_DDX", this->InputGrad("DDX"));
op->SetOutput("D_DDY", this->InputGrad("DDY"));
}
};
} // namespace operators
} // namespace paddle
......@@ -123,10 +142,16 @@ REGISTER_OPERATOR(
ops::ElementwiseAddDoubleGradMaker<paddle::framework::OpDesc>,
ops::ElementwiseAddDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_add_grad_grad,
ops::ElementwiseOpDoubleGradWithoutDXDY,
ops::ElementwiseDoubleGradOpInplaceInferer,
ops::ElementwiseDoubleGradNoBufVarsInferer);
REGISTER_OPERATOR(
elementwise_add_grad_grad, ops::ElementwiseOpDoubleGradWithoutDXDY,
ops::ElementwiseDoubleGradOpInplaceInferer,
ops::ElementwiseDoubleGradNoBufVarsInferer,
ops::ElementwiseAddTripleGradMaker<paddle::framework::OpDesc>,
ops::ElementwiseAddTripleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_add_triple_grad, ops::ElementwiseOpTripleGrad,
ops::ElementwiseTripleGradOpInplaceInferer,
ops::ElementwiseTripleGradNoBufVarsInferer);
REGISTER_OP_CPU_KERNEL(
elementwise_add,
......@@ -162,6 +187,20 @@ REGISTER_OP_CPU_KERNEL(
paddle::platform::complex<float>>,
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
elementwise_add_triple_grad,
ops::ElementwiseAddTripleGradKernel<paddle::platform::CPUDeviceContext,
float>,
ops::ElementwiseAddTripleGradKernel<paddle::platform::CPUDeviceContext,
double>,
ops::ElementwiseAddTripleGradKernel<paddle::platform::CPUDeviceContext,
int>,
ops::ElementwiseAddTripleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>,
ops::ElementwiseAddTripleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::ElementwiseAddTripleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
// A specialization elementwise_add operator, used in gradient accumulation with
// inplace addto.
......
......@@ -196,6 +196,17 @@ REGISTER_OP_CUDA_KERNEL(
plat::complex<float>>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext,
plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
elementwise_add_triple_grad,
ops::ElementwiseAddTripleGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseAddTripleGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseAddTripleGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddTripleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddTripleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseAddTripleGradKernel<plat::CUDADeviceContext,
plat::complex<float>>,
ops::ElementwiseAddTripleGradKernel<plat::CUDADeviceContext,
plat::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
grad_add, ops::ElementwiseAddKernel<plat::CUDADeviceContext, float>,
......
......@@ -205,5 +205,44 @@ class ElementwiseAddDoubleGradKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class ElementwiseAddTripleGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using Tensor = framework::Tensor;
auto *ddx = ctx.Input<Tensor>("DDX");
auto *ddy = ctx.Input<Tensor>("DDY");
auto *d_ddout = ctx.Input<Tensor>("D_DDOut");
auto *d_ddx = ctx.Output<Tensor>("D_DDX");
auto *d_ddy = ctx.Output<Tensor>("D_DDY");
// skip out
auto *out = d_ddout;
// Special case when d_ddy is not needed and d_ddx doesn't reduce
if (d_ddx != nullptr && d_ddy == nullptr &&
d_ddx->dims() == d_ddout->dims()) {
VLOG(4) << "Special case when d_ddy is not needed and d_ddx doesn't "
"reduce";
framework::TensorCopy(
*d_ddout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), d_ddx);
} else if (d_ddx == nullptr && d_ddy != nullptr &&
d_ddy->dims() == d_ddout->dims()) {
VLOG(4) << "Special case when d_ddx is not needed and d_ddy doesn't "
"reduce";
framework::TensorCopy(
*d_ddout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), d_ddy);
} else if (d_ddx != nullptr && d_ddy != nullptr &&
(d_ddx->dims() == d_ddy->dims())) {
elementwise_add_grad<DeviceContext, T>(ctx, ddx, ddy, out, d_ddout, d_ddx,
d_ddy);
} else {
default_elementwise_add_grad<DeviceContext, T>(ctx, ddx, ddy, out,
d_ddout, d_ddx, d_ddy);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -426,6 +426,62 @@ class ElementwiseOpDoubleGradWithoutDXDY
}
};
class ElementwiseOpTripleGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override {
if (ctx->HasOutput("D_DDX")) {
ctx->ShareDim("DDX", "D_DDX");
ctx->ShareLoD("DDX", "D_DDX");
}
if (ctx->HasOutput("D_DDY")) {
ctx->ShareDim("DDY", "D_DDY");
ctx->ShareLoD("DDY", "D_DDY");
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::proto::VarType::Type input_data_type;
if (ctx.HasInput("DDX") == false) {
OP_INOUT_CHECK(ctx.HasInput("DDY"), "Input", "DDY",
"ElementwiseOpTripleGrad");
input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDY");
} else if (ctx.HasInput("DDY") == false) {
OP_INOUT_CHECK(ctx.HasInput("DDX"), "Input", "DDX",
"ElementwiseOpTripleGrad");
input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX");
} else {
input_data_type =
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "DDX", "DDY");
}
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
if (framework::IsComplexType(expected_kernel_type.data_type_)) {
// only promote inputs’s types when contains complex input
return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
} else {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}
};
template <typename T>
class ElemwiseGradKernel : public framework::OpKernel<T> {
public:
......@@ -447,9 +503,14 @@ DECLARE_INPLACE_OP_INFERER(ElementwiseGradOpInplaceInferer,
DECLARE_INPLACE_OP_INFERER(ElementwiseDoubleGradOpInplaceInferer,
{"DDX", "DDOut"});
DECLARE_INPLACE_OP_INFERER(ElementwiseTripleGradOpInplaceInferer,
{"D_DDOut", "D_DDX"});
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseGradNoBufVarsInferer, "X", "Y");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseDoubleGradNoBufVarsInferer, "Y",
"DOut");
DECLARE_NO_NEED_BUFFER_VARS_INFERER(ElementwiseTripleGradNoBufVarsInferer,
"DDX", "DDY");
} // namespace operators
} // namespace paddle
......
......@@ -486,20 +486,26 @@ def triple_grad_check(x,
var_to_np_array_in_scope(scope, place, v.name)
for v in x_grads_grads
]
# append second order grads
target_grads_grads = fluid.gradients(target_grads, x, x_grads_grads)
x += y_grads
x_init = _as_list(x_init)
x_init += y_grads_init
# append second order grads
target_grads_grads = fluid.gradients(target_grads, x, x_grads_grads)
# filter None in target_grads_grads for Dy/Dx may be None in kernel
filted = [(i, dyi) for i, dyi in enumerate(target_grads_grads)
if dyi is not None]
filted_idx, filted_target_grads_grads = zip(*filted)
x += x_grads_grads
x_init += x_grads_grads_init
# x <=> [x, dout, ddx]
grad_check(
x=x,
y=target_grads_grads,
y=filted_target_grads_grads,
x_init=x_init,
place=place,
program=program,
......
......@@ -243,5 +243,59 @@ class TestElementwiseDivBroadcastDoubleGradCheck(unittest.TestCase):
self.func(p)
class TestElementwiseAddTripleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
# the shape of input variable should be clearly specified, not inlcude -1.
shape = [2, 3, 4, 5]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.data('y', shape, False, dtype)
x.persistable = True
y.persistable = True
out = layers.elementwise_add(x, y)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, shape).astype(dtype)
gradient_checker.triple_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestElementwiseAddBroadcastTripleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
# the shape of input variable should be clearly specified, not inlcude -1.
shape = [2, 3, 4, 5]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.data('y', shape[:-1], False, dtype)
x.persistable = True
y.persistable = True
out = layers.elementwise_add(x, y, axis=0)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
y_arr = np.random.uniform(-1, 1, shape[:-1]).astype(dtype)
gradient_checker.triple_grad_check(
[x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册