diff --git a/paddle/fluid/operators/clip_op.cc b/paddle/fluid/operators/clip_op.cc index bb04d00a2c8403010b7ad846ba918107ac831db4..f727f63eb61d652ed322719bd1251591ecfca68f 100644 --- a/paddle/fluid/operators/clip_op.cc +++ b/paddle/fluid/operators/clip_op.cc @@ -26,12 +26,6 @@ class ClipOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "clip"); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "clip"); auto x_dims = ctx->GetInputDim("X"); - auto max = ctx->Attrs().Get("max"); - auto min = ctx->Attrs().Get("min"); - PADDLE_ENFORCE_LT(min, max, platform::errors::InvalidArgument( - "Max of ClipOp should be greater than min. " - "Received max is %f, received min is %f.", - max, min)); ctx->SetOutputDim("Out", x_dims); ctx->ShareLoD("X", /*->*/ "Out"); } @@ -44,6 +38,14 @@ class ClipOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "Tensor, the input of clip op, data type should be float32 or " "float64."); + AddInput("Min", + "Tensor, the lower bound, data type should be float32 " + "or float64.") + .AsDispensable(); + AddInput("Max", + "Tensor, the upper bound, data type should be float32 " + "or float64.") + .AsDispensable(); AddOutput( "Out", "Tensor, the clipped tensor, with the same shape and data type as " @@ -88,6 +90,12 @@ class ClipGradOpMaker : public framework::SingleGradOpMaker { void Apply(GradOpPtr op) const override { op->SetType("clip_grad"); op->SetInput("X", this->Input("X")); + if (this->HasInput("Min")) { + op->SetInput("Min", this->Input("Min")); + } + if (this->HasInput("Max")) { + op->SetInput("Max", this->Input("Max")); + } op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetAttrMap(this->Attrs()); diff --git a/paddle/fluid/operators/clip_op.h b/paddle/fluid/operators/clip_op.h index daf06f370ffb591e25ad846b94c8284aad19a8dd..efe57f8b3062ead7fab94d64cd2a0044052d4ade 100644 --- a/paddle/fluid/operators/clip_op.h +++ b/paddle/fluid/operators/clip_op.h @@ -60,8 +60,36 @@ template class ClipKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto max = context.Attr("max"); - auto min = context.Attr("min"); + auto max = static_cast(context.Attr("max")); + Tensor max_cpu; + if (context.HasInput("Max")) { + auto* max_t = context.Input("Max"); + auto* max_data = max_t->data(); + if (platform::is_gpu_place(max_t->place())) { + TensorCopySync(*max_t, platform::CPUPlace(), &max_cpu); + max_data = max_cpu.data(); + } + max = max_data[0]; + } + max = static_cast(max); + + auto min = context.Attr("min"); + Tensor min_cpu; + if (context.HasInput("Min")) { + auto* min_t = context.Input("Min"); + auto* min_data = min_t->data(); + if (platform::is_gpu_place(min_t->place())) { + TensorCopySync(*min_t, platform::CPUPlace(), &min_cpu); + min_data = min_cpu.data(); + } + min = min_data[0]; + } + min = static_cast(min); + PADDLE_ENFORCE_LT(min, max, platform::errors::InvalidArgument( + "max should be greater than min. " + "But received min = %f, max = %f", + min, max)); + auto* x_var = context.InputVar("X"); if (x_var->IsType()) { auto* x = context.Input("X"); @@ -75,8 +103,9 @@ class ClipKernel : public framework::OpKernel { } else if (x_var->IsType()) { auto* x = context.Input("X"); auto* out = context.Output("Out"); - PADDLE_ENFORCE_NE(x, out, - "Inplace clip is not allowed when x is SelectedRows"); + PADDLE_ENFORCE_NE( + x, out, platform::errors::InvalidArgument( + "Inplace clip is not allowed when x is SelectedRows")); math::scatter::MergeAdd merge_func; merge_func(context.template device_context(), *x, out); auto* out_tensor = out->mutable_value(); @@ -95,8 +124,32 @@ template class ClipGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto max = context.Attr("max"); - auto min = context.Attr("min"); + auto max = static_cast(context.Attr("max")); + Tensor max_cpu; + if (context.HasInput("Max")) { + auto* max_t = context.Input("Max"); + auto* max_data = max_t->data(); + if (platform::is_gpu_place(max_t->place())) { + TensorCopySync(*max_t, platform::CPUPlace(), &max_cpu); + max_data = max_cpu.data(); + } + max = max_data[0]; + } + max = static_cast(max); + + auto min = context.Attr("min"); + Tensor min_cpu; + if (context.HasInput("Min")) { + auto* min_t = context.Input("Min"); + auto* min_data = min_t->data(); + if (platform::is_gpu_place(min_t->place())) { + TensorCopySync(*min_t, platform::CPUPlace(), &min_cpu); + min_data = min_cpu.data(); + } + min = min_data[0]; + } + min = static_cast(min); + auto* d_out = context.Input(framework::GradVarName("Out")); auto* d_x = diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 2dd755ecafd890289dd6f6aae76da86a0d69f7d9..68203b866f0112835dd377fdf443a70c37ee935f 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -145,6 +145,7 @@ from .tensor.math import log1p #DEFINE_ALIAS # from .tensor.math import erf #DEFINE_ALIAS from .tensor.math import addcmul #DEFINE_ALIAS from .tensor.math import addmm #DEFINE_ALIAS +from .tensor.math import clamp #DEFINE_ALIAS # from .tensor.attribute import rank #DEFINE_ALIAS # from .tensor.attribute import shape #DEFINE_ALIAS # from .tensor.io import save #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_clamp.py b/python/paddle/fluid/tests/unittests/test_clamp.py new file mode 100644 index 0000000000000000000000000000000000000000..ce18321ca9f5f5c66e13221830b22f0bba74d6fc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_clamp.py @@ -0,0 +1,67 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import paddle.tensor as tensor +import paddle.fluid as fluid +import numpy as np +import unittest + + +class TestClampAPI(unittest.TestCase): + def test_clamp(self): + data_shape = [1, 9, 9, 4] + data = np.random.random(data_shape).astype('float32') + images = fluid.data(name='image', shape=data_shape, dtype='float32') + min = fluid.data(name='min', shape=[1], dtype='float32') + max = fluid.data(name='max', shape=[1], dtype='float32') + + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + exe = fluid.Executor(place) + + out_1 = tensor.clamp(images, min=min, max=max) + out_2 = tensor.clamp(images, min=0.2, max=0.9) + out_3 = tensor.clamp(images, min=0.3) + out_4 = tensor.clamp(images, max=0.7) + out_5 = tensor.clamp(images, min=min) + out_6 = tensor.clamp(images, max=max) + + res1, res2, res3, res4, res5, res6 = exe.run( + fluid.default_main_program(), + feed={ + "image": data, + "min": np.array([0.2]).astype('float32'), + "max": np.array([0.8]).astype('float32') + }, + fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6]) + + self.assertTrue(np.allclose(res1, data.clip(0.2, 0.8))) + self.assertTrue(np.allclose(res2, data.clip(0.2, 0.9))) + self.assertTrue(np.allclose(res3, data.clip(min=0.3))) + self.assertTrue(np.allclose(res4, data.clip(max=0.7))) + self.assertTrue(np.allclose(res5, data.clip(min=0.2))) + self.assertTrue(np.allclose(res6, data.clip(max=0.8))) + + +class TestClampError(unittest.TestCase): + def test_errors(self): + x1 = fluid.layers.data(name='x1', shape=[1], dtype="int16") + x2 = fluid.layers.data(name='x2', shape=[1], dtype="int8") + self.assertRaises(TypeError, tensor.clamp, x=x1, min=0.2, max=0.8) + self.assertRaises(TypeError, tensor.clamp, x=x2, min=0.2, max=0.8) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_clip_op.py b/python/paddle/fluid/tests/unittests/test_clip_op.py index 9b03a95ea6a4c636187db6f4f6dd3515d087d717..c98842816932e1dbc3e224b437b05584dd1e36df 100644 --- a/python/paddle/fluid/tests/unittests/test_clip_op.py +++ b/python/paddle/fluid/tests/unittests/test_clip_op.py @@ -24,6 +24,7 @@ from op_test import OpTest class TestClipOp(OpTest): def setUp(self): self.max_relative_error = 0.006 + self.inputs = {} self.initTestCase() input = np.random.random(self.shape).astype("float32") input[np.abs(input - self.min) < self.max_relative_error] = 0.5 @@ -33,10 +34,21 @@ class TestClipOp(OpTest): self.attrs = {} self.attrs['min'] = self.min self.attrs['max'] = self.max - self.outputs = { - 'Out': np.clip(self.inputs['X'], self.attrs['min'], - self.attrs['max']) - } + if 'Min' in self.inputs: + min_v = self.inputs['Min'] + else: + min_v = self.attrs['min'] + + if 'Max' in self.inputs: + max_v = self.inputs['Max'] + else: + max_v = self.attrs['max'] + + input = np.random.random(self.shape).astype("float32") + input[np.abs(input - min_v) < self.max_relative_error] = 0.5 + input[np.abs(input - max_v) < self.max_relative_error] = 0.5 + self.inputs['X'] = input + self.outputs = {'Out': np.clip(self.inputs['X'], min_v, max_v)} def test_check_output(self): self.check_output() @@ -46,8 +58,10 @@ class TestClipOp(OpTest): def initTestCase(self): self.shape = (10, 10) - self.max = 0.7 - self.min = 0.1 + self.max = 0.8 + self.min = 0.3 + self.inputs['Max'] = np.array([0.8]).astype('float32') + self.inputs['Min'] = np.array([0.1]).astype('float32') class TestCase1(TestClipOp): @@ -71,6 +85,15 @@ class TestCase3(TestClipOp): self.min = 0.2 +class TestCase4(TestClipOp): + def initTestCase(self): + self.shape = (4, 8, 8) + self.max = 0.7 + self.min = 0.2 + self.inputs['Max'] = np.array([0.8]).astype('float32') + self.inputs['Min'] = np.array([0.3]).astype('float32') + + class TestClipOpError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 84d5abfdaf68f7c19bd3ef81ddf9fd514f0f2c43..be9c01b32dbe01e0bf66e6c848c29fa0007f63ff 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -123,6 +123,7 @@ from .math import log1p #DEFINE_ALIAS # from .math import erf #DEFINE_ALIAS from .math import addcmul #DEFINE_ALIAS from .math import addmm #DEFINE_ALIAS +from .math import clamp #DEFINE_ALIAS # from .attribute import rank #DEFINE_ALIAS # from .attribute import shape #DEFINE_ALIAS # from .io import save #DEFINE_ALIAS diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index ac9a56a90fef1a1facda4980ce6d8230315da4f8..20be0e0c919dc06631232ffdf6077eb8cf4b0079 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -21,6 +21,7 @@ from paddle.common_ops_import import * from ..fluid import layers from ..fluid.framework import core, _varbase_creator from ..fluid.layers.layer_function_generator import _generate_doc_string_ +import sys # TODO: define math functions # yapf: disable @@ -76,7 +77,8 @@ __all__ = [ 'log1p', # 'erf', 'addcmul', - 'addmm' + 'addmm', + 'clamp', ] # yapf: enable. @@ -1302,6 +1304,7 @@ def addcmul(input, tensor1, tensor2, value=1.0, out=None, name=None): Examples: .. code-block:: python + import paddle import paddle.fluid as fluid input = fluid.data(name='input', dtype='float32', shape=[3, 4]) @@ -1323,3 +1326,89 @@ def addcmul(input, tensor1, tensor2, value=1.0, out=None, name=None): else: out = layers.elementwise_add(input, layers.elementwise_mul(tensor1, tensor2) * value) return out + + +def clamp(input, min=None, max=None, output=None, name=None): + """ + **clampe layer** + + This operator clamps all elements in input into the range [ min, max ] and return + a resulting tensor as the following equation: + + .. math:: + + Out = MIN(MAX(x, min), max) + + Args: + input (Variable): An input N-D Tensor or LoDTensor + with data type float32, float64. + min (float32|Variable): The lower bound with type ``float32`` or a ``Tensor`` + with shape [1] and type ``int32``, ``float32``, ``float64``. + max (float32|Variable): The upper bound with type ``float32`` or a ``Tensor`` + with shape [1] and type ``int32``, ``float32``, ``float64``. + output (Variable, optional): A tensor or LoDTensor. If :attr:`output` is None, + a new tensor will be created as :attr:`output`. Default: None. + name (str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + + Returns: + Variable: A Tensor or LodTensor with the same data type and data shape as input's. + + Examples: + .. code-block:: python + + import paddle + import paddle.fluid as fluid + import numpy as np + + in1 = np.array([[1.2,3.5], + [4.5,6.4]]).astype('float32') + with fluid.dygraph.guard(): + x1 = fluid.dygraph.to_variable(in1) + out1 = paddle.tensor.clamp(x1, min=3.5, max=5.0) + out2 = paddle.tensor.clamp(x1, min=2.5) + print(out1.numpy()) + # [[3.5, 3.5] + # [4.5, 5.0]] + print(out2.numpy()) + # [[2.5, 3.5] + # [[4.5, 6.4] + """ + + assert min is not None or max is not None, "either min or max should be defined." + + if min is not None: + check_type(min, 'min', (float, Variable), 'clamp') + if isinstance(min, Variable): + check_dtype(min.dtype, 'min', ['float32', 'float64', 'int32'], + 'clamp', '(When the type of min in clamp is Variable.)') + if max is not None: + check_type(max, 'max', (float, Variable), 'clamp') + if isinstance(max, Variable): + check_dtype(max.dtype, 'max', ['float32', 'float64', 'int32'], + 'clamp', '(When the type of max in clamp is Variable.)') + + inputs = {'X': input} + attrs = {'min': sys.float_info.min, 'max': sys.float_info.max} + + if isinstance(min, Variable): + min.stop_gradient = True + inputs['Min'] = min + elif min is not None: + attrs['min'] = min + + if isinstance(max, Variable): + max.stop_gradient = True + inputs['Max'] = max + elif max is not None: + attrs['max'] = max + + helper = LayerHelper('clamp', **locals()) + if output is None: + output = helper.create_variable_for_type_inference( + dtype=helper.input_dtype()) + helper.append_op( + type='clip', inputs=inputs, outputs={'Out': [output]}, attrs=attrs) + + return output