未验证 提交 8d549fc8 编写于 作者: Q qingqing01 提交者: GitHub

Add clip double grad (#29590)

上级 81acc327
......@@ -109,6 +109,29 @@ DECLARE_INPLACE_OP_INFERER(ClipGradInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
template <typename T>
class ClipDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("clip_grad");
op->SetInput("X", this->Input("X"));
if (this->HasInput("Min")) {
op->SetInput("Min", this->Input("Min"));
}
if (this->HasInput("Max")) {
op->SetInput("Max", this->Input("Max"));
}
op->SetInput(framework::GradVarName("Out"),
this->OutputGrad(framework::GradVarName("X")));
op->SetOutput(framework::GradVarName("X"),
this->InputGrad(framework::GradVarName("Out")));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
......@@ -117,7 +140,9 @@ REGISTER_OPERATOR(clip, ops::ClipOp, ops::ClipOpMaker<float>,
ops::ClipGradOpMaker<paddle::framework::OpDesc>,
ops::ClipGradOpMaker<paddle::imperative::OpBase>,
ops::ClipInplaceInferer);
REGISTER_OPERATOR(clip_grad, ops::ClipOpGrad, ops::ClipGradInplaceInferer);
REGISTER_OPERATOR(clip_grad, ops::ClipOpGrad, ops::ClipGradInplaceInferer,
ops::ClipDoubleGradOpMaker<paddle::framework::OpDesc>,
ops::ClipDoubleGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
clip, ops::ClipKernel<paddle::platform::CPUDeviceContext, float>,
ops::ClipKernel<paddle::platform::CPUDeviceContext, double>);
......
......@@ -329,5 +329,26 @@ class TestUnsqueezeDoubleGradCheck(unittest.TestCase):
self.func(p)
class TestClipDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
x_shape = [2, 4, 10]
dtype = np.float64
x = layers.data('x', x_shape, False, dtype)
x.persistable = True
out = paddle.clip(x, min=-1., max=1.)
x_arr = np.random.uniform(-5., 5., x_shape).astype(dtype)
gradient_checker.double_grad_check([x], out, x_init=x_arr, place=place)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册