未验证 提交 3ed69e0d 编写于 作者: H hong 提交者: GitHub

move reduce_all_flag from python to c++ (#44926)

* move reduce_all_flag from python to c++

* fix infer shape bug

* fix bug;

* fix sum infer meta bug

* fix reduce sum grad gpu bug

* fix amin amax bug;
上级 47ea4d87
......@@ -2586,6 +2586,9 @@ void ReduceInferMeta(const MetaTensor& x,
bool keep_dim,
MetaTensor* out) {
bool reduce_all = false;
if (axis.size() == 0) {
reduce_all = true;
}
ReduceInferMetaBase(x, axis, keep_dim, reduce_all, out);
}
......@@ -3254,6 +3257,9 @@ void SumInferMeta(const MetaTensor& x,
bool keep_dim,
MetaTensor* out) {
bool reduce_all = false;
if (axis.size() == 0) {
reduce_all = true;
}
SumRawInferMeta(x, axis, keep_dim, reduce_all, dtype, out);
}
......
......@@ -38,6 +38,9 @@ void ReduceCudaAMaxAMinGrad(const Context& dev_ctx,
auto* d_x = x_grad;
// get reduce_dim and reduce_num for reduce_mean_grad
int dim_size = in_x->dims().size();
if (dims.size() == 0) {
reduce_all = true;
}
auto reduce_dims = funcs::details::GetReduceDim(dims, dim_size, reduce_all);
auto update_dims = vectorize(d_x->dims());
int reduce_num = 1;
......
......@@ -37,6 +37,9 @@ void ReduceSumGradKernel(const Context& dev_ctx,
// get reduce_dim and reduce_num for reduce_mean_grad
int dim_size = in_x->dims().size();
if (dims.size() == 0) {
reduce_all = true;
}
std::vector<int> reduce_dims =
funcs::details::GetReduceDim(dims, dim_size, reduce_all);
......
......@@ -91,6 +91,9 @@ void ReduceGradKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
if (dims.size() == 0) {
reduce_all = true;
}
if (x.dtype() != out_grad.dtype()) {
DenseTensorMeta x_grad_meta(
out_grad.dtype(), x_grad->dims(), x_grad->layout());
......
......@@ -26,6 +26,9 @@ void AMaxKernel(const Context& dev_ctx,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = false;
if (dims.size() == 0) {
reduce_all = true;
}
AMaxRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}
......
......@@ -26,6 +26,9 @@ void AMinKernel(const Context& dev_ctx,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = false;
if (dims.size() == 0) {
reduce_all = true;
}
AMinRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}
......
......@@ -26,6 +26,9 @@ void MaxKernel(const Context& dev_ctx,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = false;
if (dims.size() == 0) {
reduce_all = true;
}
MaxRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}
......
......@@ -26,6 +26,9 @@ void MinKernel(const Context& dev_ctx,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = false;
if (dims.size() == 0) {
reduce_all = true;
}
MinRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}
......
......@@ -27,6 +27,9 @@ void SumKernel(const Context& dev_ctx,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = false;
if (dims.size() == 0) {
reduce_all = true;
}
SumRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out_dtype, out);
}
......
......@@ -187,3 +187,7 @@ class TestMaxMinAmaxAminAPI6(TestMaxMinAmaxAminAPI):
self.dtype = 'float64'
self.axis = None
self.keepdim = False
if __name__ == '__main__':
unittest.main()
......@@ -316,4 +316,5 @@ class TestNNFunctionalMseLoss(unittest.TestCase):
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -540,4 +540,5 @@ class API_TestEmptySplit(unittest.TestCase):
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -67,12 +67,12 @@ def pairwise_distance(x, y, p=2., epsilon=1e-6, keepdim=False, name=None):
check_type(epsilon, 'epsilon', (float), 'PairwiseDistance')
check_type(keepdim, 'keepdim', (bool), 'PairwiseDistance')
if in_dygraph_mode():
sub = _C_ops.elementwise_sub(x, y)
sub = _C_ops.final_state_subtract(x, y)
# p_norm op has not uesd epsilon, so change it to the following.
if epsilon != 0.0:
epsilon = paddle.fluid.dygraph.base.to_variable([epsilon],
dtype=sub.dtype)
sub = _C_ops.elementwise_add(sub, epsilon)
sub = _C_ops.final_state_add(sub, epsilon)
return _C_ops.final_state_p_norm(sub, p, -1, 0., keepdim, False)
if _in_legacy_dygraph():
......
......@@ -1162,12 +1162,7 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
axis = [axis]
if not axis:
reduce_all_flag = True
else:
if len(axis) == len(x.shape):
reduce_all_flag = True
else:
reduce_all_flag = False
axis = []
dtype_flag = False
if dtype is not None:
......@@ -1175,13 +1170,16 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
dtype = convert_np_dtype_to_dtype_(dtype)
if in_dygraph_mode():
if reduce_all_flag:
axis = range(len(x.shape))
else:
axis = axis if axis != None and axis != [] else [0]
return _C_ops.final_state_sum(x, axis, dtype, keepdim)
if len(axis) == 0:
reduce_all_flag = True
else:
if len(axis) == len(x.shape):
reduce_all_flag = True
else:
reduce_all_flag = False
if _in_legacy_dygraph():
axis = axis if axis != None and axis != [] else [0]
if dtype_flag:
......@@ -2056,6 +2054,24 @@ def inverse(x, name=None):
type='inverse', inputs={'Input': [x] }, outputs={'Output': [out]})
return out
def _get_reduce_axis(axis):
"""
Internal function for max, min, amax and amin.
It computes the attribute reduce_all value based on axis.
"""
if axis is not None and not isinstance(axis, list):
if isinstance(axis, tuple):
axis = list(axis)
elif isinstance(axis, int):
axis= [axis]
else:
raise TypeError(
"The type of axis must be int, list or tuple, but received {}".format(type(axis)))
reduce_all = True if axis == None or axis == [] else False
if axis == None:
axis = []
return reduce_all, axis
def _get_reduce_all_value(axis):
"""
Internal function for max, min, amax and amin.
......@@ -2152,10 +2168,8 @@ def max(x, axis=None, keepdim=False, name=None):
#[7., 8.], [[[0., 0.], [0., 0.]], [[0., 0.], [1., 1.]]]
"""
reduce_all, axis = _get_reduce_all_value(axis)
reduce_all, axis = _get_reduce_axis(axis)
if in_dygraph_mode():
if reduce_all:
axis = range(len(x.shape))
return _C_ops.final_state_max(x, axis, keepdim)
if _in_legacy_dygraph():
return _C_ops.reduce_max(x, 'dim', axis, 'keep_dim', keepdim,
......@@ -2255,10 +2269,8 @@ def min(x, axis=None, keepdim=False, name=None):
#[1., 2.], [[[1., 1.], [0., 0.]], [[0., 0.], [0., 0.]]]
"""
reduce_all, axis = _get_reduce_all_value(axis)
reduce_all, axis = _get_reduce_axis(axis)
if in_dygraph_mode():
if reduce_all:
axis = range(len(x.shape))
return _C_ops.final_state_min(x, axis, keepdim)
if _in_legacy_dygraph():
......@@ -2372,10 +2384,8 @@ def amax(x, axis=None, keepdim=False, name=None):
#[0.9., 0.9], [[[0., 0.3333], [0.5, 0.3333]], [[0.5, 0.3333], [1., 1.]]]
"""
reduce_all, axis = _get_reduce_all_value(axis)
reduce_all, axis = _get_reduce_axis(axis)
if in_dygraph_mode():
if reduce_all:
axis = range(len(x.shape))
return _C_ops.final_state_amax(x, axis, keepdim)
if _in_legacy_dygraph():
return _C_ops.reduce_amax(x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all)
......@@ -2488,10 +2498,8 @@ def amin(x, axis=None, keepdim=False, name=None):
#[0.1., 0.1], [[[0., 0.3333], [0.5, 0.3333]], [[0.5, 0.3333], [1., 1.]]]
"""
reduce_all, axis = _get_reduce_all_value(axis)
reduce_all, axis = _get_reduce_axis( axis )
if in_dygraph_mode():
if reduce_all:
axis = range(len(x.shape))
return _C_ops.final_state_amin(x, axis, keepdim)
elif _in_legacy_dygraph():
return _C_ops.reduce_amin(x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册