未验证 提交 931cba2e 编写于 作者: W wangguanzhong 提交者: GitHub

add clamp api, test=develop (#23273)

* add clamp api, test=develop
上级 a28a63a9
......@@ -26,12 +26,6 @@ class ClipOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "clip");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "clip");
auto x_dims = ctx->GetInputDim("X");
auto max = ctx->Attrs().Get<float>("max");
auto min = ctx->Attrs().Get<float>("min");
PADDLE_ENFORCE_LT(min, max, platform::errors::InvalidArgument(
"Max of ClipOp should be greater than min. "
"Received max is %f, received min is %f.",
max, min));
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
......@@ -44,6 +38,14 @@ class ClipOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X",
"Tensor, the input of clip op, data type should be float32 or "
"float64.");
AddInput("Min",
"Tensor, the lower bound, data type should be float32 "
"or float64.")
.AsDispensable();
AddInput("Max",
"Tensor, the upper bound, data type should be float32 "
"or float64.")
.AsDispensable();
AddOutput(
"Out",
"Tensor, the clipped tensor, with the same shape and data type as "
......@@ -88,6 +90,12 @@ class ClipGradOpMaker : public framework::SingleGradOpMaker<T> {
void Apply(GradOpPtr<T> op) const override {
op->SetType("clip_grad");
op->SetInput("X", this->Input("X"));
if (this->HasInput("Min")) {
op->SetInput("Min", this->Input("Min"));
}
if (this->HasInput("Max")) {
op->SetInput("Max", this->Input("Max"));
}
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
......
......@@ -18,6 +18,7 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
clip, ops::ClipKernel<paddle::platform::CUDADeviceContext, float>,
ops::ClipKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
clip_grad, ops::ClipGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ClipGradKernel<paddle::platform::CUDADeviceContext, double>);
......@@ -60,8 +60,36 @@ template <typename DeviceContext, typename T>
class ClipKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max = context.Attr<T>("max");
auto min = context.Attr<T>("min");
auto max = static_cast<T>(context.Attr<float>("max"));
Tensor max_cpu;
if (context.HasInput("Max")) {
auto* max_t = context.Input<Tensor>("Max");
auto* max_data = max_t->data<T>();
if (platform::is_gpu_place(max_t->place())) {
TensorCopySync(*max_t, platform::CPUPlace(), &max_cpu);
max_data = max_cpu.data<T>();
}
max = max_data[0];
}
max = static_cast<T>(max);
auto min = context.Attr<float>("min");
Tensor min_cpu;
if (context.HasInput("Min")) {
auto* min_t = context.Input<Tensor>("Min");
auto* min_data = min_t->data<T>();
if (platform::is_gpu_place(min_t->place())) {
TensorCopySync(*min_t, platform::CPUPlace(), &min_cpu);
min_data = min_cpu.data<T>();
}
min = min_data[0];
}
min = static_cast<T>(min);
PADDLE_ENFORCE_LT(min, max, platform::errors::InvalidArgument(
"max should be greater than min. "
"But received min = %f, max = %f",
min, max));
auto* x_var = context.InputVar("X");
if (x_var->IsType<framework::LoDTensor>()) {
auto* x = context.Input<framework::LoDTensor>("X");
......@@ -75,8 +103,9 @@ class ClipKernel : public framework::OpKernel<T> {
} else if (x_var->IsType<framework::SelectedRows>()) {
auto* x = context.Input<framework::SelectedRows>("X");
auto* out = context.Output<framework::SelectedRows>("Out");
PADDLE_ENFORCE_NE(x, out,
"Inplace clip is not allowed when x is SelectedRows");
PADDLE_ENFORCE_NE(x, out, platform::errors::InvalidArgument(
"Inplace clip is not allowed "
"when x is SelectedRows"));
math::scatter::MergeAdd<DeviceContext, T> merge_func;
merge_func(context.template device_context<DeviceContext>(), *x, out);
auto* out_tensor = out->mutable_value();
......@@ -95,8 +124,32 @@ template <typename DeviceContext, typename T>
class ClipGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max = context.Attr<T>("max");
auto min = context.Attr<T>("min");
auto max = static_cast<T>(context.Attr<float>("max"));
Tensor max_cpu;
if (context.HasInput("Max")) {
auto* max_t = context.Input<Tensor>("Max");
auto* max_data = max_t->data<T>();
if (platform::is_gpu_place(max_t->place())) {
TensorCopySync(*max_t, platform::CPUPlace(), &max_cpu);
max_data = max_cpu.data<T>();
}
max = max_data[0];
}
max = static_cast<T>(max);
auto min = context.Attr<float>("min");
Tensor min_cpu;
if (context.HasInput("Min")) {
auto* min_t = context.Input<Tensor>("Min");
auto* min_data = min_t->data<T>();
if (platform::is_gpu_place(min_t->place())) {
TensorCopySync(*min_t, platform::CPUPlace(), &min_cpu);
min_data = min_cpu.data<T>();
}
min = min_data[0];
}
min = static_cast<T>(min);
auto* d_out =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto* d_x =
......
......@@ -145,6 +145,7 @@ from .tensor.math import log1p #DEFINE_ALIAS
# from .tensor.math import erf #DEFINE_ALIAS
from .tensor.math import addcmul #DEFINE_ALIAS
from .tensor.math import addmm #DEFINE_ALIAS
from .tensor.math import clamp #DEFINE_ALIAS
# from .tensor.attribute import rank #DEFINE_ALIAS
# from .tensor.attribute import shape #DEFINE_ALIAS
# from .tensor.io import save #DEFINE_ALIAS
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle.tensor as tensor
import paddle.fluid as fluid
import numpy as np
import unittest
class TestClampAPI(unittest.TestCase):
def test_clamp(self):
data_shape = [1, 9, 9, 4]
data = np.random.random(data_shape).astype('float32')
images = fluid.data(name='image', shape=data_shape, dtype='float32')
min = fluid.data(name='min', shape=[1], dtype='float32')
max = fluid.data(name='max', shape=[1], dtype='float32')
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
out_1 = tensor.clamp(images, min=min, max=max)
out_2 = tensor.clamp(images, min=0.2, max=0.9)
out_3 = tensor.clamp(images, min=0.3)
out_4 = tensor.clamp(images, max=0.7)
out_5 = tensor.clamp(images, min=min)
out_6 = tensor.clamp(images, max=max)
res1, res2, res3, res4, res5, res6 = exe.run(
fluid.default_main_program(),
feed={
"image": data,
"min": np.array([0.2]).astype('float32'),
"max": np.array([0.8]).astype('float32')
},
fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6])
self.assertTrue(np.allclose(res1, data.clip(0.2, 0.8)))
self.assertTrue(np.allclose(res2, data.clip(0.2, 0.9)))
self.assertTrue(np.allclose(res3, data.clip(min=0.3)))
self.assertTrue(np.allclose(res4, data.clip(max=0.7)))
self.assertTrue(np.allclose(res5, data.clip(min=0.2)))
self.assertTrue(np.allclose(res6, data.clip(max=0.8)))
class TestClampError(unittest.TestCase):
def test_errors(self):
x1 = fluid.layers.data(name='x1', shape=[1], dtype="int16")
x2 = fluid.layers.data(name='x2', shape=[1], dtype="int8")
self.assertRaises(TypeError, tensor.clamp, x=x1, min=0.2, max=0.8)
self.assertRaises(TypeError, tensor.clamp, x=x2, min=0.2, max=0.8)
if __name__ == '__main__':
unittest.main()
......@@ -24,19 +24,29 @@ from op_test import OpTest
class TestClipOp(OpTest):
def setUp(self):
self.max_relative_error = 0.006
self.inputs = {}
self.initTestCase()
input = np.random.random(self.shape).astype("float32")
input[np.abs(input - self.min) < self.max_relative_error] = 0.5
input[np.abs(input - self.max) < self.max_relative_error] = 0.5
self.op_type = "clip"
self.inputs = {'X': input, }
self.attrs = {}
self.attrs['min'] = self.min
self.attrs['max'] = self.max
self.outputs = {
'Out': np.clip(self.inputs['X'], self.attrs['min'],
self.attrs['max'])
}
if 'Min' in self.inputs:
min_v = self.inputs['Min']
else:
min_v = self.attrs['min']
if 'Max' in self.inputs:
max_v = self.inputs['Max']
else:
max_v = self.attrs['max']
input = np.random.random(self.shape).astype("float32")
input[np.abs(input - min_v) < self.max_relative_error] = 0.5
input[np.abs(input - max_v) < self.max_relative_error] = 0.5
self.inputs['X'] = input
self.outputs = {'Out': np.clip(self.inputs['X'], min_v, max_v)}
def test_check_output(self):
self.check_output()
......@@ -45,9 +55,11 @@ class TestClipOp(OpTest):
self.check_grad(['X'], 'Out')
def initTestCase(self):
self.shape = (10, 10)
self.max = 0.7
self.min = 0.1
self.shape = (4, 10, 10)
self.max = 0.8
self.min = 0.3
self.inputs['Max'] = np.array([0.8]).astype('float32')
self.inputs['Min'] = np.array([0.1]).astype('float32')
class TestCase1(TestClipOp):
......@@ -71,6 +83,15 @@ class TestCase3(TestClipOp):
self.min = 0.2
class TestCase4(TestClipOp):
def initTestCase(self):
self.shape = (4, 8, 8)
self.max = 0.7
self.min = 0.2
self.inputs['Max'] = np.array([0.8]).astype('float32')
self.inputs['Min'] = np.array([0.3]).astype('float32')
class TestClipOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
......
......@@ -124,6 +124,7 @@ from .math import log1p #DEFINE_ALIAS
# from .math import erf #DEFINE_ALIAS
from .math import addcmul #DEFINE_ALIAS
from .math import addmm #DEFINE_ALIAS
from .math import clamp #DEFINE_ALIAS
# from .attribute import rank #DEFINE_ALIAS
# from .attribute import shape #DEFINE_ALIAS
# from .io import save #DEFINE_ALIAS
......
......@@ -21,6 +21,7 @@ from paddle.common_ops_import import *
from ..fluid import layers
from ..fluid.framework import core
from ..fluid.layers.layer_function_generator import _generate_doc_string_
import sys
# TODO: define math functions
# yapf: disable
......@@ -76,7 +77,8 @@ __all__ = [
'log1p',
# 'erf',
'addcmul',
'addmm'
'addmm',
'clamp',
]
# yapf: enable.
......@@ -947,7 +949,6 @@ def mm(input, mat2, out=None, name=None):
'Y': mat2}, outputs={'Out': out})
return out
def addmm(input, x, y, alpha=1.0, beta=1.0, name=None):
"""
**addmm**
......@@ -1274,17 +1275,14 @@ def log1p(x, out=None, name=None):
helper.append_op(type="log1p", inputs={"X": x}, outputs={"Out": out})
return out
def addcmul(input, tensor1, tensor2, value=1.0, out=None, name=None):
"""
Calculate the element-wise multiplication of tensor1 and tensor2,
then multiply the result by value, and add it to input. The shape of input,
tensor1, tensor2 should be broadcastable.
The equation is:
.. math::
out = input + value * tensor1 * tensor2
Args:
input(Variable): The input to be added. A Tensor with type float32, float64, int32, int64.
tensor1(Variable): The tensor to be multiplied. A Tensor with type float32, float64, int32, int64.
......@@ -1296,13 +1294,10 @@ def addcmul(input, tensor1, tensor2, value=1.0, out=None, name=None):
created to save the output. Default: None.
name(str, Optional): For details, please refer to :ref:`api_guide_Name`.
Generally, no setting is required. Default: None.
Returns:
out(Variable): The output result. A Tensor with the same data type as input's.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
input = fluid.data(name='input', dtype='float32', shape=[3, 4])
......@@ -1324,3 +1319,89 @@ def addcmul(input, tensor1, tensor2, value=1.0, out=None, name=None):
else:
out = layers.elementwise_add(input, layers.elementwise_mul(tensor1, tensor2) * value)
return out
def clamp(input, min=None, max=None, output=None, name=None):
"""
**clampe layer**
This operator clamps all elements in input into the range [ min, max ] and return
a resulting tensor as the following equation:
.. math::
Out = MIN(MAX(x, min), max)
Args:
input (Variable): An input N-D Tensor or LoDTensor
with data type float32, float64.
min (float32|Variable): The lower bound with type ``float32`` or a ``Tensor``
with shape [1] and type ``int32``, ``float32``, ``float64``.
max (float32|Variable): The upper bound with type ``float32`` or a ``Tensor``
with shape [1] and type ``int32``, ``float32``, ``float64``.
output (Variable, optional): A tensor or LoDTensor. If :attr:`output` is None,
a new tensor will be created as :attr:`output`. Default: None.
name (str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
Variable: A Tensor or LodTensor with the same data type and data shape as input's.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np
in1 = np.array([[1.2,3.5],
[4.5,6.4]]).astype('float32')
with fluid.dygraph.guard():
x1 = fluid.dygraph.to_variable(in1)
out1 = paddle.tensor.clamp(x1, min=3.5, max=5.0)
out2 = paddle.tensor.clamp(x1, min=2.5)
print(out1.numpy())
# [[3.5, 3.5]
# [4.5, 5.0]]
print(out2.numpy())
# [[2.5, 3.5]
# [[4.5, 6.4]
"""
assert min is not None or max is not None, "either min or max should be defined."
if min is not None:
check_type(min, 'min', (float, Variable), 'clamp')
if isinstance(min, Variable):
check_dtype(min.dtype, 'min', ['float32', 'float64', 'int32'],
'clamp', '(When the type of min in clamp is Variable.)')
if max is not None:
check_type(max, 'max', (float, Variable), 'clamp')
if isinstance(max, Variable):
check_dtype(max.dtype, 'max', ['float32', 'float64', 'int32'],
'clamp', '(When the type of max in clamp is Variable.)')
inputs = {'X': input}
attrs = {'min': sys.float_info.min, 'max': sys.float_info.max}
if isinstance(min, Variable):
min.stop_gradient = True
inputs['Min'] = min
elif min is not None:
attrs['min'] = min
if isinstance(max, Variable):
max.stop_gradient = True
inputs['Max'] = max
elif max is not None:
attrs['max'] = max
helper = LayerHelper('clamp', **locals())
if output is None:
output = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
helper.append_op(
type='clip', inputs=inputs, outputs={'Out': [output]}, attrs=attrs)
return output
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册