未验证 提交 a6476418 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager, Performance optimization] Reduce min/max kernel polish (#45755)

* [Eager, Performance optimization] reduce_max / min polish

* polish reduce_max / min

* update min/max kernel reduce_all logic

* fix a mistake

* fix ci errors

* fix errors
上级 5e77346f
...@@ -26,7 +26,7 @@ void MaxKernel(const Context& dev_ctx, ...@@ -26,7 +26,7 @@ void MaxKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
DenseTensor* out) { DenseTensor* out) {
bool reduce_all = false; bool reduce_all = false;
if (dims.size() == 0) { if (dims.size() == 0 || static_cast<int>(dims.size()) == x.dims().size()) {
reduce_all = true; reduce_all = true;
} }
MaxRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out); MaxRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
......
...@@ -26,7 +26,7 @@ void MinKernel(const Context& dev_ctx, ...@@ -26,7 +26,7 @@ void MinKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
DenseTensor* out) { DenseTensor* out) {
bool reduce_all = false; bool reduce_all = false;
if (dims.size() == 0) { if (dims.size() == 0 || static_cast<int>(dims.size()) == x.dims().size()) {
reduce_all = true; reduce_all = true;
} }
MinRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out); MinRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
......
...@@ -4808,8 +4808,13 @@ def reduce_max(input, dim=None, keep_dim=False, name=None): ...@@ -4808,8 +4808,13 @@ def reduce_max(input, dim=None, keep_dim=False, name=None):
""" """
helper = LayerHelper('reduce_max', **locals()) helper = LayerHelper('reduce_max', **locals())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
if dim is not None and not isinstance(dim, list): if dim is not None and not isinstance(dim, list):
dim = [dim] dim = [dim]
if in_dygraph_mode():
return _C_ops.max(input, dim if dim != None else [], keep_dim)
helper.append_op(type='reduce_max', helper.append_op(type='reduce_max',
inputs={'X': input}, inputs={'X': input},
outputs={'Out': out}, outputs={'Out': out},
...@@ -4878,6 +4883,10 @@ def reduce_min(input, dim=None, keep_dim=False, name=None): ...@@ -4878,6 +4883,10 @@ def reduce_min(input, dim=None, keep_dim=False, name=None):
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
if dim is not None and not isinstance(dim, list): if dim is not None and not isinstance(dim, list):
dim = [dim] dim = [dim]
if in_dygraph_mode():
return _C_ops.min(input, dim if dim != None else [], keep_dim)
helper.append_op(type='reduce_min', helper.append_op(type='reduce_min',
inputs={'X': input}, inputs={'X': input},
outputs={'Out': out}, outputs={'Out': out},
......
...@@ -766,15 +766,20 @@ def cond(x, p=None, name=None): ...@@ -766,15 +766,20 @@ def cond(x, p=None, name=None):
axis = axis if axis != None and axis != [] else [0] axis = axis if axis != None and axis != [] else [0]
keepdim = False keepdim = False
if _non_static_mode():
if in_dygraph_mode(): if in_dygraph_mode():
abs_out = _C_ops.abs(input) abs_out = _C_ops.abs(input)
sum_out = _C_ops.sum(abs_out, axis, None, keepdim) sum_out = _C_ops.sum(abs_out, axis, None, keepdim)
else:
if porder == 1 or porder == np.inf:
return _C_ops.max(sum_out, [-1], keepdim)
if porder == -1 or porder == -np.inf:
return _C_ops.min(sum_out, [-1], keepdim)
elif _in_legacy_dygraph():
abs_out = _legacy_C_ops.abs(input) abs_out = _legacy_C_ops.abs(input)
sum_out = _legacy_C_ops.reduce_sum(abs_out, 'dim', axis, sum_out = _legacy_C_ops.reduce_sum(abs_out, 'dim', axis, 'keepdim',
'keepdim', keepdim, keepdim, 'reduce_all',
'reduce_all', reduce_all) reduce_all)
if porder == 1 or porder == np.inf: if porder == 1 or porder == np.inf:
return _legacy_C_ops.reduce_max(sum_out, 'dim', [-1], 'keepdim', return _legacy_C_ops.reduce_max(sum_out, 'dim', [-1], 'keepdim',
keepdim, 'reduce_all', keepdim, 'reduce_all',
...@@ -783,7 +788,7 @@ def cond(x, p=None, name=None): ...@@ -783,7 +788,7 @@ def cond(x, p=None, name=None):
return _legacy_C_ops.reduce_min(sum_out, 'dim', [-1], 'keepdim', return _legacy_C_ops.reduce_min(sum_out, 'dim', [-1], 'keepdim',
keepdim, 'reduce_all', keepdim, 'reduce_all',
reduce_all) reduce_all)
else:
block = LayerHelper('norm', **locals()) block = LayerHelper('norm', **locals())
abs_out = block.create_variable_for_type_inference( abs_out = block.create_variable_for_type_inference(
dtype=block.input_dtype()) dtype=block.input_dtype())
...@@ -899,6 +904,15 @@ def cond(x, p=None, name=None): ...@@ -899,6 +904,15 @@ def cond(x, p=None, name=None):
return _legacy_C_ops.reduce_sum(s, 'dim', axis, 'keepdim', return _legacy_C_ops.reduce_sum(s, 'dim', axis, 'keepdim',
keepdim, 'reduce_all', keepdim, 'reduce_all',
reduce_all) reduce_all)
if in_dygraph_mode():
max_out = _C_ops.max(s, axis, keepdim)
min_out = _C_ops.min(s, axis, keepdim)
if porder == 2:
return _C_ops.divide(max_out, min_out)
if porder == -2:
return _C_ops.divide(min_out, max_out)
else:
max_out = _legacy_C_ops.reduce_max(s, 'dim', axis, 'keepdim', max_out = _legacy_C_ops.reduce_max(s, 'dim', axis, 'keepdim',
keepdim, 'reduce_all', keepdim, 'reduce_all',
reduce_all) reduce_all)
...@@ -906,15 +920,11 @@ def cond(x, p=None, name=None): ...@@ -906,15 +920,11 @@ def cond(x, p=None, name=None):
keepdim, 'reduce_all', keepdim, 'reduce_all',
reduce_all) reduce_all)
if porder == 2: if porder == 2:
if in_dygraph_mode(): return _legacy_C_ops.elementwise_div(
return _C_ops.divide(max_out, min_out) max_out, min_out, 'aixs', axis, 'use_mkldnn', False)
return _legacy_C_ops.elementwise_div(max_out, min_out, 'aixs',
axis, 'use_mkldnn', False)
if porder == -2: if porder == -2:
if in_dygraph_mode(): return _legacy_C_ops.elementwise_div(
return _C_ops.divide(min_out, max_out) min_out, max_out, 'aixs', axis, 'use_mkldnn', False)
return _legacy_C_ops.elementwise_div(min_out, max_out, 'aixs',
axis, 'use_mkldnn', False)
block = LayerHelper('norm', **locals()) block = LayerHelper('norm', **locals())
out = block.create_variable_for_type_inference( out = block.create_variable_for_type_inference(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册