diff --git a/paddle/fluid/operators/clip_op.cu b/paddle/fluid/operators/clip_op.cu index fd61e4ea61d4ff20656dea842b02958c8c2701b9..846354fcb81c5f07580533c69a598df62e50ddaf 100644 --- a/paddle/fluid/operators/clip_op.cu +++ b/paddle/fluid/operators/clip_op.cu @@ -19,10 +19,14 @@ REGISTER_OP_CUDA_KERNEL( clip, ops::ClipKernel, ops::ClipKernel, ops::ClipKernel, - ops::ClipKernel); + ops::ClipKernel, + ops::ClipKernel); REGISTER_OP_CUDA_KERNEL( clip_grad, ops::ClipGradKernel, ops::ClipGradKernel, ops::ClipGradKernel, - ops::ClipGradKernel); + ops::ClipGradKernel, + ops::ClipGradKernel); diff --git a/paddle/fluid/operators/clip_op.h b/paddle/fluid/operators/clip_op.h index 93157ed9d47bbcf3ddffec650f8a4b97a0e2af3f..abf721936b41e3c8403d47a0b6f961fc8b2b5bd0 100644 --- a/paddle/fluid/operators/clip_op.h +++ b/paddle/fluid/operators/clip_op.h @@ -54,7 +54,7 @@ class ClipGradFunctor { public: explicit ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {} HOSTDEVICE T operator()(const T& x, const T& y) const { - return (y > min_ && y < max_) ? x : 0; + return (y > min_ && y < max_) ? x : static_cast(0); } private: @@ -79,7 +79,7 @@ class ClipKernel : public framework::OpKernel { } max = static_cast(max); - auto min = context.Attr("min"); + auto min = static_cast(context.Attr("min")); Tensor min_cpu; if (context.HasInput("Min")) { auto* min_t = context.Input("Min"); @@ -156,7 +156,7 @@ class ClipGradKernel : public framework::OpKernel { } max = static_cast(max); - auto min = context.Attr("min"); + auto min = static_cast(context.Attr("min")); Tensor min_cpu; if (context.HasInput("Min")) { auto* min_t = context.Input("Min"); diff --git a/python/paddle/fluid/tests/unittests/test_clip_op.py b/python/paddle/fluid/tests/unittests/test_clip_op.py index 1833c473d18a967b715bea351ab6b24a23f4bd04..74c5f693a37f1f0a480a14465e87be97cabb8f9f 100644 --- a/python/paddle/fluid/tests/unittests/test_clip_op.py +++ b/python/paddle/fluid/tests/unittests/test_clip_op.py @@ -43,7 +43,7 @@ class TestClipOp(OpTest): else: max_v = self.attrs['max'] - input = np.random.random(self.shape).astype("float32") + input = np.random.random(self.shape).astype(self.dtype) input[np.abs(input - min_v) < self.max_relative_error] = 0.5 input[np.abs(input - max_v) < self.max_relative_error] = 0.5 self.inputs['X'] = input @@ -60,15 +60,17 @@ class TestClipOp(OpTest): paddle.disable_static() def initTestCase(self): + self.dtype = np.float32 self.shape = (4, 10, 10) self.max = 0.8 self.min = 0.3 - self.inputs['Max'] = np.array([0.8]).astype('float32') - self.inputs['Min'] = np.array([0.1]).astype('float32') + self.inputs['Max'] = np.array([0.8]).astype(self.dtype) + self.inputs['Min'] = np.array([0.1]).astype(self.dtype) class TestCase1(TestClipOp): def initTestCase(self): + self.dtype = np.float32 self.shape = (8, 16, 8) self.max = 0.7 self.min = 0.0 @@ -76,6 +78,7 @@ class TestCase1(TestClipOp): class TestCase2(TestClipOp): def initTestCase(self): + self.dtype = np.float32 self.shape = (8, 16) self.max = 1.0 self.min = 0.0 @@ -83,6 +86,7 @@ class TestCase2(TestClipOp): class TestCase3(TestClipOp): def initTestCase(self): + self.dtype = np.float32 self.shape = (4, 8, 16) self.max = 0.7 self.min = 0.2 @@ -90,20 +94,32 @@ class TestCase3(TestClipOp): class TestCase4(TestClipOp): def initTestCase(self): + self.dtype = np.float32 self.shape = (4, 8, 8) self.max = 0.7 self.min = 0.2 - self.inputs['Max'] = np.array([0.8]).astype('float32') - self.inputs['Min'] = np.array([0.3]).astype('float32') + self.inputs['Max'] = np.array([0.8]).astype(self.dtype) + self.inputs['Min'] = np.array([0.3]).astype(self.dtype) class TestCase5(TestClipOp): def initTestCase(self): + self.dtype = np.float32 self.shape = (4, 8, 16) self.max = 0.5 self.min = 0.5 +class TestCase6(TestClipOp): + def initTestCase(self): + self.dtype == np.float16 + self.shape = (4, 8, 8) + self.max = 0.7 + self.min = 0.2 + self.inputs['Max'] = np.array([0.8]).astype(self.dtype) + self.inputs['Min'] = np.array([0.3]).astype(self.dtype) + + class TestClipOpError(unittest.TestCase): def test_errors(self): paddle.enable_static()