未验证 提交 bd40dd9a 编写于 作者: Z zhangbo9674 提交者: GitHub

[Cherry Pick]Add fp16 kernel for clip_op (#36577) (#36672)

Add fp16 kernel for clip_op.
上级 304fb2b5
...@@ -19,10 +19,14 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -19,10 +19,14 @@ REGISTER_OP_CUDA_KERNEL(
clip, ops::ClipKernel<paddle::platform::CUDADeviceContext, float>, clip, ops::ClipKernel<paddle::platform::CUDADeviceContext, float>,
ops::ClipKernel<paddle::platform::CUDADeviceContext, double>, ops::ClipKernel<paddle::platform::CUDADeviceContext, double>,
ops::ClipKernel<paddle::platform::CUDADeviceContext, int>, ops::ClipKernel<paddle::platform::CUDADeviceContext, int>,
ops::ClipKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::ClipKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ClipKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
clip_grad, ops::ClipGradKernel<paddle::platform::CUDADeviceContext, float>, clip_grad, ops::ClipGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ClipGradKernel<paddle::platform::CUDADeviceContext, double>, ops::ClipGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ClipGradKernel<paddle::platform::CUDADeviceContext, int>, ops::ClipGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ClipGradKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::ClipGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ClipGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
...@@ -54,7 +54,7 @@ class ClipGradFunctor { ...@@ -54,7 +54,7 @@ class ClipGradFunctor {
public: public:
explicit ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {} explicit ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {}
HOSTDEVICE T operator()(const T& x, const T& y) const { HOSTDEVICE T operator()(const T& x, const T& y) const {
return (y > min_ && y < max_) ? x : 0; return (y > min_ && y < max_) ? x : static_cast<T>(0);
} }
private: private:
...@@ -79,7 +79,7 @@ class ClipKernel : public framework::OpKernel<T> { ...@@ -79,7 +79,7 @@ class ClipKernel : public framework::OpKernel<T> {
} }
max = static_cast<T>(max); max = static_cast<T>(max);
auto min = context.Attr<float>("min"); auto min = static_cast<T>(context.Attr<float>("min"));
Tensor min_cpu; Tensor min_cpu;
if (context.HasInput("Min")) { if (context.HasInput("Min")) {
auto* min_t = context.Input<Tensor>("Min"); auto* min_t = context.Input<Tensor>("Min");
...@@ -156,7 +156,7 @@ class ClipGradKernel : public framework::OpKernel<T> { ...@@ -156,7 +156,7 @@ class ClipGradKernel : public framework::OpKernel<T> {
} }
max = static_cast<T>(max); max = static_cast<T>(max);
auto min = context.Attr<float>("min"); auto min = static_cast<T>(context.Attr<float>("min"));
Tensor min_cpu; Tensor min_cpu;
if (context.HasInput("Min")) { if (context.HasInput("Min")) {
auto* min_t = context.Input<Tensor>("Min"); auto* min_t = context.Input<Tensor>("Min");
......
...@@ -43,7 +43,7 @@ class TestClipOp(OpTest): ...@@ -43,7 +43,7 @@ class TestClipOp(OpTest):
else: else:
max_v = self.attrs['max'] max_v = self.attrs['max']
input = np.random.random(self.shape).astype("float32") input = np.random.random(self.shape).astype(self.dtype)
input[np.abs(input - min_v) < self.max_relative_error] = 0.5 input[np.abs(input - min_v) < self.max_relative_error] = 0.5
input[np.abs(input - max_v) < self.max_relative_error] = 0.5 input[np.abs(input - max_v) < self.max_relative_error] = 0.5
self.inputs['X'] = input self.inputs['X'] = input
...@@ -60,15 +60,17 @@ class TestClipOp(OpTest): ...@@ -60,15 +60,17 @@ class TestClipOp(OpTest):
paddle.disable_static() paddle.disable_static()
def initTestCase(self): def initTestCase(self):
self.dtype = np.float32
self.shape = (4, 10, 10) self.shape = (4, 10, 10)
self.max = 0.8 self.max = 0.8
self.min = 0.3 self.min = 0.3
self.inputs['Max'] = np.array([0.8]).astype('float32') self.inputs['Max'] = np.array([0.8]).astype(self.dtype)
self.inputs['Min'] = np.array([0.1]).astype('float32') self.inputs['Min'] = np.array([0.1]).astype(self.dtype)
class TestCase1(TestClipOp): class TestCase1(TestClipOp):
def initTestCase(self): def initTestCase(self):
self.dtype = np.float32
self.shape = (8, 16, 8) self.shape = (8, 16, 8)
self.max = 0.7 self.max = 0.7
self.min = 0.0 self.min = 0.0
...@@ -76,6 +78,7 @@ class TestCase1(TestClipOp): ...@@ -76,6 +78,7 @@ class TestCase1(TestClipOp):
class TestCase2(TestClipOp): class TestCase2(TestClipOp):
def initTestCase(self): def initTestCase(self):
self.dtype = np.float32
self.shape = (8, 16) self.shape = (8, 16)
self.max = 1.0 self.max = 1.0
self.min = 0.0 self.min = 0.0
...@@ -83,6 +86,7 @@ class TestCase2(TestClipOp): ...@@ -83,6 +86,7 @@ class TestCase2(TestClipOp):
class TestCase3(TestClipOp): class TestCase3(TestClipOp):
def initTestCase(self): def initTestCase(self):
self.dtype = np.float32
self.shape = (4, 8, 16) self.shape = (4, 8, 16)
self.max = 0.7 self.max = 0.7
self.min = 0.2 self.min = 0.2
...@@ -90,20 +94,32 @@ class TestCase3(TestClipOp): ...@@ -90,20 +94,32 @@ class TestCase3(TestClipOp):
class TestCase4(TestClipOp): class TestCase4(TestClipOp):
def initTestCase(self): def initTestCase(self):
self.dtype = np.float32
self.shape = (4, 8, 8) self.shape = (4, 8, 8)
self.max = 0.7 self.max = 0.7
self.min = 0.2 self.min = 0.2
self.inputs['Max'] = np.array([0.8]).astype('float32') self.inputs['Max'] = np.array([0.8]).astype(self.dtype)
self.inputs['Min'] = np.array([0.3]).astype('float32') self.inputs['Min'] = np.array([0.3]).astype(self.dtype)
class TestCase5(TestClipOp): class TestCase5(TestClipOp):
def initTestCase(self): def initTestCase(self):
self.dtype = np.float32
self.shape = (4, 8, 16) self.shape = (4, 8, 16)
self.max = 0.5 self.max = 0.5
self.min = 0.5 self.min = 0.5
class TestCase6(TestClipOp):
def initTestCase(self):
self.dtype == np.float16
self.shape = (4, 8, 8)
self.max = 0.7
self.min = 0.2
self.inputs['Max'] = np.array([0.8]).astype(self.dtype)
self.inputs['Min'] = np.array([0.3]).astype(self.dtype)
class TestClipOpError(unittest.TestCase): class TestClipOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
paddle.enable_static() paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册