From a647641883ea25e0ba282e27e6f892a3b6fb0ac7 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Tue, 6 Sep 2022 13:36:05 +0800 Subject: [PATCH] [Eager, Performance optimization] Reduce min/max kernel polish (#45755) * [Eager, Performance optimization] reduce_max / min polish * polish reduce_max / min * update min/max kernel reduce_all logic * fix a mistake * fix ci errors * fix errors --- paddle/phi/kernels/reduce_max_kernel.cc | 2 +- paddle/phi/kernels/reduce_min_kernel.cc | 2 +- python/paddle/fluid/layers/nn.py | 9 ++ python/paddle/tensor/linalg.py | 124 +++++++++++++----------- 4 files changed, 78 insertions(+), 59 deletions(-) diff --git a/paddle/phi/kernels/reduce_max_kernel.cc b/paddle/phi/kernels/reduce_max_kernel.cc index 72dd515fc43..cd38f7cbcea 100644 --- a/paddle/phi/kernels/reduce_max_kernel.cc +++ b/paddle/phi/kernels/reduce_max_kernel.cc @@ -26,7 +26,7 @@ void MaxKernel(const Context& dev_ctx, bool keep_dim, DenseTensor* out) { bool reduce_all = false; - if (dims.size() == 0) { + if (dims.size() == 0 || static_cast(dims.size()) == x.dims().size()) { reduce_all = true; } MaxRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); diff --git a/paddle/phi/kernels/reduce_min_kernel.cc b/paddle/phi/kernels/reduce_min_kernel.cc index 11f11b772ef..4d3041adf46 100644 --- a/paddle/phi/kernels/reduce_min_kernel.cc +++ b/paddle/phi/kernels/reduce_min_kernel.cc @@ -26,7 +26,7 @@ void MinKernel(const Context& dev_ctx, bool keep_dim, DenseTensor* out) { bool reduce_all = false; - if (dims.size() == 0) { + if (dims.size() == 0 || static_cast(dims.size()) == x.dims().size()) { reduce_all = true; } MinRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 0dd33a70810..53994eed80f 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4808,8 +4808,13 @@ def reduce_max(input, dim=None, keep_dim=False, name=None): """ helper = LayerHelper('reduce_max', **locals()) out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) + if dim is not None and not isinstance(dim, list): dim = [dim] + + if in_dygraph_mode(): + return _C_ops.max(input, dim if dim != None else [], keep_dim) + helper.append_op(type='reduce_max', inputs={'X': input}, outputs={'Out': out}, @@ -4878,6 +4883,10 @@ def reduce_min(input, dim=None, keep_dim=False, name=None): out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) if dim is not None and not isinstance(dim, list): dim = [dim] + + if in_dygraph_mode(): + return _C_ops.min(input, dim if dim != None else [], keep_dim) + helper.append_op(type='reduce_min', inputs={'X': input}, outputs={'Out': out}, diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 7b7adc1eeca..700c6c340dc 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -766,15 +766,20 @@ def cond(x, p=None, name=None): axis = axis if axis != None and axis != [] else [0] keepdim = False - if _non_static_mode(): - if in_dygraph_mode(): - abs_out = _C_ops.abs(input) - sum_out = _C_ops.sum(abs_out, axis, None, keepdim) - else: - abs_out = _legacy_C_ops.abs(input) - sum_out = _legacy_C_ops.reduce_sum(abs_out, 'dim', axis, - 'keepdim', keepdim, - 'reduce_all', reduce_all) + if in_dygraph_mode(): + abs_out = _C_ops.abs(input) + sum_out = _C_ops.sum(abs_out, axis, None, keepdim) + + if porder == 1 or porder == np.inf: + return _C_ops.max(sum_out, [-1], keepdim) + if porder == -1 or porder == -np.inf: + return _C_ops.min(sum_out, [-1], keepdim) + + elif _in_legacy_dygraph(): + abs_out = _legacy_C_ops.abs(input) + sum_out = _legacy_C_ops.reduce_sum(abs_out, 'dim', axis, 'keepdim', + keepdim, 'reduce_all', + reduce_all) if porder == 1 or porder == np.inf: return _legacy_C_ops.reduce_max(sum_out, 'dim', [-1], 'keepdim', keepdim, 'reduce_all', @@ -783,44 +788,44 @@ def cond(x, p=None, name=None): return _legacy_C_ops.reduce_min(sum_out, 'dim', [-1], 'keepdim', keepdim, 'reduce_all', reduce_all) - - block = LayerHelper('norm', **locals()) - abs_out = block.create_variable_for_type_inference( - dtype=block.input_dtype()) - sum_out = block.create_variable_for_type_inference( - dtype=block.input_dtype()) - out = block.create_variable_for_type_inference( - dtype=block.input_dtype()) - block.append_op(type='abs', - inputs={'X': input}, - outputs={'Out': abs_out}) - block.append_op(type='reduce_sum', - inputs={'X': abs_out}, - outputs={'Out': sum_out}, - attrs={ - 'dim': axis, - 'keep_dim': keepdim, - 'reduce_all': reduce_all - }) - if porder == 1 or porder == np.inf: - block.append_op(type='reduce_max', - inputs={'X': sum_out}, - outputs={'Out': out}, - attrs={ - 'dim': [-1], - 'keep_dim': keepdim, - 'reduce_all': reduce_all - }) - if porder == -1 or porder == -np.inf: - block.append_op(type='reduce_min', - inputs={'X': sum_out}, - outputs={'Out': out}, + else: + block = LayerHelper('norm', **locals()) + abs_out = block.create_variable_for_type_inference( + dtype=block.input_dtype()) + sum_out = block.create_variable_for_type_inference( + dtype=block.input_dtype()) + out = block.create_variable_for_type_inference( + dtype=block.input_dtype()) + block.append_op(type='abs', + inputs={'X': input}, + outputs={'Out': abs_out}) + block.append_op(type='reduce_sum', + inputs={'X': abs_out}, + outputs={'Out': sum_out}, attrs={ - 'dim': [-1], + 'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all }) - return out + if porder == 1 or porder == np.inf: + block.append_op(type='reduce_max', + inputs={'X': sum_out}, + outputs={'Out': out}, + attrs={ + 'dim': [-1], + 'keep_dim': keepdim, + 'reduce_all': reduce_all + }) + if porder == -1 or porder == -np.inf: + block.append_op(type='reduce_min', + inputs={'X': sum_out}, + outputs={'Out': out}, + attrs={ + 'dim': [-1], + 'keep_dim': keepdim, + 'reduce_all': reduce_all + }) + return out def fro_norm(input, porder=2, axis=[-1]): """ @@ -899,22 +904,27 @@ def cond(x, p=None, name=None): return _legacy_C_ops.reduce_sum(s, 'dim', axis, 'keepdim', keepdim, 'reduce_all', reduce_all) - max_out = _legacy_C_ops.reduce_max(s, 'dim', axis, 'keepdim', - keepdim, 'reduce_all', - reduce_all) - min_out = _legacy_C_ops.reduce_min(s, 'dim', axis, 'keepdim', - keepdim, 'reduce_all', - reduce_all) - if porder == 2: - if in_dygraph_mode(): + if in_dygraph_mode(): + max_out = _C_ops.max(s, axis, keepdim) + min_out = _C_ops.min(s, axis, keepdim) + if porder == 2: return _C_ops.divide(max_out, min_out) - return _legacy_C_ops.elementwise_div(max_out, min_out, 'aixs', - axis, 'use_mkldnn', False) - if porder == -2: - if in_dygraph_mode(): + if porder == -2: return _C_ops.divide(min_out, max_out) - return _legacy_C_ops.elementwise_div(min_out, max_out, 'aixs', - axis, 'use_mkldnn', False) + + else: + max_out = _legacy_C_ops.reduce_max(s, 'dim', axis, 'keepdim', + keepdim, 'reduce_all', + reduce_all) + min_out = _legacy_C_ops.reduce_min(s, 'dim', axis, 'keepdim', + keepdim, 'reduce_all', + reduce_all) + if porder == 2: + return _legacy_C_ops.elementwise_div( + max_out, min_out, 'aixs', axis, 'use_mkldnn', False) + if porder == -2: + return _legacy_C_ops.elementwise_div( + min_out, max_out, 'aixs', axis, 'use_mkldnn', False) block = LayerHelper('norm', **locals()) out = block.create_variable_for_type_inference( -- GitLab