未验证 提交 192b3033 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager, Performance optimization] reduce_all interface move reduce_all flag...

[Eager, Performance optimization] reduce_all interface move reduce_all flag from python to C++ (#45744)

* [Eager, Performance optimization] move reduce_all flag from python to c++

* polish reduce_all

* fix ci error

* fix errors
上级 cd84e1bf
......@@ -26,6 +26,9 @@ void AllKernel(const Context& dev_ctx,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = false;
if (dims.size() == 0 || static_cast<int>(dims.size()) == x.dims().size()) {
reduce_all = true;
}
AllRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}
......
......@@ -5032,6 +5032,10 @@ def reduce_all(input, dim=None, keep_dim=False, name=None):
"""
if dim is not None and not isinstance(dim, list):
dim = [dim]
if in_dygraph_mode():
return _C_ops.all(input, dim if dim != None else [], keep_dim)
check_variable_and_dtype(input, 'input', ('bool'), 'reduce_all')
helper = LayerHelper('reduce_all', **locals())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册