未验证 提交 1962d3af 编写于 作者: Z zhangbo9674 提交者: GitHub

add fp16 kernel for clip_op (#36577)

上级 d4906214
......@@ -19,10 +19,14 @@ REGISTER_OP_CUDA_KERNEL(
clip, ops::ClipKernel<paddle::platform::CUDADeviceContext, float>,
ops::ClipKernel<paddle::platform::CUDADeviceContext, double>,
ops::ClipKernel<paddle::platform::CUDADeviceContext, int>,
ops::ClipKernel<paddle::platform::CUDADeviceContext, int64_t>);
ops::ClipKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ClipKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
clip_grad, ops::ClipGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ClipGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ClipGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ClipGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
ops::ClipGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ClipGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
......@@ -54,7 +54,7 @@ class ClipGradFunctor {
public:
explicit ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {}
HOSTDEVICE T operator()(const T& x, const T& y) const {
return (y > min_ && y < max_) ? x : 0;
return (y > min_ && y < max_) ? x : static_cast<T>(0);
}
private:
......@@ -79,7 +79,7 @@ class ClipKernel : public framework::OpKernel<T> {
}
max = static_cast<T>(max);
auto min = context.Attr<float>("min");
auto min = static_cast<T>(context.Attr<float>("min"));
Tensor min_cpu;
if (context.HasInput("Min")) {
auto* min_t = context.Input<Tensor>("Min");
......@@ -156,7 +156,7 @@ class ClipGradKernel : public framework::OpKernel<T> {
}
max = static_cast<T>(max);
auto min = context.Attr<float>("min");
auto min = static_cast<T>(context.Attr<float>("min"));
Tensor min_cpu;
if (context.HasInput("Min")) {
auto* min_t = context.Input<Tensor>("Min");
......
......@@ -43,7 +43,7 @@ class TestClipOp(OpTest):
else:
max_v = self.attrs['max']
input = np.random.random(self.shape).astype("float32")
input = np.random.random(self.shape).astype(self.dtype)
input[np.abs(input - min_v) < self.max_relative_error] = 0.5
input[np.abs(input - max_v) < self.max_relative_error] = 0.5
self.inputs['X'] = input
......@@ -60,15 +60,17 @@ class TestClipOp(OpTest):
paddle.disable_static()
def initTestCase(self):
self.dtype = np.float32
self.shape = (4, 10, 10)
self.max = 0.8
self.min = 0.3
self.inputs['Max'] = np.array([0.8]).astype('float32')
self.inputs['Min'] = np.array([0.1]).astype('float32')
self.inputs['Max'] = np.array([0.8]).astype(self.dtype)
self.inputs['Min'] = np.array([0.1]).astype(self.dtype)
class TestCase1(TestClipOp):
def initTestCase(self):
self.dtype = np.float32
self.shape = (8, 16, 8)
self.max = 0.7
self.min = 0.0
......@@ -76,6 +78,7 @@ class TestCase1(TestClipOp):
class TestCase2(TestClipOp):
def initTestCase(self):
self.dtype = np.float32
self.shape = (8, 16)
self.max = 1.0
self.min = 0.0
......@@ -83,6 +86,7 @@ class TestCase2(TestClipOp):
class TestCase3(TestClipOp):
def initTestCase(self):
self.dtype = np.float32
self.shape = (4, 8, 16)
self.max = 0.7
self.min = 0.2
......@@ -90,20 +94,32 @@ class TestCase3(TestClipOp):
class TestCase4(TestClipOp):
def initTestCase(self):
self.dtype = np.float32
self.shape = (4, 8, 8)
self.max = 0.7
self.min = 0.2
self.inputs['Max'] = np.array([0.8]).astype('float32')
self.inputs['Min'] = np.array([0.3]).astype('float32')
self.inputs['Max'] = np.array([0.8]).astype(self.dtype)
self.inputs['Min'] = np.array([0.3]).astype(self.dtype)
class TestCase5(TestClipOp):
def initTestCase(self):
self.dtype = np.float32
self.shape = (4, 8, 16)
self.max = 0.5
self.min = 0.5
class TestCase6(TestClipOp):
def initTestCase(self):
self.dtype == np.float16
self.shape = (4, 8, 8)
self.max = 0.7
self.min = 0.2
self.inputs['Max'] = np.array([0.8]).astype(self.dtype)
self.inputs['Min'] = np.array([0.3]).astype(self.dtype)
class TestClipOpError(unittest.TestCase):
def test_errors(self):
paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册