未验证 提交 29b844ad 编写于 作者: Y Yang Zhang 提交者: GitHub

Fix clip op attr (#26924)

上级 26c698e2
......@@ -66,7 +66,7 @@ template <typename DeviceContext, typename T>
class ClipKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max = context.Attr<T>("max");
auto max = static_cast<T>(context.Attr<float>("max"));
Tensor max_cpu;
if (context.HasInput("Max")) {
auto* max_t = context.Input<Tensor>("Max");
......@@ -77,8 +77,9 @@ class ClipKernel : public framework::OpKernel<T> {
}
max = max_data[0];
}
max = static_cast<T>(max);
auto min = context.Attr<T>("min");
auto min = context.Attr<float>("min");
Tensor min_cpu;
if (context.HasInput("Min")) {
auto* min_t = context.Input<Tensor>("Min");
......@@ -141,7 +142,7 @@ template <typename DeviceContext, typename T>
class ClipGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max = context.Attr<T>("max");
auto max = static_cast<T>(context.Attr<float>("max"));
Tensor max_cpu;
if (context.HasInput("Max")) {
auto* max_t = context.Input<Tensor>("Max");
......@@ -152,8 +153,9 @@ class ClipGradKernel : public framework::OpKernel<T> {
}
max = max_data[0];
}
max = static_cast<T>(max);
auto min = context.Attr<T>("min");
auto min = context.Attr<float>("min");
Tensor min_cpu;
if (context.HasInput("Min")) {
auto* min_t = context.Input<Tensor>("Min");
......@@ -164,6 +166,7 @@ class ClipGradKernel : public framework::OpKernel<T> {
}
min = min_data[0];
}
min = static_cast<T>(min);
auto* d_out =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
......
......@@ -138,8 +138,9 @@ class TestClipAPI(unittest.TestCase):
out_6 = paddle.clip(images, max=max)
out_7 = paddle.clip(images, max=-1.)
out_8 = paddle.clip(images)
out_9 = paddle.clip(paddle.cast(images, 'float64'), min=0.2, max=0.9)
res1, res2, res3, res4, res5, res6, res7, res8 = exe.run(
res1, res2, res3, res4, res5, res6, res7, res8, res9 = exe.run(
fluid.default_main_program(),
feed={
"image": data,
......@@ -147,7 +148,7 @@ class TestClipAPI(unittest.TestCase):
"max": np.array([0.8]).astype('float32')
},
fetch_list=[
out_1, out_2, out_3, out_4, out_5, out_6, out_7, out_8
out_1, out_2, out_3, out_4, out_5, out_6, out_7, out_8, out_9
])
self.assertTrue(np.allclose(res1, data.clip(0.2, 0.8)))
......@@ -158,6 +159,8 @@ class TestClipAPI(unittest.TestCase):
self.assertTrue(np.allclose(res6, data.clip(max=0.8)))
self.assertTrue(np.allclose(res7, data.clip(max=-1)))
self.assertTrue(np.allclose(res8, data))
self.assertTrue(
np.allclose(res9, data.astype(np.float64).clip(0.2, 0.9)))
def test_clip_dygraph(self):
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
......
......@@ -1611,11 +1611,8 @@ def clip(x, min=None, max=None, name=None):
# [[4.5, 6.4]
"""
np_dtype = np.float32
if x.dtype == VarDesc.VarType.FP64:
np_dtype = np.float64
fmin = float(np.finfo(np_dtype).min)
fmax = float(np.finfo(np_dtype).max)
fmin = float(np.finfo(np.float32).min)
fmax = float(np.finfo(np.float32).max)
if in_dygraph_mode():
if isinstance(min, Variable):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册