未验证 提交 192b3033 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager, Performance optimization] reduce_all interface move reduce_all flag...

[Eager, Performance optimization] reduce_all interface move reduce_all flag from python to C++ (#45744)

* [Eager, Performance optimization] move reduce_all flag from python to c++

* polish reduce_all

* fix ci error

* fix errors
上级 cd84e1bf
...@@ -26,6 +26,9 @@ void AllKernel(const Context& dev_ctx, ...@@ -26,6 +26,9 @@ void AllKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
DenseTensor* out) { DenseTensor* out) {
bool reduce_all = false; bool reduce_all = false;
if (dims.size() == 0 || static_cast<int>(dims.size()) == x.dims().size()) {
reduce_all = true;
}
AllRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out); AllRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
} }
......
...@@ -5032,6 +5032,10 @@ def reduce_all(input, dim=None, keep_dim=False, name=None): ...@@ -5032,6 +5032,10 @@ def reduce_all(input, dim=None, keep_dim=False, name=None):
""" """
if dim is not None and not isinstance(dim, list): if dim is not None and not isinstance(dim, list):
dim = [dim] dim = [dim]
if in_dygraph_mode():
return _C_ops.all(input, dim if dim != None else [], keep_dim)
check_variable_and_dtype(input, 'input', ('bool'), 'reduce_all') check_variable_and_dtype(input, 'input', ('bool'), 'reduce_all')
helper = LayerHelper('reduce_all', **locals()) helper = LayerHelper('reduce_all', **locals())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册