未验证 提交 6ea8bfc6 编写于 作者: H heyanru 提交者: GitHub

[Fluid Clean] Remove paddle.fluid.layers.nn.reduce_max/min (#48236)

上级 baa1f663
......@@ -530,9 +530,9 @@ class Categorical(Distribution):
"""
check_type(other, 'other', Categorical, 'kl_divergence')
logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True)
other_logits = other.logits - nn.reduce_max(
other.logits, dim=-1, keep_dim=True
logits = self.logits - paddle.max(self.logits, axis=-1, keepdim=True)
other_logits = other.logits - paddle.max(
other.logits, axis=-1, keepdim=True
)
e_logits = paddle.exp(logits)
other_e_logits = paddle.exp(other_logits)
......@@ -554,7 +554,7 @@ class Categorical(Distribution):
Variable: Shannon entropy of Categorical distribution. The data type is float32.
"""
logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True)
logits = self.logits - paddle.max(self.logits, axis=-1, keepdim=True)
e_logits = paddle.exp(logits)
z = nn.reduce_sum(e_logits, dim=-1, keep_dim=True)
prob = e_logits / z
......
......@@ -79,8 +79,6 @@ __all__ = [
'data_norm',
'reduce_sum',
'reduce_mean',
'reduce_max',
'reduce_min',
'reduce_all',
'reduce_any',
'dropout',
......@@ -190,7 +188,7 @@ OP_NAMEMAPPING = {
def _get_reduce_dim(dim, input):
"""
Internal function for reduce_sum, reduce_mean, reduce_max, reduce_min, reduce_prod.
Internal function for reduce_sum, reduce_mean, reduce_prod.
It computes the attribute reduce_all value based on axis.
"""
if dim is not None and not isinstance(dim, list):
......@@ -3938,150 +3936,6 @@ def reduce_mean(input, dim=None, keep_dim=False, name=None):
return paddle.mean(x=input, axis=dim, keepdim=keep_dim, name=name)
def reduce_max(input, dim=None, keep_dim=False, name=None):
"""
Computes the maximum of tensor elements over the given dimension.
Args:
input (Variable): The input variable which is a Tensor, the data type is float32,
float64, int32, int64.
dim (list|int, optional): The dimension along which the maximum is computed.
If :attr:`None`, compute the maximum over all elements of
:attr:`input` and return a Tensor variable with a single element,
otherwise must be in the range :math:`[-rank(input), rank(input))`.
If :math:`dim[i] < 0`, the dimension to reduce is :math:`rank + dim[i]`.
keep_dim (bool, optional): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have one fewer dimension
than the :attr:`input` unless :attr:`keep_dim` is true, default
value is False.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
Variable: Tensor, results of maximum on the specified dim of input tensor,
it's data type is the same as input's Tensor.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle
paddle.enable_static()
# x is a Tensor variable with following elements:
# [[0.2, 0.3, 0.5, 0.9]
# [0.1, 0.2, 0.6, 0.7]]
# Each example is followed by the corresponding output tensor.
x = fluid.data(name='x', shape=[2, 4], dtype='float32')
fluid.layers.reduce_max(x) # [0.9]
fluid.layers.reduce_max(x, dim=0) # [0.2, 0.3, 0.6, 0.9]
fluid.layers.reduce_max(x, dim=-1) # [0.9, 0.7]
fluid.layers.reduce_max(x, dim=1, keep_dim=True) # [[0.9], [0.7]]
# y is a Tensor variable with shape [2, 2, 2] and elements as below:
# [[[1.0, 2.0], [3.0, 4.0]],
# [[5.0, 6.0], [7.0, 8.0]]]
# Each example is followed by the corresponding output tensor.
y = fluid.data(name='y', shape=[2, 2, 2], dtype='float32')
fluid.layers.reduce_max(y, dim=[1, 2]) # [4.0, 8.0]
fluid.layers.reduce_max(y, dim=[0, 1]) # [7.0, 8.0]
"""
helper = LayerHelper('reduce_max', **locals())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
if dim is not None and not isinstance(dim, list):
dim = [dim]
if in_dygraph_mode():
return _C_ops.max(input, dim if dim is not None else [], keep_dim)
helper.append_op(
type='reduce_max',
inputs={'X': input},
outputs={'Out': out},
attrs={
'dim': dim if dim is not None and dim != [] else [0],
'keep_dim': keep_dim,
'reduce_all': True
if dim is None or dim == [] or len(dim) == len(input.shape)
else False,
},
)
return out
def reduce_min(input, dim=None, keep_dim=False, name=None):
"""
Computes the minimum of tensor elements over the given dimension.
Args:
input (Variable): The input variable which is a Tensor, the data type is float32,
float64, int32, int64.
dim (list|int, optional): The dimensions along which the minimum is computed.
If :attr:`None`, compute the minimum over all elements of
:attr:`input` and return a Tensor variable with a single element,
otherwise must be in the range :math:`[-rank(input), rank(input))`.
If :math:`dim[i] < 0`, the dimension to reduce is :math:`rank + dim[i]`.
keep_dim (bool, optional): Whether to reserve the reduced dimension in the
output Tensor. The result tensor will have one fewer dimension
than the :attr:`input` unless :attr:`keep_dim` is true, default
value is False.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
Variable: Tensor, result of minimum on the specified dim of input tensor,
it's data type is the same as input's Tensor.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle
paddle.enable_static()
# x is a Tensor variable with following elements:
# [[0.2, 0.3, 0.5, 0.9]
# [0.1, 0.2, 0.6, 0.7]]
# Each example is followed by the corresponding output tensor.
x = fluid.data(name='x', shape=[2, 4], dtype='float32')
fluid.layers.reduce_min(x) # [0.1]
fluid.layers.reduce_min(x, dim=0) # [0.1, 0.2, 0.5, 0.7]
fluid.layers.reduce_min(x, dim=-1) # [0.2, 0.1]
fluid.layers.reduce_min(x, dim=1, keep_dim=True) # [[0.2], [0.1]]
# y is a Tensor variable with shape [2, 2, 2] and elements as below:
# [[[1.0, 2.0], [3.0, 4.0]],
# [[5.0, 6.0], [7.0, 8.0]]]
# Each example is followed by the corresponding output tensor.
y = fluid.data(name='y', shape=[2, 2, 2], dtype='float32')
fluid.layers.reduce_min(y, dim=[1, 2]) # [1.0, 5.0]
fluid.layers.reduce_min(y, dim=[0, 1]) # [1.0, 2.0]
"""
helper = LayerHelper('reduce_min', **locals())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
if dim is not None and not isinstance(dim, list):
dim = [dim]
if in_dygraph_mode():
return _C_ops.min(input, dim if dim is not None else [], keep_dim)
helper.append_op(
type='reduce_min',
inputs={'X': input},
outputs={'Out': out},
attrs={
'dim': dim if dim is not None and dim != [] else [0],
'keep_dim': keep_dim,
'reduce_all': True
if dim is None or dim == [] or len(dim) == len(input.shape)
else False,
},
)
return out
def reduce_all(input, dim=None, keep_dim=False, name=None):
"""
......
......@@ -52,7 +52,7 @@ class SimpleConvPool(fluid.dygraph.Layer):
def forward(self, inputs):
x = paddle.tanh(self._conv2d(inputs))
x = fluid.layers.reduce_max(x, dim=-1)
x = paddle.max(x, axis=-1)
x = paddle.reshape(x, shape=[self.batch_size, -1])
return x
......@@ -194,7 +194,7 @@ class GRU(fluid.dygraph.Layer):
emb = paddle.reshape(emb, shape=[self.batch_size, -1, self.hid_dim])
fc_1 = self._fc1(emb)
gru_hidden = self._gru(fc_1)
gru_hidden = fluid.layers.reduce_max(gru_hidden, dim=1)
gru_hidden = paddle.max(gru_hidden, axis=1)
tanh_1 = paddle.tanh(gru_hidden)
fc_2 = self._fc2(tanh_1)
prediction = self._fc_prediction(fc_2)
......@@ -254,7 +254,7 @@ class BiGRU(fluid.dygraph.Layer):
encoded_vector = fluid.layers.concat(
input=[gru_forward_tanh, gru_backward_tanh], axis=2
)
encoded_vector = fluid.layers.reduce_max(encoded_vector, dim=1)
encoded_vector = paddle.max(encoded_vector, axis=1)
fc_2 = self._fc2(encoded_vector)
prediction = self._fc_prediction(fc_2)
# TODO(Aurelius84): Uncomment the following codes when we support return variable-length vars.
......
......@@ -208,7 +208,7 @@ class TSM_ResNet(fluid.dygraph.Layer):
y = self.pool2d_avg(y)
y = fluid.layers.dropout(y, dropout_prob=0.5)
y = paddle.reshape(y, [-1, self.seg_num, y.shape[1]])
y = fluid.layers.reduce_mean(y, dim=1)
y = paddle.mean(y, axis=1)
y = paddle.reshape(y, shape=[-1, 2048])
y = self.out(y)
return y
......
......@@ -123,12 +123,12 @@ class TestMean(IPUOpTest):
class TestMax(TestMean):
def set_test_op(self):
self.op = paddle.fluid.layers.reduce_max
self.op = paddle.max
class TestMin(TestMean):
def set_test_op(self):
self.op = paddle.fluid.layers.reduce_min
self.op = paddle.min
class TestSum(TestMean):
......
......@@ -70,7 +70,7 @@ class TestAssertOp(unittest.TestCase):
def test_assert_summary(self):
def net_func():
x = layers.fill_constant(shape=[10], dtype='float32', value=2.0)
condition = layers.reduce_max(x) < 1.0
condition = paddle.max(x) < 1.0
layers.Assert(condition, (x,), 5)
print("test_assert_summary")
......@@ -80,7 +80,7 @@ class TestAssertOp(unittest.TestCase):
def test_assert_summary_greater_than_size(self):
def net_func():
x = layers.fill_constant(shape=[2, 3], dtype='float32', value=2.0)
condition = layers.reduce_max(x) < 1.0
condition = paddle.max(x) < 1.0
layers.Assert(condition, [x], 10, name="test")
print("test_assert_summary_greater_than_size")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册