未验证 提交 e662d1e0 编写于 作者: Y Yang Zhang 提交者: GitHub

Update `paddle.clamp` (#25906)

* Update `paddle.clamp`

rename to `paddle.clip`
add fast path for dygraph mode
remove `out`
rename `input` -> `x`
update doc sample

* Fix leftover `Variable` wording

* Indent doc with spaces

* Remove `:alias` in docs

* Update `enable_imperative` -> `disable_static`

* Remove `imperative`

also trigger CI

* Update tests for better coverage

* Rebase to fix `cosine_similarity`

* Fix `cosine_similarity` some more
上级 1a72a903
...@@ -181,7 +181,7 @@ from .tensor.math import log1p #DEFINE_ALIAS ...@@ -181,7 +181,7 @@ from .tensor.math import log1p #DEFINE_ALIAS
from .tensor.math import erf #DEFINE_ALIAS from .tensor.math import erf #DEFINE_ALIAS
from .tensor.math import addcmul #DEFINE_ALIAS from .tensor.math import addcmul #DEFINE_ALIAS
from .tensor.math import addmm #DEFINE_ALIAS from .tensor.math import addmm #DEFINE_ALIAS
from .tensor.math import clamp #DEFINE_ALIAS from .tensor.math import clip #DEFINE_ALIAS
from .tensor.math import trace #DEFINE_ALIAS from .tensor.math import trace #DEFINE_ALIAS
from .tensor.math import kron #DEFINE_ALIAS from .tensor.math import kron #DEFINE_ALIAS
from .tensor.math import prod #DEFINE_ALIAS from .tensor.math import prod #DEFINE_ALIAS
......
...@@ -12205,8 +12205,6 @@ def logical_not(x, out=None, name=None): ...@@ -12205,8 +12205,6 @@ def logical_not(x, out=None, name=None):
@templatedoc() @templatedoc()
def clip(x, min, max, name=None): def clip(x, min, max, name=None):
""" """
:alias_main: paddle.nn.clip
:alias: paddle.nn.clip,paddle.nn.clip.clip
:old_api: paddle.fluid.layers.clip :old_api: paddle.fluid.layers.clip
${comment} ${comment}
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle.tensor as tensor
import paddle.fluid as fluid
import numpy as np
import unittest
class TestClampAPI(unittest.TestCase):
def test_dygraph_clamp(self):
in1 = np.array([[1.2, 3.5], [4.5, 6.4]]).astype('float32')
with fluid.dygraph.guard():
x1 = fluid.dygraph.to_variable(in1)
out1 = tensor.clamp(x1, min=3.5, max=5.0)
out2 = tensor.clamp(x1, min=2.5)
self.assertTrue(
np.allclose(
out1.numpy(), in1.clip(
min=3.5, max=5.0)))
self.assertTrue(np.allclose(out2.numpy(), in1.clip(min=2.5)))
def test_clamp(self):
data_shape = [1, 9, 9, 4]
data = np.random.random(data_shape).astype('float32')
images = fluid.data(name='image', shape=data_shape, dtype='float32')
min = fluid.data(name='min', shape=[1], dtype='float32')
max = fluid.data(name='max', shape=[1], dtype='float32')
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
out_1 = tensor.clamp(images, min=min, max=max)
out_2 = tensor.clamp(images, min=0.2, max=0.9)
out_3 = tensor.clamp(images, min=0.3)
out_4 = tensor.clamp(images, max=0.7)
out_5 = tensor.clamp(images, min=min)
out_6 = tensor.clamp(images, max=max)
res1, res2, res3, res4, res5, res6 = exe.run(
fluid.default_main_program(),
feed={
"image": data,
"min": np.array([0.2]).astype('float32'),
"max": np.array([0.8]).astype('float32')
},
fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6])
self.assertTrue(np.allclose(res1, data.clip(0.2, 0.8)))
self.assertTrue(np.allclose(res2, data.clip(0.2, 0.9)))
self.assertTrue(np.allclose(res3, data.clip(min=0.3)))
self.assertTrue(np.allclose(res4, data.clip(max=0.7)))
self.assertTrue(np.allclose(res5, data.clip(min=0.2)))
self.assertTrue(np.allclose(res6, data.clip(max=0.8)))
class TestClampError(unittest.TestCase):
def test_errors(self):
x1 = fluid.layers.data(name='x1', shape=[1], dtype="int16")
x2 = fluid.layers.data(name='x2', shape=[1], dtype="int8")
self.assertRaises(TypeError, tensor.clamp, x=x1, min=0.2, max=0.8)
self.assertRaises(TypeError, tensor.clamp, x=x2, min=0.2, max=0.8)
if __name__ == '__main__':
unittest.main()
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
from op_test import OpTest from op_test import OpTest
...@@ -109,5 +110,64 @@ class TestClipOpError(unittest.TestCase): ...@@ -109,5 +110,64 @@ class TestClipOpError(unittest.TestCase):
self.assertRaises(TypeError, test_dtype) self.assertRaises(TypeError, test_dtype)
class TestClipAPI(unittest.TestCase):
def test_clip(self):
data_shape = [1, 9, 9, 4]
data = np.random.random(data_shape).astype('float32')
images = fluid.data(name='image', shape=data_shape, dtype='float32')
min = fluid.data(name='min', shape=[1], dtype='float32')
max = fluid.data(name='max', shape=[1], dtype='float32')
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
out_1 = paddle.clip(images, min=min, max=max)
out_2 = paddle.clip(images, min=0.2, max=0.9)
out_3 = paddle.clip(images, min=0.3)
out_4 = paddle.clip(images, max=0.7)
out_5 = paddle.clip(images, min=min)
out_6 = paddle.clip(images, max=max)
res1, res2, res3, res4, res5, res6 = exe.run(
fluid.default_main_program(),
feed={
"image": data,
"min": np.array([0.2]).astype('float32'),
"max": np.array([0.8]).astype('float32')
},
fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6])
self.assertTrue(np.allclose(res1, data.clip(0.2, 0.8)))
self.assertTrue(np.allclose(res2, data.clip(0.2, 0.9)))
self.assertTrue(np.allclose(res3, data.clip(min=0.3)))
self.assertTrue(np.allclose(res4, data.clip(max=0.7)))
self.assertTrue(np.allclose(res5, data.clip(min=0.2)))
self.assertTrue(np.allclose(res6, data.clip(max=0.8)))
def test_clip_dygraph(self):
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
paddle.disable_static(place)
data_shape = [1, 9, 9, 4]
data = np.random.random(data_shape).astype('float32')
images = paddle.to_variable(data, dtype='float32')
out_1 = paddle.clip(images, min=0.2, max=0.8)
out_2 = paddle.clip(images, min=0.2, max=0.9)
self.assertTrue(np.allclose(out_1.numpy(), data.clip(0.2, 0.8)))
self.assertTrue(np.allclose(out_2.numpy(), data.clip(0.2, 0.9)))
def test_errors(self):
paddle.enable_static()
x1 = fluid.data(name='x1', shape=[1], dtype="int16")
x2 = fluid.data(name='x2', shape=[1], dtype="int8")
x3 = fluid.data(name='x3', shape=[1], dtype="float32")
self.assertRaises(TypeError, paddle.clip, x=x1, min=0.2, max=0.8)
self.assertRaises(TypeError, paddle.clip, x=x2, min=0.2, max=0.8)
self.assertRaises(Exception, paddle.clip, x=x3)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -29,10 +29,10 @@ class TestCosineSimilarityAPI(unittest.TestCase): ...@@ -29,10 +29,10 @@ class TestCosineSimilarityAPI(unittest.TestCase):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0)) self.places.append(paddle.CUDAPlace(0))
def _get_numpy_out(self, x1, x2, dim=1, eps=1e-8): def _get_numpy_out(self, x1, x2, axis=1, eps=1e-8):
w12 = np.sum(x1 * x2, axis=dim) w12 = np.sum(x1 * x2, axis=axis)
w1 = np.sum(x1 * x1, axis=dim) w1 = np.sum(x1 * x1, axis=axis)
w2 = np.sum(x2 * x2, axis=dim) w2 = np.sum(x2 * x2, axis=axis)
n12 = np.sqrt(np.clip(w1 * w2, eps * eps, None)) n12 = np.sqrt(np.clip(w1 * w2, eps * eps, None))
cos_sim = w12 / n12 cos_sim = w12 / n12
return cos_sim return cos_sim
...@@ -42,7 +42,7 @@ class TestCosineSimilarityAPI(unittest.TestCase): ...@@ -42,7 +42,7 @@ class TestCosineSimilarityAPI(unittest.TestCase):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
shape = [10, 15] shape = [10, 15]
dim = 1 axis = 1
eps = 1e-8 eps = 1e-8
np.random.seed(0) np.random.seed(0)
np_x1 = np.random.rand(*shape).astype(np.float32) np_x1 = np.random.rand(*shape).astype(np.float32)
...@@ -50,14 +50,14 @@ class TestCosineSimilarityAPI(unittest.TestCase): ...@@ -50,14 +50,14 @@ class TestCosineSimilarityAPI(unittest.TestCase):
x1 = paddle.data(name="x1", shape=shape) x1 = paddle.data(name="x1", shape=shape)
x2 = paddle.data(name="x2", shape=shape) x2 = paddle.data(name="x2", shape=shape)
result = F.cosine_similarity(x1, x2, dim=dim, eps=eps) result = F.cosine_similarity(x1, x2, axis=axis, eps=eps)
exe = Executor(place) exe = Executor(place)
fetches = exe.run(default_main_program(), fetches = exe.run(default_main_program(),
feed={"x1": np_x1, feed={"x1": np_x1,
"x2": np_x2}, "x2": np_x2},
fetch_list=[result]) fetch_list=[result])
np_out = self._get_numpy_out(np_x1, np_x2, dim=dim, eps=eps) np_out = self._get_numpy_out(np_x1, np_x2, axis=axis, eps=eps)
self.assertTrue(np.allclose(fetches[0], np_out)) self.assertTrue(np.allclose(fetches[0], np_out))
def test_static(self): def test_static(self):
...@@ -68,16 +68,16 @@ class TestCosineSimilarityAPI(unittest.TestCase): ...@@ -68,16 +68,16 @@ class TestCosineSimilarityAPI(unittest.TestCase):
paddle.disable_static() paddle.disable_static()
shape = [10, 15] shape = [10, 15]
dim = 1 axis = 1
eps = 1e-8 eps = 1e-8
np.random.seed(1) np.random.seed(1)
np_x1 = np.random.rand(*shape).astype(np.float32) np_x1 = np.random.rand(*shape).astype(np.float32)
np_x2 = np.random.rand(*shape).astype(np.float32) np_x2 = np.random.rand(*shape).astype(np.float32)
np_out = self._get_numpy_out(np_x1, np_x2, dim=dim, eps=eps) np_out = self._get_numpy_out(np_x1, np_x2, axis=axis, eps=eps)
tesnor_x1 = paddle.to_variable(np_x1) tesnor_x1 = paddle.to_variable(np_x1)
tesnor_x2 = paddle.to_variable(np_x2) tesnor_x2 = paddle.to_variable(np_x2)
y = F.cosine_similarity(tesnor_x1, tesnor_x2, dim=dim, eps=eps) y = F.cosine_similarity(tesnor_x1, tesnor_x2, axis=axis, eps=eps)
self.assertTrue(np.allclose(y.numpy(), np_out)) self.assertTrue(np.allclose(y.numpy(), np_out))
...@@ -85,16 +85,16 @@ class TestCosineSimilarityAPI(unittest.TestCase): ...@@ -85,16 +85,16 @@ class TestCosineSimilarityAPI(unittest.TestCase):
paddle.disable_static() paddle.disable_static()
shape = [12, 13] shape = [12, 13]
dim = 0 axis = 0
eps = 1e-6 eps = 1e-6
np.random.seed(1) np.random.seed(1)
np_x1 = np.random.rand(*shape).astype(np.float32) np_x1 = np.random.rand(*shape).astype(np.float32)
np_x2 = np.random.rand(*shape).astype(np.float32) np_x2 = np.random.rand(*shape).astype(np.float32)
np_out = self._get_numpy_out(np_x1, np_x2, dim=dim, eps=eps) np_out = self._get_numpy_out(np_x1, np_x2, axis=axis, eps=eps)
tesnor_x1 = paddle.to_variable(np_x1) tesnor_x1 = paddle.to_variable(np_x1)
tesnor_x2 = paddle.to_variable(np_x2) tesnor_x2 = paddle.to_variable(np_x2)
y = F.cosine_similarity(tesnor_x1, tesnor_x2, dim=dim, eps=eps) y = F.cosine_similarity(tesnor_x1, tesnor_x2, axis=axis, eps=eps)
self.assertTrue(np.allclose(y.numpy(), np_out)) self.assertTrue(np.allclose(y.numpy(), np_out))
...@@ -103,16 +103,16 @@ class TestCosineSimilarityAPI(unittest.TestCase): ...@@ -103,16 +103,16 @@ class TestCosineSimilarityAPI(unittest.TestCase):
shape1 = [10, 12, 10] shape1 = [10, 12, 10]
shape2 = [10, 1, 10] shape2 = [10, 1, 10]
dim = 2 axis = 2
eps = 1e-6 eps = 1e-6
np.random.seed(1) np.random.seed(1)
np_x1 = np.random.rand(*shape1).astype(np.float32) np_x1 = np.random.rand(*shape1).astype(np.float32)
np_x2 = np.random.rand(*shape2).astype(np.float32) np_x2 = np.random.rand(*shape2).astype(np.float32)
np_out = self._get_numpy_out(np_x1, np_x2, dim=dim, eps=eps) np_out = self._get_numpy_out(np_x1, np_x2, axis=axis, eps=eps)
tesnor_x1 = paddle.to_variable(np_x1) tesnor_x1 = paddle.to_variable(np_x1)
tesnor_x2 = paddle.to_variable(np_x2) tesnor_x2 = paddle.to_variable(np_x2)
y = F.cosine_similarity(tesnor_x1, tesnor_x2, dim=dim, eps=eps) y = F.cosine_similarity(tesnor_x1, tesnor_x2, axis=axis, eps=eps)
self.assertTrue(np.allclose(y.numpy(), np_out)) self.assertTrue(np.allclose(y.numpy(), np_out))
......
...@@ -28,9 +28,9 @@ from ...fluid.layers import assign #DEFINE_ALIAS ...@@ -28,9 +28,9 @@ from ...fluid.layers import assign #DEFINE_ALIAS
from ...fluid.layers import squeeze #DEFINE_ALIAS from ...fluid.layers import squeeze #DEFINE_ALIAS
from ...fluid.layers import unsqueeze #DEFINE_ALIAS from ...fluid.layers import unsqueeze #DEFINE_ALIAS
from ...fluid.layers import elementwise_mul #DEFINE_ALIAS from ...fluid.layers import elementwise_mul #DEFINE_ALIAS
from ...tensor import clamp #DEFINE_ALIAS from ...tensor import clip
from ...tensor import sum #DEFINE_ALIAS from ...tensor import sum
from ...tensor import sqrt #DEFINE_ALIAS from ...tensor import sqrt
#from ...fluid.layers import fc #DEFINE_ALIAS #from ...fluid.layers import fc #DEFINE_ALIAS
from ...fluid.layers import pad_constant_like #DEFINE_ALIAS from ...fluid.layers import pad_constant_like #DEFINE_ALIAS
...@@ -635,17 +635,17 @@ def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None): ...@@ -635,17 +635,17 @@ def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None):
return out return out
def cosine_similarity(x1, x2, dim=1, eps=1e-8): def cosine_similarity(x1, x2, axis=1, eps=1e-8):
""" """
Compute cosine similarity between x1 and x2 along dim. Compute cosine similarity between x1 and x2 along axis.
Parameters: Parameters:
x1 (Tensor): First input. float32/double. x1 (Tensor): First input. float32/double.
x2 (Tensor): Second input. float32/double. x2 (Tensor): Second input. float32/double.
dim (int): Dimension of vectors to compute cosine similarity. Default is 1. axis (int): Dimension of vectors to compute cosine similarity. Default is 1.
eps(float): Small value to avoid division by zero. Default is 1e-8. eps(float): Small value to avoid division by zero. Default is 1e-8.
Returns: a Tensor representing cosine similarity between x1 and x2 along dim. Returns: a Tensor representing cosine similarity between x1 and x2 along axis.
Return Type: Tensor Return Type: Tensor
Examples: Examples:
...@@ -659,7 +659,7 @@ def cosine_similarity(x1, x2, dim=1, eps=1e-8): ...@@ -659,7 +659,7 @@ def cosine_similarity(x1, x2, dim=1, eps=1e-8):
[0.9098952 0.15715368 0.8671125 0.3156102 ] [0.9098952 0.15715368 0.8671125 0.3156102 ]
[0.4427798 0.54136837 0.5276275 0.32394758] [0.4427798 0.54136837 0.5276275 0.32394758]
[0.3769419 0.8535014 0.48041078 0.9256797 ]] [0.3769419 0.8535014 0.48041078 0.9256797 ]]
dim = 1 axis = 1
eps = 1e-8 eps = 1e-8
Out: [0.5275037 0.8368967 0.75037485 0.9245899] Out: [0.5275037 0.8368967 0.75037485 0.9245899]
...@@ -675,14 +675,14 @@ def cosine_similarity(x1, x2, dim=1, eps=1e-8): ...@@ -675,14 +675,14 @@ def cosine_similarity(x1, x2, dim=1, eps=1e-8):
x2 = np.random.rand(2,3) x2 = np.random.rand(2,3)
x1 = paddle.to_tensor(x1) x1 = paddle.to_tensor(x1)
x2 = paddle.to_tensor(x2) x2 = paddle.to_tensor(x2)
result = paddle.nn.functional.cosine_similarity(x1, x2, dim=0) result = paddle.nn.functional.cosine_similarity(x1, x2, axis=0)
print(result.numpy()) print(result.numpy())
# [0.99806249 0.9817672 0.94987036] # [0.99806249 0.9817672 0.94987036]
""" """
w12 = sum(elementwise_mul(x1, x2), axis=dim) w12 = sum(elementwise_mul(x1, x2), axis=axis)
w1 = sum(elementwise_mul(x1, x1), axis=dim) w1 = sum(elementwise_mul(x1, x1), axis=axis)
w2 = sum(elementwise_mul(x2, x2), axis=dim) w2 = sum(elementwise_mul(x2, x2), axis=axis)
n12 = sqrt(clamp(w1 * w2, min=eps * eps)) n12 = sqrt(clip(w1 * w2, min=eps * eps))
cos_sim = w12 / n12 cos_sim = w12 / n12
return cos_sim return cos_sim
...@@ -154,7 +154,7 @@ from .math import log1p #DEFINE_ALIAS ...@@ -154,7 +154,7 @@ from .math import log1p #DEFINE_ALIAS
from .math import erf #DEFINE_ALIAS from .math import erf #DEFINE_ALIAS
from .math import addcmul #DEFINE_ALIAS from .math import addcmul #DEFINE_ALIAS
from .math import addmm #DEFINE_ALIAS from .math import addmm #DEFINE_ALIAS
from .math import clamp #DEFINE_ALIAS from .math import clip #DEFINE_ALIAS
from .math import trace #DEFINE_ALIAS from .math import trace #DEFINE_ALIAS
from .math import kron #DEFINE_ALIAS from .math import kron #DEFINE_ALIAS
from .math import prod #DEFINE_ALIAS from .math import prod #DEFINE_ALIAS
......
...@@ -121,7 +121,7 @@ __all__ = [ ...@@ -121,7 +121,7 @@ __all__ = [
'erf', 'erf',
'addcmul', 'addcmul',
'addmm', 'addmm',
'clamp', 'clip',
'trace', 'trace',
'kron' 'kron'
] ]
...@@ -1326,14 +1326,14 @@ def addcmul(input, tensor1, tensor2, value=1.0, name=None): ...@@ -1326,14 +1326,14 @@ def addcmul(input, tensor1, tensor2, value=1.0, name=None):
return out return out
def clamp(input, min=None, max=None, name=None): def clip(x, min=None, max=None, name=None):
""" """
:alias_main: paddle.clamp :alias_main: paddle.clip
:alias: paddle.clamp,paddle.tensor.clamp,paddle.tensor.math.clamp :alias: paddle.clip,paddle.tensor.clip,paddle.tensor.math.clip
**clampe layer** **clip layer**
This operator clamps all elements in input into the range [ min, max ] and return This operator clip all elements in input into the range [ min, max ] and return
a resulting tensor as the following equation: a resulting tensor as the following equation:
.. math:: .. math::
...@@ -1341,38 +1341,35 @@ def clamp(input, min=None, max=None, name=None): ...@@ -1341,38 +1341,35 @@ def clamp(input, min=None, max=None, name=None):
Out = MIN(MAX(x, min), max) Out = MIN(MAX(x, min), max)
Args: Args:
input (Variable): An input N-D Tensor or LoDTensor x (Tensor): An N-D Tensor with data type float32 or float64.
with data type float32, float64. min (float32|Tensor): The lower bound with type ``float32`` or a ``Tensor``
min (float32|Variable): The lower bound with type ``float32`` or a ``Tensor``
with shape [1] and type ``int32``, ``float32``, ``float64``. with shape [1] and type ``int32``, ``float32``, ``float64``.
max (float32|Variable): The upper bound with type ``float32`` or a ``Tensor`` max (float32|Tensor): The upper bound with type ``float32`` or a ``Tensor``
with shape [1] and type ``int32``, ``float32``, ``float64``. with shape [1] and type ``int32``, ``float32``, ``float64``.
name (str, optional): The default value is None. Normally there is no name (str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`. refer to :ref:`api_guide_Name`.
Returns: Returns:
Variable: A Tensor or LodTensor with the same data type and data shape as input's. Tensor: A Tensor with the same data type and data shape as input.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.fluid as fluid
import numpy as np import numpy as np
in1 = np.array([[1.2,3.5], paddle.disable_static()
[4.5,6.4]]).astype('float32') x = np.array([[1.2,3.5], [4.5,6.4]]).astype('float32')
with fluid.dygraph.guard(): x1 = paddle.to_variable(x)
x1 = fluid.dygraph.to_variable(in1) out1 = paddle.clip(x1, min=3.5, max=5.0)
out1 = paddle.tensor.clamp(x1, min=3.5, max=5.0) out2 = paddle.clip(x1, min=2.5)
out2 = paddle.tensor.clamp(x1, min=2.5) print(out1.numpy())
print(out1.numpy()) # [[3.5, 3.5]
# [[3.5, 3.5] # [4.5, 5.0]]
# [4.5, 5.0]] print(out2.numpy())
print(out2.numpy()) # [[2.5, 3.5]
# [[2.5, 3.5] # [[4.5, 6.4]
# [[4.5, 6.4]
""" """
assert min is not None or max is not None, "either min or max should be defined." assert min is not None or max is not None, "either min or max should be defined."
...@@ -1380,20 +1377,22 @@ def clamp(input, min=None, max=None, name=None): ...@@ -1380,20 +1377,22 @@ def clamp(input, min=None, max=None, name=None):
if in_dygraph_mode(): if in_dygraph_mode():
min = sys.float_info.min if min is None else min min = sys.float_info.min if min is None else min
max = sys.float_info.max if max is None else max max = sys.float_info.max if max is None else max
return core.ops.clip(input, "min", min, "max", max) return core.ops.clip(x, "min", min, "max", max)
if min is not None: if min is not None:
check_type(min, 'min', (float, Variable), 'clamp') check_type(min, 'min', (float, int, Variable), 'clip')
if isinstance(min, Variable): if isinstance(min, Variable):
check_dtype(min.dtype, 'min', ['float32', 'float64', 'int32'], check_dtype(min.dtype, 'min', ['float32', 'float64', 'int32'],
'clamp', '(When the type of min in clamp is Variable.)') 'clip', '(When the type of min in clip is Variable.)')
if max is not None: if max is not None:
check_type(max, 'max', (float, Variable), 'clamp') check_type(max, 'max', (float, int, Variable), 'clip')
if isinstance(max, Variable): if isinstance(max, Variable):
check_dtype(max.dtype, 'max', ['float32', 'float64', 'int32'], check_dtype(max.dtype, 'max', ['float32', 'float64', 'int32'],
'clamp', '(When the type of max in clamp is Variable.)') 'clip', '(When the type of max in clip is Variable.)')
inputs = {'X': input} check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'clip')
inputs = {'X': x}
attrs = {'min': sys.float_info.min, 'max': sys.float_info.max} attrs = {'min': sys.float_info.min, 'max': sys.float_info.max}
if isinstance(min, Variable): if isinstance(min, Variable):
...@@ -1408,9 +1407,9 @@ def clamp(input, min=None, max=None, name=None): ...@@ -1408,9 +1407,9 @@ def clamp(input, min=None, max=None, name=None):
elif max is not None: elif max is not None:
attrs['max'] = max attrs['max'] = max
helper = LayerHelper('clamp', **locals()) helper = LayerHelper('clip', **locals())
output = helper.create_variable_for_type_inference( output = helper.create_variable_for_type_inference(
dtype=helper.input_dtype()) dtype=helper.input_dtype())
helper.append_op( helper.append_op(
type='clip', inputs=inputs, outputs={'Out': [output]}, attrs=attrs) type='clip', inputs=inputs, outputs={'Out': [output]}, attrs=attrs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册