未验证 提交 29b844ad 编写于 作者: Y Yang Zhang 提交者: GitHub

Fix clip op attr (#26924)

上级 26c698e2
...@@ -66,7 +66,7 @@ template <typename DeviceContext, typename T> ...@@ -66,7 +66,7 @@ template <typename DeviceContext, typename T>
class ClipKernel : public framework::OpKernel<T> { class ClipKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto max = context.Attr<T>("max"); auto max = static_cast<T>(context.Attr<float>("max"));
Tensor max_cpu; Tensor max_cpu;
if (context.HasInput("Max")) { if (context.HasInput("Max")) {
auto* max_t = context.Input<Tensor>("Max"); auto* max_t = context.Input<Tensor>("Max");
...@@ -77,8 +77,9 @@ class ClipKernel : public framework::OpKernel<T> { ...@@ -77,8 +77,9 @@ class ClipKernel : public framework::OpKernel<T> {
} }
max = max_data[0]; max = max_data[0];
} }
max = static_cast<T>(max);
auto min = context.Attr<T>("min"); auto min = context.Attr<float>("min");
Tensor min_cpu; Tensor min_cpu;
if (context.HasInput("Min")) { if (context.HasInput("Min")) {
auto* min_t = context.Input<Tensor>("Min"); auto* min_t = context.Input<Tensor>("Min");
...@@ -141,7 +142,7 @@ template <typename DeviceContext, typename T> ...@@ -141,7 +142,7 @@ template <typename DeviceContext, typename T>
class ClipGradKernel : public framework::OpKernel<T> { class ClipGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto max = context.Attr<T>("max"); auto max = static_cast<T>(context.Attr<float>("max"));
Tensor max_cpu; Tensor max_cpu;
if (context.HasInput("Max")) { if (context.HasInput("Max")) {
auto* max_t = context.Input<Tensor>("Max"); auto* max_t = context.Input<Tensor>("Max");
...@@ -152,8 +153,9 @@ class ClipGradKernel : public framework::OpKernel<T> { ...@@ -152,8 +153,9 @@ class ClipGradKernel : public framework::OpKernel<T> {
} }
max = max_data[0]; max = max_data[0];
} }
max = static_cast<T>(max);
auto min = context.Attr<T>("min"); auto min = context.Attr<float>("min");
Tensor min_cpu; Tensor min_cpu;
if (context.HasInput("Min")) { if (context.HasInput("Min")) {
auto* min_t = context.Input<Tensor>("Min"); auto* min_t = context.Input<Tensor>("Min");
...@@ -164,6 +166,7 @@ class ClipGradKernel : public framework::OpKernel<T> { ...@@ -164,6 +166,7 @@ class ClipGradKernel : public framework::OpKernel<T> {
} }
min = min_data[0]; min = min_data[0];
} }
min = static_cast<T>(min);
auto* d_out = auto* d_out =
context.Input<framework::LoDTensor>(framework::GradVarName("Out")); context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
......
...@@ -138,8 +138,9 @@ class TestClipAPI(unittest.TestCase): ...@@ -138,8 +138,9 @@ class TestClipAPI(unittest.TestCase):
out_6 = paddle.clip(images, max=max) out_6 = paddle.clip(images, max=max)
out_7 = paddle.clip(images, max=-1.) out_7 = paddle.clip(images, max=-1.)
out_8 = paddle.clip(images) out_8 = paddle.clip(images)
out_9 = paddle.clip(paddle.cast(images, 'float64'), min=0.2, max=0.9)
res1, res2, res3, res4, res5, res6, res7, res8 = exe.run( res1, res2, res3, res4, res5, res6, res7, res8, res9 = exe.run(
fluid.default_main_program(), fluid.default_main_program(),
feed={ feed={
"image": data, "image": data,
...@@ -147,7 +148,7 @@ class TestClipAPI(unittest.TestCase): ...@@ -147,7 +148,7 @@ class TestClipAPI(unittest.TestCase):
"max": np.array([0.8]).astype('float32') "max": np.array([0.8]).astype('float32')
}, },
fetch_list=[ fetch_list=[
out_1, out_2, out_3, out_4, out_5, out_6, out_7, out_8 out_1, out_2, out_3, out_4, out_5, out_6, out_7, out_8, out_9
]) ])
self.assertTrue(np.allclose(res1, data.clip(0.2, 0.8))) self.assertTrue(np.allclose(res1, data.clip(0.2, 0.8)))
...@@ -158,6 +159,8 @@ class TestClipAPI(unittest.TestCase): ...@@ -158,6 +159,8 @@ class TestClipAPI(unittest.TestCase):
self.assertTrue(np.allclose(res6, data.clip(max=0.8))) self.assertTrue(np.allclose(res6, data.clip(max=0.8)))
self.assertTrue(np.allclose(res7, data.clip(max=-1))) self.assertTrue(np.allclose(res7, data.clip(max=-1)))
self.assertTrue(np.allclose(res8, data)) self.assertTrue(np.allclose(res8, data))
self.assertTrue(
np.allclose(res9, data.astype(np.float64).clip(0.2, 0.9)))
def test_clip_dygraph(self): def test_clip_dygraph(self):
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
......
...@@ -1611,11 +1611,8 @@ def clip(x, min=None, max=None, name=None): ...@@ -1611,11 +1611,8 @@ def clip(x, min=None, max=None, name=None):
# [[4.5, 6.4] # [[4.5, 6.4]
""" """
np_dtype = np.float32 fmin = float(np.finfo(np.float32).min)
if x.dtype == VarDesc.VarType.FP64: fmax = float(np.finfo(np.float32).max)
np_dtype = np.float64
fmin = float(np.finfo(np_dtype).min)
fmax = float(np.finfo(np_dtype).max)
if in_dygraph_mode(): if in_dygraph_mode():
if isinstance(min, Variable): if isinstance(min, Variable):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册