未验证 提交 63eef763 编写于 作者: Y Yang Zhang 提交者: GitHub

Fix clip input check (#26683)

* Fix clip input check

* Fix default min/max value

* Allow both max and min to be None

* Register op change

* Revert OP signature change
上级 edf5f317
......@@ -66,7 +66,7 @@ template <typename DeviceContext, typename T>
class ClipKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max = static_cast<T>(context.Attr<float>("max"));
auto max = context.Attr<T>("max");
Tensor max_cpu;
if (context.HasInput("Max")) {
auto* max_t = context.Input<Tensor>("Max");
......@@ -77,9 +77,8 @@ class ClipKernel : public framework::OpKernel<T> {
}
max = max_data[0];
}
max = static_cast<T>(max);
auto min = context.Attr<float>("min");
auto min = context.Attr<T>("min");
Tensor min_cpu;
if (context.HasInput("Min")) {
auto* min_t = context.Input<Tensor>("Min");
......@@ -90,11 +89,12 @@ class ClipKernel : public framework::OpKernel<T> {
}
min = min_data[0];
}
min = static_cast<T>(min);
PADDLE_ENFORCE_LT(min, max, platform::errors::InvalidArgument(
"max should be greater than min. "
"But received min = %f, max = %f",
min, max));
PADDLE_ENFORCE_LE(min, max,
platform::errors::InvalidArgument(
"max should be greater than or equal to min. "
"But received min = %f, max = %f",
min, max));
auto* x_var = context.InputVar("X");
if (x_var->IsType<framework::LoDTensor>()) {
......@@ -141,7 +141,7 @@ template <typename DeviceContext, typename T>
class ClipGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max = static_cast<T>(context.Attr<float>("max"));
auto max = context.Attr<T>("max");
Tensor max_cpu;
if (context.HasInput("Max")) {
auto* max_t = context.Input<Tensor>("Max");
......@@ -152,9 +152,8 @@ class ClipGradKernel : public framework::OpKernel<T> {
}
max = max_data[0];
}
max = static_cast<T>(max);
auto min = context.Attr<float>("min");
auto min = context.Attr<T>("min");
Tensor min_cpu;
if (context.HasInput("Min")) {
auto* min_t = context.Input<Tensor>("Min");
......@@ -165,7 +164,6 @@ class ClipGradKernel : public framework::OpKernel<T> {
}
min = min_data[0];
}
min = static_cast<T>(min);
auto* d_out =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
......
......@@ -93,6 +93,13 @@ class TestCase4(TestClipOp):
self.inputs['Min'] = np.array([0.3]).astype('float32')
class TestCase5(TestClipOp):
def initTestCase(self):
self.shape = (4, 8, 16)
self.max = 0.5
self.min = 0.5
class TestClipOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
......@@ -112,6 +119,7 @@ class TestClipOpError(unittest.TestCase):
class TestClipAPI(unittest.TestCase):
def test_clip(self):
paddle.enable_static()
data_shape = [1, 9, 9, 4]
data = np.random.random(data_shape).astype('float32')
images = fluid.data(name='image', shape=data_shape, dtype='float32')
......@@ -128,15 +136,19 @@ class TestClipAPI(unittest.TestCase):
out_4 = paddle.clip(images, max=0.7)
out_5 = paddle.clip(images, min=min)
out_6 = paddle.clip(images, max=max)
out_7 = paddle.clip(images, max=-1.)
out_8 = paddle.clip(images)
res1, res2, res3, res4, res5, res6 = exe.run(
res1, res2, res3, res4, res5, res6, res7, res8 = exe.run(
fluid.default_main_program(),
feed={
"image": data,
"min": np.array([0.2]).astype('float32'),
"max": np.array([0.8]).astype('float32')
},
fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6])
fetch_list=[
out_1, out_2, out_3, out_4, out_5, out_6, out_7, out_8
])
self.assertTrue(np.allclose(res1, data.clip(0.2, 0.8)))
self.assertTrue(np.allclose(res2, data.clip(0.2, 0.9)))
......@@ -144,6 +156,8 @@ class TestClipAPI(unittest.TestCase):
self.assertTrue(np.allclose(res4, data.clip(max=0.7)))
self.assertTrue(np.allclose(res5, data.clip(min=0.2)))
self.assertTrue(np.allclose(res6, data.clip(max=0.8)))
self.assertTrue(np.allclose(res7, data.clip(max=-1)))
self.assertTrue(np.allclose(res8, data))
def test_clip_dygraph(self):
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
......@@ -163,10 +177,8 @@ class TestClipAPI(unittest.TestCase):
paddle.enable_static()
x1 = fluid.data(name='x1', shape=[1], dtype="int16")
x2 = fluid.data(name='x2', shape=[1], dtype="int8")
x3 = fluid.data(name='x3', shape=[1], dtype="float32")
self.assertRaises(TypeError, paddle.clip, x=x1, min=0.2, max=0.8)
self.assertRaises(TypeError, paddle.clip, x=x2, min=0.2, max=0.8)
self.assertRaises(Exception, paddle.clip, x=x3)
if __name__ == '__main__':
......
......@@ -15,6 +15,7 @@
math functions
"""
from __future__ import print_function
import numpy as np
from paddle.common_ops_import import *
from paddle.tensor import cast
......@@ -24,7 +25,6 @@ from ..fluid.framework import core, _varbase_creator, in_dygraph_mode, Variable
from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype
from ..fluid.layers.layer_function_generator import _generate_doc_string_, generate_activation_fn, generate_layer_fn
import sys
# TODO: define math functions
# yapf: disable
......@@ -1611,11 +1611,15 @@ def clip(x, min=None, max=None, name=None):
# [[4.5, 6.4]
"""
assert min is not None or max is not None, "either min or max should be defined."
np_dtype = np.float32
if x.dtype == VarDesc.VarType.FP64:
np_dtype = np.float64
fmin = float(np.finfo(np_dtype).min)
fmax = float(np.finfo(np_dtype).max)
if in_dygraph_mode():
min = sys.float_info.min if min is None else min
max = sys.float_info.max if max is None else max
min = fmin if min is None else min
max = fmax if max is None else max
return core.ops.clip(x, "min", min, "max", max)
if min is not None:
......@@ -1629,10 +1633,10 @@ def clip(x, min=None, max=None, name=None):
check_dtype(max.dtype, 'max', ['float32', 'float64', 'int32'],
'clip', '(When the type of max in clip is Variable.)')
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'clip')
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'clip')
inputs = {'X': x}
attrs = {'min': sys.float_info.min, 'max': sys.float_info.max}
attrs = {'min': fmin, 'max': fmax}
if isinstance(min, Variable):
min.stop_gradient = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册