未验证 提交 5bd7c82b 编写于 作者: Q qingqing01 提交者: GitHub

[Cherry-pick] Double grad for clip op #31109

Cherry-pick double grad for clip
上级 8177ece5
...@@ -109,6 +109,29 @@ DECLARE_INPLACE_OP_INFERER(ClipGradInplaceInferer, ...@@ -109,6 +109,29 @@ DECLARE_INPLACE_OP_INFERER(ClipGradInplaceInferer,
{framework::GradVarName("Out"), {framework::GradVarName("Out"),
framework::GradVarName("X")}); framework::GradVarName("X")});
template <typename T>
class ClipDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("clip_grad");
op->SetInput("X", this->Input("X"));
if (this->HasInput("Min")) {
op->SetInput("Min", this->Input("Min"));
}
if (this->HasInput("Max")) {
op->SetInput("Max", this->Input("Max"));
}
op->SetInput(framework::GradVarName("Out"),
this->OutputGrad(framework::GradVarName("X")));
op->SetOutput(framework::GradVarName("X"),
this->InputGrad(framework::GradVarName("Out")));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -117,7 +140,9 @@ REGISTER_OPERATOR(clip, ops::ClipOp, ops::ClipOpMaker<float>, ...@@ -117,7 +140,9 @@ REGISTER_OPERATOR(clip, ops::ClipOp, ops::ClipOpMaker<float>,
ops::ClipGradOpMaker<paddle::framework::OpDesc>, ops::ClipGradOpMaker<paddle::framework::OpDesc>,
ops::ClipGradOpMaker<paddle::imperative::OpBase>, ops::ClipGradOpMaker<paddle::imperative::OpBase>,
ops::ClipInplaceInferer); ops::ClipInplaceInferer);
REGISTER_OPERATOR(clip_grad, ops::ClipOpGrad, ops::ClipGradInplaceInferer); REGISTER_OPERATOR(clip_grad, ops::ClipOpGrad, ops::ClipGradInplaceInferer,
ops::ClipDoubleGradOpMaker<paddle::framework::OpDesc>,
ops::ClipDoubleGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
clip, ops::ClipKernel<paddle::platform::CPUDeviceContext, float>, clip, ops::ClipKernel<paddle::platform::CPUDeviceContext, float>,
ops::ClipKernel<paddle::platform::CPUDeviceContext, double>); ops::ClipKernel<paddle::platform::CPUDeviceContext, double>);
......
...@@ -329,6 +329,27 @@ class TestUnsqueezeDoubleGradCheck(unittest.TestCase): ...@@ -329,6 +329,27 @@ class TestUnsqueezeDoubleGradCheck(unittest.TestCase):
self.func(p) self.func(p)
class TestClipDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
x_shape = [2, 4, 10]
dtype = np.float64
x = layers.data('x', x_shape, False, dtype)
x.persistable = True
out = paddle.clip(x, min=-1., max=1.)
x_arr = np.random.uniform(-5., 5., x_shape).astype(dtype)
gradient_checker.double_grad_check([x], out, x_init=x_arr, place=place)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestTransposeDoubleGradCheck(unittest.TestCase): class TestTransposeDoubleGradCheck(unittest.TestCase):
@prog_scope() @prog_scope()
def func(self, place): def func(self, place):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册