未验证 提交 ddf94ae4 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] Support paddle.sum/mean/loss api output 0D,test=allcase (#52739)

上级 0271d9e7
...@@ -365,7 +365,7 @@ void sum_grad(const Tensor& x, ...@@ -365,7 +365,7 @@ void sum_grad(const Tensor& x,
if (!keepdim) { if (!keepdim) {
auto axis_ = std::vector<int64_t>(); auto axis_ = std::vector<int64_t>();
if (reduce_all) { if (reduce_all) {
for (int64_t i = 1; i < x_dim_size; i++) { for (int64_t i = 0; i < x_dim_size; i++) {
axis_.push_back(i); axis_.push_back(i);
} }
} else { } else {
......
...@@ -4004,9 +4004,6 @@ DDim OriginReduceInferDim(const MetaTensor& x, ...@@ -4004,9 +4004,6 @@ DDim OriginReduceInferDim(const MetaTensor& x,
out_dim_vector.push_back(x.dims().at(i)); out_dim_vector.push_back(x.dims().at(i));
} }
} }
if (x_rank > 0 && out_dim_vector.size() == 0) {
out_dim_vector.push_back(1);
}
DDim out_dim = phi::make_ddim(out_dim_vector); DDim out_dim = phi::make_ddim(out_dim_vector);
return out_dim; return out_dim;
...@@ -4023,14 +4020,14 @@ DDim OriginReduceInferDimForIntArrayAxis(const MetaTensor& x, ...@@ -4023,14 +4020,14 @@ DDim OriginReduceInferDimForIntArrayAxis(const MetaTensor& x,
if (keep_dim) { if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), 1); vec_dim = std::vector<int64_t>(x.dims().size(), 1);
} else { } else {
vec_dim = {1}; vec_dim = {};
} }
} else { } else {
if (keep_dim) { if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), -1); vec_dim = std::vector<int64_t>(x.dims().size(), -1);
} else { } else {
auto x_rank = static_cast<size_t>(x.dims().size()); auto x_rank = static_cast<size_t>(x.dims().size());
if (vec_axis.size() >= x_rank) { if (vec_axis.size() > x_rank) {
vec_dim = {-1}; vec_dim = {-1};
} else { } else {
vec_dim = std::vector<int64_t>(x.dims().size() - vec_axis.size(), -1); vec_dim = std::vector<int64_t>(x.dims().size() - vec_axis.size(), -1);
......
...@@ -1688,7 +1688,7 @@ class Completer: ...@@ -1688,7 +1688,7 @@ class Completer:
world_ranks world_ranks
) )
out_dist_attr.dims_mapping = [ out_dist_attr.dims_mapping = [
-1 for _ in range(len(out_var.shape)) -1 for _ in out_var.shape
] ]
self._dist_context.set_tensor_dist_attr_for_program( self._dist_context.set_tensor_dist_attr_for_program(
out_var, out_dist_attr out_var, out_dist_attr
...@@ -1732,7 +1732,9 @@ class Completer: ...@@ -1732,7 +1732,9 @@ class Completer:
len(out_var.shape) == 1 len(out_var.shape) == 1
and out_var.shape[0] == 1 and out_var.shape[0] == 1
) )
out_dist_attr.dims_mapping = [-1] out_dist_attr.dims_mapping = [
-1 for _ in out_var.shape
]
self._dist_context.set_tensor_dist_attr_for_program( self._dist_context.set_tensor_dist_attr_for_program(
out_var, out_dist_attr out_var, out_dist_attr
) )
...@@ -1802,16 +1804,20 @@ class Completer: ...@@ -1802,16 +1804,20 @@ class Completer:
param.name, ref_dims_mapping param.name, ref_dims_mapping
) )
learning_var = vars[op.input("LearningRate")[0]] learning_var = vars[op.input("LearningRate")[0]]
op_dist_attr.set_input_dims_mapping(learning_var.name, [-1]) op_dist_attr.set_input_dims_mapping(
learning_var.name, [-1 for _ in learning_var.shape]
)
op_dist_attr.set_output_dims_mapping( op_dist_attr.set_output_dims_mapping(
learning_var.name, [-1] learning_var.name, [-1 for _ in learning_var.shape]
) )
if not learning_rate_completed: if not learning_rate_completed:
learning_rate_completed = True learning_rate_completed = True
var_dist_attr = TensorDistAttr() var_dist_attr = TensorDistAttr()
var_dist_attr.process_mesh = ProcessMesh(world_ranks) var_dist_attr.process_mesh = ProcessMesh(world_ranks)
var_dist_attr.dims_mapping = [-1] var_dist_attr.dims_mapping = [
-1 for _ in learning_var.shape
]
self._dist_context.set_tensor_dist_attr_for_program( self._dist_context.set_tensor_dist_attr_for_program(
learning_var, var_dist_attr learning_var, var_dist_attr
) )
...@@ -1841,10 +1847,10 @@ class Completer: ...@@ -1841,10 +1847,10 @@ class Completer:
): ):
input_var_attr.dims_mapping = [-1] input_var_attr.dims_mapping = [-1]
op_dist_attr.set_input_dims_mapping( op_dist_attr.set_input_dims_mapping(
input_var.name, [-1] input_var.name, [-1 for _ in input_var.shape]
) )
op_dist_attr.set_output_dims_mapping( op_dist_attr.set_output_dims_mapping(
input_var.name, [-1] input_var.name, [-1 for _ in input_var.shape]
) )
else: else:
input_var_attr.dims_mapping = ref_dims_mapping input_var_attr.dims_mapping = ref_dims_mapping
......
...@@ -511,7 +511,7 @@ class Engine: ...@@ -511,7 +511,7 @@ class Engine:
loss_indices = fetch_indices[group_idx] loss_indices = fetch_indices[group_idx]
assert len(loss_indices) <= 1 assert len(loss_indices) <= 1
for idx in loss_indices: for idx in loss_indices:
logs["loss"] = outs[idx][0] logs["loss"] = outs[idx]
group_idx += 1 group_idx += 1
# logging metrics # logging metrics
dist_context = self._dist_contexts[mode] dist_context = self._dist_contexts[mode]
......
...@@ -393,7 +393,7 @@ def get_data_parallel_group(dist_ctx, op, act_grad_names, rank): ...@@ -393,7 +393,7 @@ def get_data_parallel_group(dist_ctx, op, act_grad_names, rank):
for var_name in act_grad_names: for var_name in act_grad_names:
var_dim_mapping = op_dist_attr.get_input_dims_mapping(var_name) var_dim_mapping = op_dist_attr.get_input_dims_mapping(var_name)
# consider that the variable's shape is None # consider that the variable's shape is [], which is 0D
# TODO utilize the batch_dim attr instead of "0" in future # TODO utilize the batch_dim attr instead of "0" in future
batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1 batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
......
...@@ -159,7 +159,9 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -159,7 +159,9 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
): ):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname) var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.shape mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True need_gradient_allreduce = True
break break
......
...@@ -101,7 +101,9 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): ...@@ -101,7 +101,9 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
): ):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname) var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.shape mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True need_gradient_allreduce = True
break break
......
...@@ -252,7 +252,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -252,7 +252,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
backward_op.input("Ids")[0] backward_op.input("Ids")[0]
) )
mesh_shape = process_mesh.shape mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True} attrs = {"use_calc_stream": True}
......
...@@ -651,7 +651,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -651,7 +651,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
backward_op.input("X")[0] backward_op.input("X")[0]
) )
mesh_shape = process_mesh.shape mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if ( if (
batch_size_axis > -1 batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1 and mesh_shape[batch_size_axis] > 1
...@@ -1028,7 +1028,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -1028,7 +1028,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
backward_op.input("X")[0] backward_op.input("X")[0]
) )
mesh_shape = process_mesh.shape mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if ( if (
batch_size_axis > -1 batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1 and mesh_shape[batch_size_axis] > 1
...@@ -1365,7 +1365,7 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): ...@@ -1365,7 +1365,7 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
backward_op.input("X")[0] backward_op.input("X")[0]
) )
mesh_shape = process_mesh.shape mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if ( if (
batch_size_axis > -1 batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1 and mesh_shape[batch_size_axis] > 1
...@@ -1552,7 +1552,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1552,7 +1552,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
backward_op.input("X")[0] backward_op.input("X")[0]
) )
mesh_shape = process_mesh.shape mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if ( if (
batch_size_axis > -1 batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1 and mesh_shape[batch_size_axis] > 1
...@@ -1929,7 +1929,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1929,7 +1929,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
backward_op.input("X")[0] backward_op.input("X")[0]
) )
mesh_shape = process_mesh.shape mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if ( if (
batch_size_axis > -1 batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1 and mesh_shape[batch_size_axis] > 1
...@@ -2264,7 +2264,7 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): ...@@ -2264,7 +2264,7 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
backward_op.input("X")[0] backward_op.input("X")[0]
) )
mesh_shape = process_mesh.shape mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if ( if (
batch_size_axis > -1 batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1 and mesh_shape[batch_size_axis] > 1
...@@ -2449,7 +2449,7 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2449,7 +2449,7 @@ class DistributedMulImpl0(DistributedOperatorImpl):
backward_op.input("X")[0] backward_op.input("X")[0]
) )
mesh_shape = process_mesh.shape mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if ( if (
batch_size_axis > -1 batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1 and mesh_shape[batch_size_axis] > 1
...@@ -2832,7 +2832,7 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2832,7 +2832,7 @@ class DistributedMulImpl1(DistributedOperatorImpl):
backward_op.input("X")[0] backward_op.input("X")[0]
) )
mesh_shape = process_mesh.shape mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if ( if (
batch_size_axis > -1 batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1 and mesh_shape[batch_size_axis] > 1
...@@ -3178,7 +3178,7 @@ class DistributedMulImpl2(DistributedOperatorImpl): ...@@ -3178,7 +3178,7 @@ class DistributedMulImpl2(DistributedOperatorImpl):
backward_op.input("X")[0] backward_op.input("X")[0]
) )
mesh_shape = process_mesh.shape mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
if ( if (
batch_size_axis > -1 batch_size_axis > -1
and mesh_shape[batch_size_axis] > 1 and mesh_shape[batch_size_axis] > 1
......
...@@ -120,7 +120,9 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -120,7 +120,9 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname) var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.shape mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True} attrs = {"use_calc_stream": True}
...@@ -377,7 +379,9 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -377,7 +379,9 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname) var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.shape mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True} attrs = {"use_calc_stream": True}
...@@ -637,7 +641,9 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): ...@@ -637,7 +641,9 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname) var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.shape mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True} attrs = {"use_calc_stream": True}
......
...@@ -100,7 +100,9 @@ class DistributedScaleImpl(DistributedOperatorImpl): ...@@ -100,7 +100,9 @@ class DistributedScaleImpl(DistributedOperatorImpl):
): ):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname) var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.shape mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True need_gradient_allreduce = True
break break
......
...@@ -94,7 +94,9 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -94,7 +94,9 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname) var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.shape mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True} attrs = {"use_calc_stream": True}
......
...@@ -183,7 +183,9 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): ...@@ -183,7 +183,9 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
var_dim_mapping = dist_attr.get_input_dims_mapping(varname) var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
mesh_shape = process_mesh.shape mesh_shape = process_mesh.shape
batch_size_axis = var_dim_mapping[0] batch_size_axis = (
var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1
)
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
parallel_axis = batch_size_axis parallel_axis = batch_size_axis
attrs = {"use_calc_stream": True} attrs = {"use_calc_stream": True}
......
...@@ -1727,7 +1727,9 @@ class RuleBasedTuner: ...@@ -1727,7 +1727,9 @@ class RuleBasedTuner:
len(out_var.shape) == 1 len(out_var.shape) == 1
and out_var.shape[0] == 1 and out_var.shape[0] == 1
) )
out_dist_attr.dims_mapping = [-1] out_dist_attr.dims_mapping = [
-1 for _ in out_var.shape
]
sub_program_dist_context.set_tensor_dist_attr_for_program( sub_program_dist_context.set_tensor_dist_attr_for_program(
out_var, out_dist_attr out_var, out_dist_attr
) )
...@@ -1798,17 +1800,19 @@ class RuleBasedTuner: ...@@ -1798,17 +1800,19 @@ class RuleBasedTuner:
) )
learning_var = vars[op.input("LearningRate")[0]] learning_var = vars[op.input("LearningRate")[0]]
op_dist_attr.set_input_dims_mapping( op_dist_attr.set_input_dims_mapping(
learning_var.name, [-1] learning_var.name, [-1 for i in learning_var.shape]
) )
op_dist_attr.set_output_dims_mapping( op_dist_attr.set_output_dims_mapping(
learning_var.name, [-1] learning_var.name, [-1 for i in learning_var.shape]
) )
if not learning_rate_completed: if not learning_rate_completed:
learning_rate_completed = True learning_rate_completed = True
var_dist_attr = TensorDistAttr() var_dist_attr = TensorDistAttr()
var_dist_attr.process_mesh = world_ranks var_dist_attr.process_mesh = world_ranks
var_dist_attr.dims_mapping = [-1] var_dist_attr.dims_mapping = [
-1 for i in learning_var.shape
]
sub_program_dist_context.set_tensor_dist_attr_for_program( sub_program_dist_context.set_tensor_dist_attr_for_program(
learning_var, var_dist_attr learning_var, var_dist_attr
) )
......
...@@ -1466,7 +1466,8 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op): ...@@ -1466,7 +1466,8 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op):
), "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part.".format( ), "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part.".format(
op_desc.type(), idx, mapping op_desc.type(), idx, mapping
) )
batch_dim_mappings.append(dims_mapping[0]) if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
for arg_name in op_desc.output_arg_names(): for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name) serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter: if serial_tensor.is_parameter:
...@@ -1480,7 +1481,8 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op): ...@@ -1480,7 +1481,8 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op):
), "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part.".format( ), "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part.".format(
op_desc.type(), idx, mapping op_desc.type(), idx, mapping
) )
batch_dim_mappings.append(dims_mapping[0]) if len(dims_mapping) >= 1:
batch_dim_mappings.append(dims_mapping[0])
else: else:
assert ( assert (
dims_mapping[0] == -1 dims_mapping[0] == -1
...@@ -1505,7 +1507,7 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op): ...@@ -1505,7 +1507,7 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op):
if serial_tensor.is_parameter: if serial_tensor.is_parameter:
continue continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if compatible_dim_mapping != dims_mapping[0]: if len(dims_mapping) >= 1 and compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping dims_mapping[0] = compatible_dim_mapping
changed = True changed = True
for arg_name in op_desc.output_arg_names(): for arg_name in op_desc.output_arg_names():
...@@ -1514,7 +1516,10 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op): ...@@ -1514,7 +1516,10 @@ def update_op_dims_mapping_by_default_dist_impl(dist_op):
continue continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names: if arg_name not in xshape_arg_names:
if compatible_dim_mapping != dims_mapping[0]: if (
len(dims_mapping) >= 1
and compatible_dim_mapping != dims_mapping[0]
):
dims_mapping[0] = compatible_dim_mapping dims_mapping[0] = compatible_dim_mapping
changed = True changed = True
else: else:
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
import paddle import paddle
from paddle import framework from paddle import framework
...@@ -95,11 +96,10 @@ class HybridParallelClipGrad: ...@@ -95,11 +96,10 @@ class HybridParallelClipGrad:
# global norm of distributed FP16 params_and_grads # global norm of distributed FP16 params_and_grads
if len(sum_square_dist_fp16) == 0: if len(sum_square_dist_fp16) == 0:
global_norm_dist_fp16 = paddle.to_tensor( global_norm_dist_fp16 = paddle.to_tensor(
[0.0], dtype=paddle.float32 np.array(0.0), dtype=paddle.float32
) )
else: else:
global_norm_dist_fp16 = paddle.concat(sum_square_dist_fp16) global_norm_dist_fp16 = paddle.add_n(sum_square_dist_fp16)
global_norm_dist_fp16 = paddle.sum(global_norm_dist_fp16)
global_norm_dist_fp16 = paddle.cast( global_norm_dist_fp16 = paddle.cast(
global_norm_dist_fp16, dtype=paddle.float32 global_norm_dist_fp16, dtype=paddle.float32
) )
...@@ -107,11 +107,10 @@ class HybridParallelClipGrad: ...@@ -107,11 +107,10 @@ class HybridParallelClipGrad:
# global norm of non-distributed FP16 params_and_grads # global norm of non-distributed FP16 params_and_grads
if len(sum_square_not_dist_fp16) == 0: if len(sum_square_not_dist_fp16) == 0:
global_norm_not_dist_fp16 = paddle.to_tensor( global_norm_not_dist_fp16 = paddle.to_tensor(
[0.0], dtype=paddle.float32 np.array(0.0), dtype=paddle.float32
) )
else: else:
global_norm_not_dist_fp16 = paddle.concat(sum_square_not_dist_fp16) global_norm_not_dist_fp16 = paddle.add_n(sum_square_not_dist_fp16)
global_norm_not_dist_fp16 = paddle.sum(global_norm_not_dist_fp16)
global_norm_not_dist_fp16 = paddle.cast( global_norm_not_dist_fp16 = paddle.cast(
global_norm_not_dist_fp16, dtype=paddle.float32 global_norm_not_dist_fp16, dtype=paddle.float32
) )
...@@ -119,11 +118,10 @@ class HybridParallelClipGrad: ...@@ -119,11 +118,10 @@ class HybridParallelClipGrad:
# global norm of distributed BF16 params_and_grads # global norm of distributed BF16 params_and_grads
if len(sum_square_dist_bf16) == 0: if len(sum_square_dist_bf16) == 0:
global_norm_dist_bf16 = paddle.to_tensor( global_norm_dist_bf16 = paddle.to_tensor(
[0.0], dtype=paddle.float32 np.array(0.0), dtype=paddle.float32
) )
else: else:
global_norm_dist_bf16 = paddle.concat(sum_square_dist_bf16) global_norm_dist_bf16 = paddle.add_n(sum_square_dist_bf16)
global_norm_dist_bf16 = paddle.sum(global_norm_dist_bf16)
global_norm_dist_bf16 = paddle.cast( global_norm_dist_bf16 = paddle.cast(
global_norm_dist_bf16, dtype=paddle.float32 global_norm_dist_bf16, dtype=paddle.float32
) )
...@@ -131,30 +129,29 @@ class HybridParallelClipGrad: ...@@ -131,30 +129,29 @@ class HybridParallelClipGrad:
# global norm of non-distributed FP16 params_and_grads # global norm of non-distributed FP16 params_and_grads
if len(sum_square_not_dist_bf16) == 0: if len(sum_square_not_dist_bf16) == 0:
global_norm_not_dist_bf16 = paddle.to_tensor( global_norm_not_dist_bf16 = paddle.to_tensor(
[0.0], dtype=paddle.float32 np.array(0.0), dtype=paddle.float32
) )
else: else:
global_norm_not_dist_bf16 = paddle.concat(sum_square_not_dist_bf16) global_norm_not_dist_bf16 = paddle.add_n(sum_square_not_dist_bf16)
global_norm_not_dist_bf16 = paddle.sum(global_norm_not_dist_bf16)
global_norm_not_dist_bf16 = paddle.cast( global_norm_not_dist_bf16 = paddle.cast(
global_norm_not_dist_bf16, dtype=paddle.float32 global_norm_not_dist_bf16, dtype=paddle.float32
) )
# global norm of distributed FP32 params_and_grads # global norm of distributed FP32 params_and_grads
global_norm_dist_fp32 = ( if len(sum_square_dist_fp32) == 0:
paddle.concat(sum_square_dist_fp32) global_norm_dist_fp32 = paddle.to_tensor(
if len(sum_square_dist_fp32) != 0 np.array(0.0), dtype=paddle.float32
else paddle.to_tensor([0.0], dtype=paddle.float32) )
) else:
global_norm_dist_fp32 = paddle.sum(global_norm_dist_fp32) global_norm_dist_fp32 = paddle.add_n(sum_square_dist_fp32)
# global norm of non-distributed FP32 params_and_grads # global norm of non-distributed FP32 params_and_grads
global_norm_not_dist_fp32 = ( if len(sum_square_not_dist_fp32) == 0:
paddle.concat(sum_square_not_dist_fp32) global_norm_not_dist_fp32 = paddle.to_tensor(
if len(sum_square_not_dist_fp32) != 0 np.array(0.0), dtype=paddle.float32
else paddle.to_tensor([0.0], dtype=paddle.float32) )
) else:
global_norm_not_dist_fp32 = paddle.sum(global_norm_not_dist_fp32) global_norm_not_dist_fp32 = paddle.add_n(sum_square_not_dist_fp32)
global_norm_var_dist = ( global_norm_var_dist = (
global_norm_dist_fp16 global_norm_dist_fp16
...@@ -193,14 +190,14 @@ class HybridParallelClipGrad: ...@@ -193,14 +190,14 @@ class HybridParallelClipGrad:
) )
max_global_norm = paddle.full( max_global_norm = paddle.full(
shape=[1], shape=[],
dtype=global_norm_var_fp32.dtype, dtype=global_norm_var_fp32.dtype,
fill_value=self.clip_norm, fill_value=self.clip_norm,
) )
clip_var = paddle.divide( clip_var = paddle.divide(
x=max_global_norm, x=max_global_norm,
y=paddle.maximum(x=global_norm_var_fp32, y=max_global_norm) y=paddle.maximum(x=global_norm_var_fp32, y=max_global_norm)
+ paddle.to_tensor([1.0e-6], dtype=paddle.float32), + paddle.to_tensor(np.array(1.0e-6), dtype=paddle.float32),
) )
clip_var_fp16 = paddle.cast(clip_var, paddle.float16) clip_var_fp16 = paddle.cast(clip_var, paddle.float16)
......
...@@ -94,59 +94,64 @@ class GroupShardedClipGrad: ...@@ -94,59 +94,64 @@ class GroupShardedClipGrad:
# global norm of non-distributed FP16 params_and_grads # global norm of non-distributed FP16 params_and_grads
if len(sum_square_fp16) == 0: if len(sum_square_fp16) == 0:
global_norm_fp16 = paddle.to_tensor([0.0], dtype=paddle.float32) global_norm_fp16 = paddle.to_tensor(
np.array(0.0), dtype=paddle.float32
)
else: else:
global_norm_fp16 = paddle.concat(sum_square_fp16) global_norm_fp16 = paddle.add_n(sum_square_fp16)
global_norm_fp16 = paddle.sum(global_norm_fp16)
global_norm_fp16 = paddle.cast( global_norm_fp16 = paddle.cast(
global_norm_fp16, dtype=paddle.float32 global_norm_fp16, dtype=paddle.float32
) )
# global norm of non-distributed BFP16 params_and_grads # global norm of non-distributed BFP16 params_and_grads
if len(sum_square_bfp16) == 0: if len(sum_square_bfp16) == 0:
global_norm_bfp16 = paddle.to_tensor([0.0], dtype=paddle.float32) global_norm_bfp16 = paddle.to_tensor(
np.array(0.0), dtype=paddle.float32
)
else: else:
global_norm_bfp16 = paddle.concat(sum_square_bfp16) global_norm_bfp16 = paddle.add_n(sum_square_bfp16)
global_norm_bfp16 = paddle.sum(global_norm_bfp16)
global_norm_bfp16 = paddle.cast( global_norm_bfp16 = paddle.cast(
global_norm_bfp16, dtype=paddle.float32 global_norm_bfp16, dtype=paddle.float32
) )
# global norm of non-distributed FP16 params_and_grads for unslice parameters # global norm of non-distributed FP16 params_and_grads for unslice parameters
if len(unslice_params_fp16) == 0: if len(unslice_params_fp16) == 0:
global_unslice_fp16 = paddle.to_tensor([0.0], dtype=paddle.float32) global_unslice_fp16 = paddle.to_tensor(
np.array(0.0), dtype=paddle.float32
)
else: else:
global_unslice_fp16 = paddle.concat(unslice_params_fp16) global_unslice_fp16 = paddle.add_n(unslice_params_fp16)
global_unslice_fp16 = paddle.sum(global_unslice_fp16)
global_unslice_fp16 = paddle.cast( global_unslice_fp16 = paddle.cast(
global_unslice_fp16, dtype=paddle.float32 global_unslice_fp16, dtype=paddle.float32
) )
# global norm of non-distributed BFP16 params_and_grads for unslice parameters # global norm of non-distributed BFP16 params_and_grads for unslice parameters
if len(unslice_params_bfp16) == 0: if len(unslice_params_bfp16) == 0:
global_unslice_bfp16 = paddle.to_tensor([0.0], dtype=paddle.float32) global_unslice_bfp16 = paddle.to_tensor(
np.array(0.0), dtype=paddle.float32
)
else: else:
global_unslice_bfp16 = paddle.concat(unslice_params_bfp16) global_unslice_bfp16 = paddle.add_n(unslice_params_bfp16)
global_unslice_bfp16 = paddle.sum(global_unslice_bfp16)
global_unslice_bfp16 = paddle.cast( global_unslice_bfp16 = paddle.cast(
global_unslice_bfp16, dtype=paddle.float32 global_unslice_bfp16, dtype=paddle.float32
) )
# global norm of non-distributed FP32 params_and_grads # global norm of non-distributed FP32 params_and_grads
global_norm_fp32 = ( if len(sum_square_fp32) == 0:
paddle.concat(sum_square_fp32) global_norm_fp32 = paddle.to_tensor(
if len(sum_square_fp32) != 0 np.array(0.0), dtype=paddle.float32
else paddle.to_tensor([0.0], dtype=paddle.float32) )
) else:
global_norm_fp32 = paddle.sum(global_norm_fp32) global_norm_fp32 = paddle.add_n(sum_square_fp32)
# global norm of non-distributed FP32 params_and_grads for unslice parameters # global norm of non-distributed FP32 params_and_grads for unslice parameters
global_unslice_fp32 = ( if len(unslice_params_fp32) == 0:
paddle.concat(unslice_params_fp32) global_unslice_fp32 = paddle.to_tensor(
if len(unslice_params_fp32) != 0 np.array(0.0), dtype=paddle.float32
else paddle.to_tensor([0.0], dtype=paddle.float32) )
) else:
global_unslice_fp32 = paddle.sum(global_unslice_fp32) global_unslice_fp32 = paddle.add_n(unslice_params_fp32)
global_unslice_var = ( global_unslice_var = (
global_unslice_fp16 + global_unslice_fp32 + global_unslice_bfp16 global_unslice_fp16 + global_unslice_fp32 + global_unslice_bfp16
) )
...@@ -165,7 +170,7 @@ class GroupShardedClipGrad: ...@@ -165,7 +170,7 @@ class GroupShardedClipGrad:
global_norm_var = paddle.sqrt(global_norm_var + global_unslice_var) global_norm_var = paddle.sqrt(global_norm_var + global_unslice_var)
max_global_norm = paddle.full( max_global_norm = paddle.full(
shape=[1], dtype=global_norm_var.dtype, fill_value=self.clip_norm shape=[], dtype=global_norm_var.dtype, fill_value=self.clip_norm
) )
clip_var = paddle.divide( clip_var = paddle.divide(
......
...@@ -40,7 +40,7 @@ def sum(input, scope=None, util=None): ...@@ -40,7 +40,7 @@ def sum(input, scope=None, util=None):
# in model.py # in model.py
input = paddle.cast(some_input, dtype='float32') input = paddle.cast(some_input, dtype='float32')
cnt = paddle.sum(input) cnt = paddle.sum(input)
global_cnt = paddle.static.create_global_var(persistable=True, dtype='float32', shape=[1], value=0) global_cnt = paddle.static.create_global_var(persistable=True, dtype='float32', shape=[], value=0)
tmp = paddle.add(cnt, global_cnt) tmp = paddle.add(cnt, global_cnt)
paddle.assign(tmp, global_cnt) paddle.assign(tmp, global_cnt)
...@@ -80,7 +80,7 @@ def max(input, scope=None, util=None): ...@@ -80,7 +80,7 @@ def max(input, scope=None, util=None):
# in model.py # in model.py
input = paddle.cast(some_input, dtype='float32') input = paddle.cast(some_input, dtype='float32')
cnt = paddle.sum(input) cnt = paddle.sum(input)
global_cnt = paddle.static.create_global_var(persistable=True, dtype='float32', shape=[1], value=0) global_cnt = paddle.static.create_global_var(persistable=True, dtype='float32', shape=[], value=0)
tmp = paddle.maximum(cnt, global_cnt) tmp = paddle.maximum(cnt, global_cnt)
paddle.assign(tmp, global_cnt) paddle.assign(tmp, global_cnt)
...@@ -120,7 +120,7 @@ def min(input, scope=None, util=None): ...@@ -120,7 +120,7 @@ def min(input, scope=None, util=None):
# in model.py # in model.py
input = paddle.cast(some_input, dtype='float32') input = paddle.cast(some_input, dtype='float32')
cnt = paddle.sum(input) cnt = paddle.sum(input)
global_cnt = paddle.static.create_global_var(persistable=True, dtype='float32', shape=[1], value=0) global_cnt = paddle.static.create_global_var(persistable=True, dtype='float32', shape=[], value=0)
tmp = paddle.minimum(cnt, global_cnt) tmp = paddle.minimum(cnt, global_cnt)
paddle.assign(tmp, global_cnt) paddle.assign(tmp, global_cnt)
......
...@@ -955,7 +955,7 @@ class AMPPass(PassBase): ...@@ -955,7 +955,7 @@ class AMPPass(PassBase):
loss_op._set_attr(OP_ROLE_KEY, OpRole.Forward) loss_op._set_attr(OP_ROLE_KEY, OpRole.Forward)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, [-1], self.dist_context cast_op, ref_mesh, [-1 for i in loss.shape], self.dist_context
) )
# backward # backward
...@@ -970,12 +970,20 @@ class AMPPass(PassBase): ...@@ -970,12 +970,20 @@ class AMPPass(PassBase):
dtype=core.VarDesc.VarType.FP32, dtype=core.VarDesc.VarType.FP32,
persistable=loss.persistable, persistable=loss.persistable,
) )
set_var_dist_attr(self.dist_context, cast_loss_grad, [-1], ref_mesh) set_var_dist_attr(
self.dist_context,
cast_loss_grad,
[-1 for i in loss.shape],
ref_mesh,
)
pre_grad_name = first_backward_op.output_arg_names[0] pre_grad_name = first_backward_op.output_arg_names[0]
first_backward_op._rename_output(pre_grad_name, cast_loss_grad.name) first_backward_op._rename_output(pre_grad_name, cast_loss_grad.name)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
first_backward_op, ref_mesh, [-1], self.dist_context first_backward_op,
ref_mesh,
[-1 for i in loss.shape],
self.dist_context,
) )
cast_grad_op = main_block._insert_op( cast_grad_op = main_block._insert_op(
loss_op_idx + 3, loss_op_idx + 3,
...@@ -989,7 +997,10 @@ class AMPPass(PassBase): ...@@ -989,7 +997,10 @@ class AMPPass(PassBase):
}, },
) )
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_grad_op, ref_mesh, [-1], self.dist_context cast_grad_op,
ref_mesh,
[-1 for i in loss.shape],
self.dist_context,
) )
loss_op = cast_op loss_op = cast_op
loss = cast_loss loss = cast_loss
...@@ -1021,7 +1032,12 @@ class AMPPass(PassBase): ...@@ -1021,7 +1032,12 @@ class AMPPass(PassBase):
dtype=loss.dtype, dtype=loss.dtype,
persistable=loss.persistable, persistable=loss.persistable,
) )
set_var_dist_attr(self.dist_context, scaled_loss, [-1], ref_mesh) set_var_dist_attr(
self.dist_context,
scaled_loss,
[-1 for i in loss.shape],
ref_mesh,
)
elementwise_mul_op = main_block._insert_op( elementwise_mul_op = main_block._insert_op(
loss_op_idx + 1, loss_op_idx + 1,
...@@ -1034,7 +1050,10 @@ class AMPPass(PassBase): ...@@ -1034,7 +1050,10 @@ class AMPPass(PassBase):
) )
loss_op._set_attr(OP_ROLE_KEY, OpRole.Forward) loss_op._set_attr(OP_ROLE_KEY, OpRole.Forward)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
elementwise_mul_op, ref_mesh, [-1], self.dist_context elementwise_mul_op,
ref_mesh,
[-1 for i in loss.shape],
self.dist_context,
) )
# backward # backward
...@@ -1050,14 +1069,20 @@ class AMPPass(PassBase): ...@@ -1050,14 +1069,20 @@ class AMPPass(PassBase):
persistable=loss.persistable, persistable=loss.persistable,
) )
set_var_dist_attr( set_var_dist_attr(
self.dist_context, scaled_loss_grad, [-1], ref_mesh self.dist_context,
scaled_loss_grad,
[-1 for i in loss.shape],
ref_mesh,
) )
pre_grad_name = first_backward_op.output_arg_names[0] pre_grad_name = first_backward_op.output_arg_names[0]
first_backward_op._rename_output( first_backward_op._rename_output(
pre_grad_name, scaled_loss_grad.name pre_grad_name, scaled_loss_grad.name
) )
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
first_backward_op, ref_mesh, [-1], self.dist_context first_backward_op,
ref_mesh,
[-1 for i in loss.shape],
self.dist_context,
) )
scaled_loss_grad.op = first_backward_op scaled_loss_grad.op = first_backward_op
# FIXME(JZ-LIANG) a trick to insert backward op # FIXME(JZ-LIANG) a trick to insert backward op
...@@ -1085,7 +1110,10 @@ class AMPPass(PassBase): ...@@ -1085,7 +1110,10 @@ class AMPPass(PassBase):
elementwise_mul_grad_op = main_block.ops[loss_op_idx + 3] elementwise_mul_grad_op = main_block.ops[loss_op_idx + 3]
assert elementwise_mul_grad_op.type == "elementwise_mul_grad" assert elementwise_mul_grad_op.type == "elementwise_mul_grad"
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
elementwise_mul_grad_op, ref_mesh, [-1], self.dist_context elementwise_mul_grad_op,
ref_mesh,
[-1 for i in loss.shape],
self.dist_context,
) )
else: else:
scaled_loss = loss scaled_loss = loss
......
...@@ -678,7 +678,12 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"): ...@@ -678,7 +678,12 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"):
stop_gradient=src_var.stop_gradient, stop_gradient=src_var.stop_gradient,
) )
set_var_dist_attr(dist_context, output_var, [-1], world_process_group.ranks) set_var_dist_attr(
dist_context,
output_var,
[-1 for i in src_var.shape],
world_process_group.ranks,
)
# TODO to support CUDAPinned/NPU/XPU Places # TODO to support CUDAPinned/NPU/XPU Places
if direction == "D2H": if direction == "D2H":
...@@ -894,7 +899,7 @@ class FP16Pass(AMPPass): ...@@ -894,7 +899,7 @@ class FP16Pass(AMPPass):
set_var_dist_attr( set_var_dist_attr(
self.dist_context, self.dist_context,
found_inf, found_inf,
[-1], [-1 for i in found_inf.shape],
world_process_group.ranks, world_process_group.ranks,
) )
_set_op_dist_attr_with_ranks( _set_op_dist_attr_with_ranks(
......
...@@ -221,7 +221,7 @@ class ClipHelper: ...@@ -221,7 +221,7 @@ class ClipHelper:
in_var = self.block.vars[in_name] in_var = self.block.vars[in_name]
in_dist_attr = TensorDistAttr() in_dist_attr = TensorDistAttr()
in_dist_attr.process_mesh = ProcessMesh(self.world_ranks) in_dist_attr.process_mesh = ProcessMesh(self.world_ranks)
in_dist_attr.dims_mapping = [-1] in_dist_attr.dims_mapping = [-1 for i in in_var.shape]
self.dist_context.set_tensor_dist_attr_for_program( self.dist_context.set_tensor_dist_attr_for_program(
in_var, in_dist_attr in_var, in_dist_attr
) )
...@@ -230,7 +230,7 @@ class ClipHelper: ...@@ -230,7 +230,7 @@ class ClipHelper:
out_var = self.block.vars[out_name] out_var = self.block.vars[out_name]
out_dist_attr = TensorDistAttr() out_dist_attr = TensorDistAttr()
out_dist_attr.process_mesh = ProcessMesh(self.world_ranks) out_dist_attr.process_mesh = ProcessMesh(self.world_ranks)
out_dist_attr.dims_mapping = [-1] out_dist_attr.dims_mapping = [-1 for i in out_var.shape]
self.dist_context.set_tensor_dist_attr_for_program( self.dist_context.set_tensor_dist_attr_for_program(
out_var, out_dist_attr out_var, out_dist_attr
) )
......
...@@ -300,7 +300,7 @@ class QuantizationPass(PassBase): ...@@ -300,7 +300,7 @@ class QuantizationPass(PassBase):
for slot_name in quant_op.desc.input_names(): for slot_name in quant_op.desc.input_names():
in_name = quant_op.desc.input(slot_name)[0] in_name = quant_op.desc.input(slot_name)[0]
input_var = block._var_recursive(in_name) input_var = block._var_recursive(in_name)
ref_dims_mapping = [-1] ref_dims_mapping = [-1 for i in input_var.shape]
if slot_name == "X": if slot_name == "X":
continue continue
elif slot_name in ['Scale', 'ZeroPoint']: elif slot_name in ['Scale', 'ZeroPoint']:
...@@ -333,7 +333,7 @@ class QuantizationPass(PassBase): ...@@ -333,7 +333,7 @@ class QuantizationPass(PassBase):
for slot_name in quant_op.desc.output_names(): for slot_name in quant_op.desc.output_names():
output_name = quant_op.desc.output(slot_name)[0] output_name = quant_op.desc.output(slot_name)[0]
output_var = block._var_recursive(output_name) output_var = block._var_recursive(output_name)
ref_dims_mapping = [-1] ref_dims_mapping = [-1 for i in output_var.shape]
if slot_name == "Y": if slot_name == "Y":
dist_context.set_tensor_dist_attr_for_program( dist_context.set_tensor_dist_attr_for_program(
output_var, consume_input_dist_attr output_var, consume_input_dist_attr
......
...@@ -95,11 +95,7 @@ def check(use_cuda): ...@@ -95,11 +95,7 @@ def check(use_cuda):
fetch_list=[y_predict.name, avg_cost.name, acc_top1.name], fetch_list=[y_predict.name, avg_cost.name, acc_top1.name],
) )
step += 1 step += 1
print( print(f'iter={step:.0f},cost={outs[1]},acc1={outs[2]}')
'iter={:.0f},cost={},acc1={}'.format(
step, outs[1][0], outs[2]
)
)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -49,17 +49,19 @@ class TestResnetBase(TestParallelExecutorBase): ...@@ -49,17 +49,19 @@ class TestResnetBase(TestParallelExecutorBase):
) )
if compare_separately: if compare_separately:
for loss in zip(func_1_first_loss, func_2_first_loss): self.assertAlmostEqual(
self.assertAlmostEqual(loss[0], loss[1], delta=1e-5) func_1_first_loss, func_2_first_loss, delta=1e-5
for loss in zip(func_1_last_loss, func_2_last_loss): )
self.assertAlmostEqual(loss[0], loss[1], delta=delta2) self.assertAlmostEqual(
func_1_last_loss, func_2_last_loss, delta=delta2
)
else: else:
np.testing.assert_allclose( np.testing.assert_allclose(
func_1_loss_area, func_2_loss_area, rtol=delta2 func_1_loss_area, func_2_loss_area, rtol=delta2
) )
self.assertAlmostEqual( self.assertAlmostEqual(
np.mean(func_1_first_loss), func_2_first_loss[0], delta=1e-5 func_1_first_loss, func_2_first_loss, delta=1e-5
) )
self.assertAlmostEqual( self.assertAlmostEqual(
np.mean(func_1_last_loss), func_2_last_loss[0], delta=delta2 func_1_last_loss, func_2_last_loss, delta=delta2
) )
...@@ -24,6 +24,7 @@ from paddle.fluid.executor import Executor ...@@ -24,6 +24,7 @@ from paddle.fluid.executor import Executor
from paddle.fluid.framework import Program, grad_var_name from paddle.fluid.framework import Program, grad_var_name
np.random.seed(123) np.random.seed(123)
paddle.enable_static()
class PyArgsort: class PyArgsort:
...@@ -52,7 +53,7 @@ class PyArgsort: ...@@ -52,7 +53,7 @@ class PyArgsort:
out = ( out = (
np.array(self.indices, dtype=self.indices.dtype), np.array(self.indices, dtype=self.indices.dtype),
np.array(self.sorted_x, dtype=self.sorted_x.dtype), np.array(self.sorted_x, dtype=self.sorted_x.dtype),
np.array([self.loss], dtype=self.loss.dtype), np.array(self.loss, dtype=self.loss.dtype),
) )
return out return out
...@@ -178,7 +179,7 @@ class TestArgsortOpCPU(unittest.TestCase): ...@@ -178,7 +179,7 @@ class TestArgsortOpCPU(unittest.TestCase):
f[...] = o f[...] = o
dout_dfeed = (y_pos - y_neg) / (delta * 2) dout_dfeed = (y_pos - y_neg) / (delta * 2)
g[...] = dout_dfeed[0] g[...] = dout_dfeed
return grad_list return grad_list
......
...@@ -674,7 +674,7 @@ class TestCondBackward(unittest.TestCase): ...@@ -674,7 +674,7 @@ class TestCondBackward(unittest.TestCase):
}, },
fetch_list=[loss.name], fetch_list=[loss.name],
) )
numerical_grad[0][j] = (loss_delta[0] - loss_value[0]) / delta numerical_grad[0][j] = (loss_delta - loss_value) / delta
feed_img_delta[0][j] = feed_img[0][j] feed_img_delta[0][j] = feed_img[0][j]
np.testing.assert_allclose( np.testing.assert_allclose(
img_grad, numerical_grad, rtol=0.05, atol=0.05 img_grad, numerical_grad, rtol=0.05, atol=0.05
......
...@@ -64,7 +64,7 @@ class TestFunctionCosineEmbeddingLoss(unittest.TestCase): ...@@ -64,7 +64,7 @@ class TestFunctionCosineEmbeddingLoss(unittest.TestCase):
reduction='mean', reduction='mean',
) )
np.testing.assert_allclose(dy_result.numpy(), expected1, rtol=1e-05) np.testing.assert_allclose(dy_result.numpy(), expected1, rtol=1e-05)
self.assertTrue(dy_result.shape, [1]) self.assertEqual(dy_result.shape, [])
dy_result = paddle.nn.functional.cosine_embedding_loss( dy_result = paddle.nn.functional.cosine_embedding_loss(
input1, input2, label, margin=0.5, reduction='sum' input1, input2, label, margin=0.5, reduction='sum'
...@@ -78,7 +78,7 @@ class TestFunctionCosineEmbeddingLoss(unittest.TestCase): ...@@ -78,7 +78,7 @@ class TestFunctionCosineEmbeddingLoss(unittest.TestCase):
) )
np.testing.assert_allclose(dy_result.numpy(), expected2, rtol=1e-05) np.testing.assert_allclose(dy_result.numpy(), expected2, rtol=1e-05)
self.assertTrue(dy_result.shape, [1]) self.assertEqual(dy_result.shape, [])
dy_result = paddle.nn.functional.cosine_embedding_loss( dy_result = paddle.nn.functional.cosine_embedding_loss(
input1, input2, label, margin=0.5, reduction='none' input1, input2, label, margin=0.5, reduction='none'
...@@ -92,7 +92,7 @@ class TestFunctionCosineEmbeddingLoss(unittest.TestCase): ...@@ -92,7 +92,7 @@ class TestFunctionCosineEmbeddingLoss(unittest.TestCase):
) )
np.testing.assert_allclose(dy_result.numpy(), expected3, rtol=1e-05) np.testing.assert_allclose(dy_result.numpy(), expected3, rtol=1e-05)
self.assertTrue(dy_result.shape, [5]) self.assertEqual(dy_result.shape, [5])
def run_static(self, use_gpu=False): def run_static(self, use_gpu=False):
input1 = static.data(name='input1', shape=[5, 3], dtype='float64') input1 = static.data(name='input1', shape=[5, 3], dtype='float64')
...@@ -257,7 +257,7 @@ class TestClassCosineEmbeddingLoss(unittest.TestCase): ...@@ -257,7 +257,7 @@ class TestClassCosineEmbeddingLoss(unittest.TestCase):
reduction='mean', reduction='mean',
) )
np.testing.assert_allclose(dy_result.numpy(), expected1, rtol=1e-05) np.testing.assert_allclose(dy_result.numpy(), expected1, rtol=1e-05)
self.assertTrue(dy_result.shape, [1]) self.assertEqual(dy_result.shape, [])
input1_1D = paddle.to_tensor(self.input1_np_1D) input1_1D = paddle.to_tensor(self.input1_np_1D)
input2_1D = paddle.to_tensor(self.input2_np_1D) input2_1D = paddle.to_tensor(self.input2_np_1D)
......
...@@ -42,7 +42,7 @@ class PyRNNBase: ...@@ -42,7 +42,7 @@ class PyRNNBase:
def forward(self): def forward(self):
for step_id in range(self.x.shape[0]): for step_id in range(self.x.shape[0]):
self.step(step_id, self.x[step_id]) self.step(step_id, self.x[step_id])
return np.array([np.mean(self.y)]) return np.mean(self.y)
def segment_inputs(self): def segment_inputs(self):
return [self.x[i] for i in range(self.x.shape[0])] return [self.x[i] for i in range(self.x.shape[0])]
...@@ -251,7 +251,7 @@ class EagerDeletionRecurrentOpTest1(unittest.TestCase): ...@@ -251,7 +251,7 @@ class EagerDeletionRecurrentOpTest1(unittest.TestCase):
f[...] = o f[...] = o
dout_dfeed = (y_pos - y_neg) / (delta * 2) dout_dfeed = (y_pos - y_neg) / (delta * 2)
g[...] = dout_dfeed[0] g[...] = dout_dfeed
return grad_list return grad_list
......
...@@ -69,9 +69,10 @@ class TestFetchLoDTensorArray(unittest.TestCase): ...@@ -69,9 +69,10 @@ class TestFetchLoDTensorArray(unittest.TestCase):
loss_v, array_v = exe.run( loss_v, array_v = exe.run(
binary, feed=feed_dict, fetch_list=[loss, array] binary, feed=feed_dict, fetch_list=[loss, array]
) )
self.assertEqual(np.array(loss_v).shape, (1,)) self.assertEqual(loss_v.shape, ())
self.assertEqual(np.array(array_v[0]).shape, (batch_size, 784)) self.assertEqual(array_v[0].shape, (batch_size, 784))
self.assertEqual(np.array(array_v[1]).shape, (batch_size, 1)) self.assertEqual(array_v[1].shape, (batch_size, 1))
self.assertEqual(array_v[2].shape, ())
np.testing.assert_allclose(loss_v, array_v[2], rtol=1e-05) np.testing.assert_allclose(loss_v, array_v[2], rtol=1e-05)
def test_fetch_lod_tensor_array(self): def test_fetch_lod_tensor_array(self):
...@@ -81,4 +82,5 @@ class TestFetchLoDTensorArray(unittest.TestCase): ...@@ -81,4 +82,5 @@ class TestFetchLoDTensorArray(unittest.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
...@@ -78,10 +78,12 @@ class TestFuseAllReduceOpsBase(TestParallelExecutorBase): ...@@ -78,10 +78,12 @@ class TestFuseAllReduceOpsBase(TestParallelExecutorBase):
optimizer=optimizer, optimizer=optimizer,
) )
for loss in zip(not_fuse_op_first_loss, fuse_op_first_loss): self.assertAlmostEqual(
self.assertAlmostEqual(loss[0], loss[1], delta=1e-6) not_fuse_op_first_loss, fuse_op_first_loss, delta=1e-6
for loss in zip(not_fuse_op_last_loss, fuse_op_last_loss): )
self.assertAlmostEqual(loss[0], loss[1], delta=1e-6) self.assertAlmostEqual(
not_fuse_op_last_loss, fuse_op_last_loss, delta=1e-6
)
def optimizer(self, learning_rate=1e-3): def optimizer(self, learning_rate=1e-3):
optimizer = fluid.optimizer.SGD( optimizer = fluid.optimizer.SGD(
......
...@@ -98,7 +98,7 @@ class TestFuseBatchNormActPass(unittest.TestCase): ...@@ -98,7 +98,7 @@ class TestFuseBatchNormActPass(unittest.TestCase):
loss_v = exe.run( loss_v = exe.run(
binary, feed=feeder.feed(data), fetch_list=[loss] binary, feed=feeder.feed(data), fetch_list=[loss]
) )
loss_vals.append(loss_v[0][0]) loss_vals.append(loss_v[0])
# open fused_bn_act_ops # open fused_bn_act_ops
build_strategy_fused = fluid.BuildStrategy() build_strategy_fused = fluid.BuildStrategy()
...@@ -118,7 +118,7 @@ class TestFuseBatchNormActPass(unittest.TestCase): ...@@ -118,7 +118,7 @@ class TestFuseBatchNormActPass(unittest.TestCase):
loss_v = exe.run( loss_v = exe.run(
binary_fused, feed=feeder.feed(data), fetch_list=[loss] binary_fused, feed=feeder.feed(data), fetch_list=[loss]
) )
loss_vals_fused.append(loss_v[0][0]) loss_vals_fused.append(loss_v[0])
# check loss # check loss
for i in range(iters): for i in range(iters):
......
...@@ -216,7 +216,7 @@ class TestFusedBnAddActAPI(unittest.TestCase): ...@@ -216,7 +216,7 @@ class TestFusedBnAddActAPI(unittest.TestCase):
loss_v = exe.run( loss_v = exe.run(
binary_fused, feed={"x": x, "y": y}, fetch_list=[loss] binary_fused, feed={"x": x, "y": y}, fetch_list=[loss]
) )
loss_vals_fused.append(loss_v[0][0]) loss_vals_fused.append(loss_v[0])
# build_origin_program: turn off fused_bn_act_ops # build_origin_program: turn off fused_bn_act_ops
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
...@@ -234,7 +234,7 @@ class TestFusedBnAddActAPI(unittest.TestCase): ...@@ -234,7 +234,7 @@ class TestFusedBnAddActAPI(unittest.TestCase):
feed={"x": x_data[i], "y": y_data[i]}, feed={"x": x_data[i], "y": y_data[i]},
fetch_list=[loss], fetch_list=[loss],
) )
loss_vals.append(loss_v[0][0]) loss_vals.append(loss_v[0])
# check loss # check loss
for i in range(iters): for i in range(iters):
......
...@@ -74,10 +74,12 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -74,10 +74,12 @@ class TestMNIST(TestParallelExecutorBase):
optimizer=_optimizer, optimizer=_optimizer,
) )
for loss in zip(not_fuse_op_first_loss, fuse_op_first_loss): self.assertAlmostEqual(
self.assertAlmostEqual(loss[0], loss[1], delta=1e-6) not_fuse_op_first_loss, fuse_op_first_loss, delta=1e-6
for loss in zip(not_fuse_op_last_loss, fuse_op_last_loss): )
self.assertAlmostEqual(loss[0], loss[1], delta=1e-6) self.assertAlmostEqual(
not_fuse_op_last_loss, fuse_op_last_loss, delta=1e-6
)
def test_simple_fc_with_fuse_op(self): def test_simple_fc_with_fuse_op(self):
self._compare_fuse_elewise_add_act_ops(simple_fc_net, DeviceType.CUDA) self._compare_fuse_elewise_add_act_ops(simple_fc_net, DeviceType.CUDA)
......
...@@ -70,10 +70,12 @@ class TestFuseOptimizationOps(TestParallelExecutorBase): ...@@ -70,10 +70,12 @@ class TestFuseOptimizationOps(TestParallelExecutorBase):
optimizer=optimizer, optimizer=optimizer,
) )
for loss in zip(not_fuse_op_first_loss, fuse_op_first_loss): self.assertAlmostEqual(
self.assertAlmostEqual(loss[0], loss[1], delta=1e-6) not_fuse_op_first_loss, fuse_op_first_loss, delta=1e-6
for loss in zip(not_fuse_op_last_loss, fuse_op_last_loss): )
self.assertAlmostEqual(loss[0], loss[1], delta=1e-6) self.assertAlmostEqual(
not_fuse_op_last_loss, fuse_op_last_loss, delta=1e-6
)
def _decorate_compare_fused_optimizer_ops( def _decorate_compare_fused_optimizer_ops(
self, model, use_device, optimizer self, model, use_device, optimizer
......
...@@ -118,10 +118,12 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -118,10 +118,12 @@ class TestMNIST(TestParallelExecutorBase):
optimizer=_optimizer, optimizer=_optimizer,
) )
for loss in zip(not_fuse_op_first_loss, fuse_op_first_loss): self.assertAlmostEqual(
self.assertAlmostEqual(loss[0], loss[1], delta=1e-6) not_fuse_op_first_loss, fuse_op_first_loss, delta=1e-6
for loss in zip(not_fuse_op_last_loss, fuse_op_last_loss): )
self.assertAlmostEqual(loss[0], loss[1], delta=1e-6) self.assertAlmostEqual(
not_fuse_op_last_loss, fuse_op_last_loss, delta=1e-6
)
def test_simple_depthwise_with_fuse_op(self): def test_simple_depthwise_with_fuse_op(self):
self._compare(simple_depthwise_net, DeviceType.CUDA) self._compare(simple_depthwise_net, DeviceType.CUDA)
......
...@@ -152,7 +152,7 @@ class TestGradientClip(unittest.TestCase): ...@@ -152,7 +152,7 @@ class TestGradientClip(unittest.TestCase):
data = next(self.train_data()) data = next(self.train_data())
val = exe.run(prog, feed=feeder.feed(data), fetch_list=[cost])[0] val = exe.run(prog, feed=feeder.feed(data), fetch_list=[cost])[0]
self.assertEqual((1,), val.shape) self.assertEqual(val.shape, ())
self.assertFalse(np.isnan(val)) self.assertFalse(np.isnan(val))
def backward_and_optimize(self, cost): def backward_and_optimize(self, cost):
......
...@@ -50,7 +50,7 @@ class TestFunctionalHingeEmbeddingLoss(unittest.TestCase): ...@@ -50,7 +50,7 @@ class TestFunctionalHingeEmbeddingLoss(unittest.TestCase):
dy_result = paddle.nn.functional.hinge_embedding_loss(input, label) dy_result = paddle.nn.functional.hinge_embedding_loss(input, label)
expected = calc_hinge_embedding_loss(self.input_np, self.label_np) expected = calc_hinge_embedding_loss(self.input_np, self.label_np)
np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05) np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05)
self.assertTrue(dy_result.shape, [1]) self.assertEqual(dy_result.shape, [])
dy_result = paddle.nn.functional.hinge_embedding_loss( dy_result = paddle.nn.functional.hinge_embedding_loss(
input, label, reduction='sum' input, label, reduction='sum'
...@@ -59,7 +59,7 @@ class TestFunctionalHingeEmbeddingLoss(unittest.TestCase): ...@@ -59,7 +59,7 @@ class TestFunctionalHingeEmbeddingLoss(unittest.TestCase):
self.input_np, self.label_np, reduction='sum' self.input_np, self.label_np, reduction='sum'
) )
np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05) np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05)
self.assertTrue(dy_result.shape, [1]) self.assertEqual(dy_result.shape, [])
dy_result = paddle.nn.functional.hinge_embedding_loss( dy_result = paddle.nn.functional.hinge_embedding_loss(
input, label, reduction='none' input, label, reduction='none'
...@@ -68,7 +68,7 @@ class TestFunctionalHingeEmbeddingLoss(unittest.TestCase): ...@@ -68,7 +68,7 @@ class TestFunctionalHingeEmbeddingLoss(unittest.TestCase):
self.input_np, self.label_np, reduction='none' self.input_np, self.label_np, reduction='none'
) )
np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05) np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05)
self.assertTrue(dy_result.shape, self.shape) self.assertEqual(dy_result.shape, list(self.shape))
def run_static_check(self, place=paddle.CPUPlace): def run_static_check(self, place=paddle.CPUPlace):
paddle.enable_static() paddle.enable_static()
...@@ -129,7 +129,7 @@ class TestClassHingeEmbeddingLoss(unittest.TestCase): ...@@ -129,7 +129,7 @@ class TestClassHingeEmbeddingLoss(unittest.TestCase):
dy_result = hinge_embedding_loss(input, label) dy_result = hinge_embedding_loss(input, label)
expected = calc_hinge_embedding_loss(self.input_np, self.label_np) expected = calc_hinge_embedding_loss(self.input_np, self.label_np)
np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05) np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05)
self.assertTrue(dy_result.shape, [1]) self.assertEqual(dy_result.shape, [])
hinge_embedding_loss = paddle.nn.loss.HingeEmbeddingLoss( hinge_embedding_loss = paddle.nn.loss.HingeEmbeddingLoss(
reduction='sum' reduction='sum'
...@@ -139,7 +139,7 @@ class TestClassHingeEmbeddingLoss(unittest.TestCase): ...@@ -139,7 +139,7 @@ class TestClassHingeEmbeddingLoss(unittest.TestCase):
self.input_np, self.label_np, reduction='sum' self.input_np, self.label_np, reduction='sum'
) )
np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05) np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05)
self.assertTrue(dy_result.shape, [1]) self.assertEqual(dy_result.shape, [])
hinge_embedding_loss = paddle.nn.loss.HingeEmbeddingLoss( hinge_embedding_loss = paddle.nn.loss.HingeEmbeddingLoss(
reduction='none' reduction='none'
...@@ -149,7 +149,7 @@ class TestClassHingeEmbeddingLoss(unittest.TestCase): ...@@ -149,7 +149,7 @@ class TestClassHingeEmbeddingLoss(unittest.TestCase):
self.input_np, self.label_np, reduction='none' self.input_np, self.label_np, reduction='none'
) )
np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05) np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05)
self.assertTrue(dy_result.shape, self.shape) self.assertTrue(dy_result.shape, list(self.shape))
def run_static_check(self, place=paddle.CPUPlace): def run_static_check(self, place=paddle.CPUPlace):
paddle.enable_static() paddle.enable_static()
......
...@@ -80,10 +80,9 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -80,10 +80,9 @@ class TestMNIST(TestParallelExecutorBase):
use_device=use_device, use_device=use_device,
use_ir_memory_optimize=True, use_ir_memory_optimize=True,
) )
for loss in zip(first_loss0, first_loss1):
self.assertAlmostEqual(loss[0], loss[1], delta=1e-6) self.assertAlmostEqual(first_loss0, first_loss1, delta=1e-6)
for loss in zip(last_loss0, last_loss1): self.assertAlmostEqual(last_loss0, last_loss1, delta=1e-6)
self.assertAlmostEqual(loss[0], loss[1], delta=1e-6)
def test_simple_fc_net(self): def test_simple_fc_net(self):
self._compare_ir_memory_optimize(simple_fc_net, DeviceType.CPU) self._compare_ir_memory_optimize(simple_fc_net, DeviceType.CPU)
......
...@@ -36,7 +36,7 @@ class TestFunctionalL1Loss(unittest.TestCase): ...@@ -36,7 +36,7 @@ class TestFunctionalL1Loss(unittest.TestCase):
dy_result = paddle.nn.functional.l1_loss(input, label, reduction='sum') dy_result = paddle.nn.functional.l1_loss(input, label, reduction='sum')
expected = np.sum(np.abs(self.input_np - self.label_np)) expected = np.sum(np.abs(self.input_np - self.label_np))
np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05) np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05)
self.assertEqual(dy_result.shape, [1]) self.assertEqual(dy_result.shape, [])
dy_result = paddle.nn.functional.l1_loss(input, label, reduction='none') dy_result = paddle.nn.functional.l1_loss(input, label, reduction='none')
expected = np.abs(self.input_np - self.label_np) expected = np.abs(self.input_np - self.label_np)
...@@ -125,7 +125,7 @@ class TestClassL1Loss(unittest.TestCase): ...@@ -125,7 +125,7 @@ class TestClassL1Loss(unittest.TestCase):
dy_result = l1_loss(input, label) dy_result = l1_loss(input, label)
expected = np.sum(np.abs(self.input_np - self.label_np)) expected = np.sum(np.abs(self.input_np - self.label_np))
np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05) np.testing.assert_allclose(dy_result.numpy(), expected, rtol=1e-05)
self.assertEqual(dy_result.shape, [1]) self.assertEqual(dy_result.shape, [])
l1_loss = paddle.nn.loss.L1Loss(reduction='none') l1_loss = paddle.nn.loss.L1Loss(reduction='none')
dy_result = l1_loss(input, label) dy_result = l1_loss(input, label)
......
...@@ -118,7 +118,7 @@ class TestNNMseLoss(unittest.TestCase): ...@@ -118,7 +118,7 @@ class TestNNMseLoss(unittest.TestCase):
np.testing.assert_allclose(static_result, expected, rtol=1e-05) np.testing.assert_allclose(static_result, expected, rtol=1e-05)
np.testing.assert_allclose(static_result, dy_result, rtol=1e-05) np.testing.assert_allclose(static_result, dy_result, rtol=1e-05)
np.testing.assert_allclose(dy_result, expected, rtol=1e-05) np.testing.assert_allclose(dy_result, expected, rtol=1e-05)
self.assertTrue(dy_result.shape, [1]) self.assertEqual(dy_result.shape, ())
def test_NNMseLoss_sum(self): def test_NNMseLoss_sum(self):
for dim in [[10, 10], [2, 10, 10], [3, 3, 10, 10]]: for dim in [[10, 10], [2, 10, 10], [3, 3, 10, 10]]:
...@@ -164,7 +164,7 @@ class TestNNMseLoss(unittest.TestCase): ...@@ -164,7 +164,7 @@ class TestNNMseLoss(unittest.TestCase):
np.testing.assert_allclose(static_result, expected, rtol=1e-05) np.testing.assert_allclose(static_result, expected, rtol=1e-05)
np.testing.assert_allclose(static_result, dy_result, rtol=1e-05) np.testing.assert_allclose(static_result, dy_result, rtol=1e-05)
np.testing.assert_allclose(dy_result, expected, rtol=1e-05) np.testing.assert_allclose(dy_result, expected, rtol=1e-05)
self.assertTrue(dy_result.shape, [1]) self.assertEqual(dy_result.shape, ())
def test_NNMseLoss_none(self): def test_NNMseLoss_none(self):
for dim in [[10, 10], [2, 10, 10], [3, 3, 10, 10]]: for dim in [[10, 10], [2, 10, 10], [3, 3, 10, 10]]:
...@@ -210,7 +210,7 @@ class TestNNMseLoss(unittest.TestCase): ...@@ -210,7 +210,7 @@ class TestNNMseLoss(unittest.TestCase):
np.testing.assert_allclose(static_result, expected, rtol=1e-05) np.testing.assert_allclose(static_result, expected, rtol=1e-05)
np.testing.assert_allclose(static_result, dy_result, rtol=1e-05) np.testing.assert_allclose(static_result, dy_result, rtol=1e-05)
np.testing.assert_allclose(dy_result, expected, rtol=1e-05) np.testing.assert_allclose(dy_result, expected, rtol=1e-05)
self.assertTrue(dy_result.shape, [1]) self.assertEqual(dy_result.shape, tuple(dim))
class TestNNFunctionalMseLoss(unittest.TestCase): class TestNNFunctionalMseLoss(unittest.TestCase):
...@@ -254,7 +254,7 @@ class TestNNFunctionalMseLoss(unittest.TestCase): ...@@ -254,7 +254,7 @@ class TestNNFunctionalMseLoss(unittest.TestCase):
np.testing.assert_allclose(static_result, expected, rtol=1e-05) np.testing.assert_allclose(static_result, expected, rtol=1e-05)
np.testing.assert_allclose(static_result, dy_result, rtol=1e-05) np.testing.assert_allclose(static_result, dy_result, rtol=1e-05)
np.testing.assert_allclose(dy_result, expected, rtol=1e-05) np.testing.assert_allclose(dy_result, expected, rtol=1e-05)
self.assertTrue(dy_result.shape, [1]) self.assertEqual(dy_result.shape, ())
def test_NNFunctionalMseLoss_sum(self): def test_NNFunctionalMseLoss_sum(self):
for dim in [[10, 10], [2, 10, 10], [3, 3, 10, 10]]: for dim in [[10, 10], [2, 10, 10], [3, 3, 10, 10]]:
...@@ -296,7 +296,7 @@ class TestNNFunctionalMseLoss(unittest.TestCase): ...@@ -296,7 +296,7 @@ class TestNNFunctionalMseLoss(unittest.TestCase):
np.testing.assert_allclose(static_result, expected, rtol=1e-05) np.testing.assert_allclose(static_result, expected, rtol=1e-05)
np.testing.assert_allclose(static_result, dy_result, rtol=1e-05) np.testing.assert_allclose(static_result, dy_result, rtol=1e-05)
np.testing.assert_allclose(dy_result, expected, rtol=1e-05) np.testing.assert_allclose(dy_result, expected, rtol=1e-05)
self.assertTrue(dy_result.shape, [1]) self.assertEqual(dy_result.shape, ())
def test_NNFunctionalMseLoss_none(self): def test_NNFunctionalMseLoss_none(self):
for dim in [[10, 10], [2, 10, 10], [3, 3, 10, 10]]: for dim in [[10, 10], [2, 10, 10], [3, 3, 10, 10]]:
...@@ -338,7 +338,7 @@ class TestNNFunctionalMseLoss(unittest.TestCase): ...@@ -338,7 +338,7 @@ class TestNNFunctionalMseLoss(unittest.TestCase):
np.testing.assert_allclose(static_result, expected, rtol=1e-05) np.testing.assert_allclose(static_result, expected, rtol=1e-05)
np.testing.assert_allclose(static_result, dy_result, rtol=1e-05) np.testing.assert_allclose(static_result, dy_result, rtol=1e-05)
np.testing.assert_allclose(dy_result, expected, rtol=1e-05) np.testing.assert_allclose(dy_result, expected, rtol=1e-05)
self.assertTrue(dy_result.shape, [1]) self.assertEqual(dy_result.shape, tuple(dim))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -42,10 +42,6 @@ class TestNanInf(unittest.TestCase): ...@@ -42,10 +42,6 @@ class TestNanInf(unittest.TestCase):
out, err = proc.communicate() out, err = proc.communicate()
returncode = proc.returncode returncode = proc.returncode
print(out)
print(err)
# in python3, type(out+err) is 'bytes', need use encode # in python3, type(out+err) is 'bytes', need use encode
assert (out + err).find(b'There are NAN or INF') != -1 assert (out + err).find(b'There are NAN or INF') != -1
......
...@@ -110,7 +110,7 @@ def train(dot_save_dir, prefix, seed=1234): ...@@ -110,7 +110,7 @@ def train(dot_save_dir, prefix, seed=1234):
loss_values = [] loss_values = []
for step in range(iters): for step in range(iters):
loss_v = exe.run(compiled_program, feed=feed[step], fetch_list=[loss]) loss_v = exe.run(compiled_program, feed=feed[step], fetch_list=[loss])
loss_values.append(loss_v[0][0]) loss_values.append(loss_v[0])
return loss_values return loss_values
......
...@@ -48,10 +48,14 @@ class TestResnetWithReduceBase(TestParallelExecutorBase): ...@@ -48,10 +48,14 @@ class TestResnetWithReduceBase(TestParallelExecutorBase):
optimizer=seresnext_net.optimizer, optimizer=seresnext_net.optimizer,
) )
for loss in zip(all_reduce_first_loss, reduce_first_loss): self.assertAlmostEqual(
self.assertAlmostEqual(loss[0], loss[1], delta=1e-5) all_reduce_first_loss, reduce_first_loss, delta=1e-5
for loss in zip(all_reduce_last_loss, reduce_last_loss): )
self.assertAlmostEqual(loss[0], loss[1], delta=loss[0] * delta2) self.assertAlmostEqual(
all_reduce_last_loss,
reduce_last_loss,
delta=all_reduce_last_loss * delta2,
)
if not use_device: if not use_device:
return return
...@@ -86,20 +90,32 @@ class TestResnetWithReduceBase(TestParallelExecutorBase): ...@@ -86,20 +90,32 @@ class TestResnetWithReduceBase(TestParallelExecutorBase):
enable_sequential_execution=True, enable_sequential_execution=True,
) )
for loss in zip(all_reduce_first_loss, all_reduce_first_loss_seq): self.assertAlmostEqual(
self.assertAlmostEqual(loss[0], loss[1], delta=1e-5) all_reduce_first_loss, all_reduce_first_loss_seq, delta=1e-5
for loss in zip(all_reduce_last_loss, all_reduce_last_loss_seq): )
self.assertAlmostEqual(loss[0], loss[1], delta=loss[0] * delta2) self.assertAlmostEqual(
all_reduce_last_loss,
all_reduce_last_loss_seq,
delta=all_reduce_last_loss * delta2,
)
for loss in zip(reduce_first_loss, reduce_first_loss_seq): self.assertAlmostEqual(
self.assertAlmostEqual(loss[0], loss[1], delta=1e-5) reduce_first_loss, reduce_first_loss_seq, delta=1e-5
for loss in zip(reduce_last_loss, reduce_last_loss_seq): )
self.assertAlmostEqual(loss[0], loss[1], delta=loss[0] * delta2) self.assertAlmostEqual(
reduce_last_loss,
reduce_last_loss_seq,
delta=reduce_last_loss * delta2,
)
for loss in zip(all_reduce_first_loss_seq, reduce_first_loss_seq): self.assertAlmostEqual(
self.assertAlmostEqual(loss[0], loss[1], delta=1e-5) all_reduce_first_loss_seq, reduce_first_loss_seq, delta=1e-5
for loss in zip(all_reduce_last_loss_seq, reduce_last_loss_seq): )
self.assertAlmostEqual(loss[0], loss[1], delta=loss[0] * delta2) self.assertAlmostEqual(
all_reduce_last_loss_seq,
reduce_last_loss_seq,
delta=all_reduce_last_loss_seq * delta2,
)
class TestResnetWithReduceCPU(TestResnetWithReduceBase): class TestResnetWithReduceCPU(TestResnetWithReduceBase):
......
...@@ -37,7 +37,7 @@ class PyRNNBase: ...@@ -37,7 +37,7 @@ class PyRNNBase:
def forward(self): def forward(self):
for step_id in range(self.x.shape[0]): for step_id in range(self.x.shape[0]):
self.step(step_id, self.x[step_id]) self.step(step_id, self.x[step_id])
return np.array([np.mean(self.y)]) return np.mean(self.y)
def segment_inputs(self): def segment_inputs(self):
return [self.x[i] for i in range(self.x.shape[0])] return [self.x[i] for i in range(self.x.shape[0])]
...@@ -239,7 +239,7 @@ class RecurrentOpTest1(unittest.TestCase): ...@@ -239,7 +239,7 @@ class RecurrentOpTest1(unittest.TestCase):
f[...] = o f[...] = o
dout_dfeed = (y_pos - y_neg) / (delta * 2) dout_dfeed = (y_pos - y_neg) / (delta * 2)
g[...] = dout_dfeed[0] g[...] = dout_dfeed
return grad_list return grad_list
......
...@@ -103,7 +103,7 @@ class TestResnet50Accuracy(unittest.TestCase): ...@@ -103,7 +103,7 @@ class TestResnet50Accuracy(unittest.TestCase):
fetch_list=[loss], fetch_list=[loss],
return_numpy=True, return_numpy=True,
) )
loss_vals.append(loss_v[0][0]) loss_vals.append(loss_v[0])
return loss_vals return loss_vals
def test_check_resnet50_accuracy(self): def test_check_resnet50_accuracy(self):
......
...@@ -514,7 +514,7 @@ class TestParametersWithStopGradient(unittest.TestCase): ...@@ -514,7 +514,7 @@ class TestParametersWithStopGradient(unittest.TestCase):
dy_loss = self.train(to_static=False) dy_loss = self.train(to_static=False)
st_loss = self.train(to_static=True) st_loss = self.train(to_static=True)
self.assertEqual(dy_loss[0], st_loss[0]) self.assertEqual(dy_loss, st_loss)
paddle.enable_static() paddle.enable_static()
......
...@@ -220,14 +220,6 @@ class TestReduceAPI(unittest.TestCase): ...@@ -220,14 +220,6 @@ class TestReduceAPI(unittest.TestCase):
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad.numpy(), np.array(3.0)) np.testing.assert_allclose(x.grad.numpy(), np.array(3.0))
if api in [
paddle.sum,
paddle.mean,
paddle.nanmean,
paddle.nansum,
]:
return
# 2) x is ND, reduce to 0D # 2) x is ND, reduce to 0D
if api in [paddle.all, paddle.any]: if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, [3, 5]).astype('bool') x = paddle.randint(0, 2, [3, 5]).astype('bool')
...@@ -302,20 +294,11 @@ class TestReduceAPI(unittest.TestCase): ...@@ -302,20 +294,11 @@ class TestReduceAPI(unittest.TestCase):
np.testing.assert_allclose(res[2], np.array(1.0)) np.testing.assert_allclose(res[2], np.array(1.0))
np.testing.assert_allclose(res[3], np.array(1.0)) np.testing.assert_allclose(res[3], np.array(1.0))
if api in [
paddle.sum,
paddle.mean,
paddle.nanmean,
paddle.nansum,
]:
return
# 2) x is ND, reduce to 0D # 2) x is ND, reduce to 0D
if api in [paddle.all, paddle.any]: if api in [paddle.all, paddle.any]:
x = paddle.randint(0, 2, [3, 5]).astype('bool') x = paddle.randint(0, 2, [3, 5]).astype('bool')
else: else:
x = paddle.rand([3, 5]) x = paddle.rand([3, 5])
x = paddle.rand([3, 5])
x.stop_gradient = False x.stop_gradient = False
out = api(x, None) out = api(x, None)
paddle.static.append_backward(out) paddle.static.append_backward(out)
...@@ -1365,6 +1348,7 @@ class TestSundryAPI(unittest.TestCase): ...@@ -1365,6 +1348,7 @@ class TestSundryAPI(unittest.TestCase):
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
def test_std(self): def test_std(self):
# 1) x is 0D
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
out1 = paddle.std(x) out1 = paddle.std(x)
...@@ -1372,18 +1356,24 @@ class TestSundryAPI(unittest.TestCase): ...@@ -1372,18 +1356,24 @@ class TestSundryAPI(unittest.TestCase):
out1.backward() out1.backward()
out2.backward() out2.backward()
# checkout shape of out
self.assertEqual(out1.shape, []) self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, []) self.assertEqual(out2.shape, [])
# checkout value of out
self.assertEqual(out1, 0) self.assertEqual(out1, 0)
self.assertEqual(out2, 0) self.assertEqual(out2, 0)
# checkout backward
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
# 2) x is ND
x = paddle.rand([3, 5])
x.stop_gradient = False
out = paddle.std(x)
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(x.grad.shape, [3, 5])
def test_var(self): def test_var(self):
# 1) x is 0D
x = paddle.rand([]) x = paddle.rand([])
x.stop_gradient = False x.stop_gradient = False
out1 = paddle.var(x) out1 = paddle.var(x)
...@@ -1391,18 +1381,23 @@ class TestSundryAPI(unittest.TestCase): ...@@ -1391,18 +1381,23 @@ class TestSundryAPI(unittest.TestCase):
out1.backward() out1.backward()
out2.backward() out2.backward()
# checkout shape of out
self.assertEqual(out1.shape, []) self.assertEqual(out1.shape, [])
self.assertEqual(out2.shape, []) self.assertEqual(out2.shape, [])
# checkout value of out
self.assertEqual(out1, 0) self.assertEqual(out1, 0)
self.assertEqual(out2, 0) self.assertEqual(out2, 0)
# checkout backward
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
np.testing.assert_allclose(x.grad, 0) np.testing.assert_allclose(x.grad, 0)
# 2) x is ND
x = paddle.rand([3, 5])
x.stop_gradient = False
out = paddle.std(x)
out.backward()
self.assertEqual(out.shape, [])
self.assertEqual(x.grad.shape, [3, 5])
def test_quantile(self): def test_quantile(self):
# 1) x is 0D # 1) x is 0D
x = paddle.rand([]) x = paddle.rand([])
...@@ -1598,7 +1593,6 @@ class TestSundryAPI(unittest.TestCase): ...@@ -1598,7 +1593,6 @@ class TestSundryAPI(unittest.TestCase):
out = paddle.clip(x, -5, 5) out = paddle.clip(x, -5, 5)
out.retain_grads() out.retain_grads()
out.backward() out.backward()
self.assertEqual(out.shape, []) self.assertEqual(out.shape, [])
self.assertEqual(out.grad.shape, []) self.assertEqual(out.grad.shape, [])
self.assertEqual(x.grad.shape, []) self.assertEqual(x.grad.shape, [])
...@@ -1608,7 +1602,6 @@ class TestSundryAPI(unittest.TestCase): ...@@ -1608,7 +1602,6 @@ class TestSundryAPI(unittest.TestCase):
out1 = paddle.clip(x1, paddle.full([], 5.0), paddle.full([], 5.0)) out1 = paddle.clip(x1, paddle.full([], 5.0), paddle.full([], 5.0))
out1.retain_grads() out1.retain_grads()
out1.backward() out1.backward()
self.assertEqual(out1.shape, []) self.assertEqual(out1.shape, [])
self.assertEqual(out1.grad.shape, []) self.assertEqual(out1.grad.shape, [])
self.assertEqual(x1.grad.shape, []) self.assertEqual(x1.grad.shape, [])
...@@ -5643,8 +5636,7 @@ class TestDistribution(unittest.TestCase): ...@@ -5643,8 +5636,7 @@ class TestDistribution(unittest.TestCase):
self.assertEqual( self.assertEqual(
d.log_prob(paddle.full([], 2, dtype='int64')).shape, [] d.log_prob(paddle.full([], 2, dtype='int64')).shape, []
) )
# because use paddle.sum self.assertEqual(d.entropy().shape, [])
# self.assertEqual(d.entropy().shape, [])
def test_Normal(self): def test_Normal(self):
normal = paddle.distribution.Normal(0.0, 3.0) normal = paddle.distribution.Normal(0.0, 3.0)
...@@ -5687,10 +5679,9 @@ class TestDistribution(unittest.TestCase): ...@@ -5687,10 +5679,9 @@ class TestDistribution(unittest.TestCase):
self.assertEqual(beta.sample([]).shape, []) self.assertEqual(beta.sample([]).shape, [])
self.assertEqual(beta.mean.shape, []) self.assertEqual(beta.mean.shape, [])
self.assertEqual(beta.variance.shape, []) self.assertEqual(beta.variance.shape, [])
# because use paddle.sum self.assertEqual(beta.prob(self.x).shape, [])
# self.assertEqual(beta.prob(self.x).shape, []) self.assertEqual(beta.log_prob(self.x).shape, [])
# self.assertEqual(beta.log_prob(self.x).shape, []) self.assertEqual(beta.entropy().shape, [])
# self.assertEqual(beta.entropy().shape, [])
def test_kl_divergence(self): def test_kl_divergence(self):
p = paddle.distribution.Beta(alpha=0.5, beta=0.5) p = paddle.distribution.Beta(alpha=0.5, beta=0.5)
...@@ -5749,10 +5740,9 @@ class TestDistribution(unittest.TestCase): ...@@ -5749,10 +5740,9 @@ class TestDistribution(unittest.TestCase):
d = paddle.distribution.Multinomial( d = paddle.distribution.Multinomial(
10, paddle.to_tensor([0.2, 0.3, 0.5]) 10, paddle.to_tensor([0.2, 0.3, 0.5])
) )
# because use paddle.sum self.assertEqual(d.prob(self.x).shape, [])
# self.assertEqual(d.prob(self.x).shape, []) self.assertEqual(d.log_prob(self.x).shape, [])
# self.assertEqual(d.log_prob(self.x).shape, []) self.assertEqual(d.entropy().shape, [])
# self.assertEqual(d.entropy().shape, [])
class TestLossAPI(unittest.TestCase): class TestLossAPI(unittest.TestCase):
...@@ -5770,10 +5760,10 @@ class TestLossAPI(unittest.TestCase): ...@@ -5770,10 +5760,10 @@ class TestLossAPI(unittest.TestCase):
fg_num_1 = paddle.full([1], 2.0) fg_num_1 = paddle.full([1], 2.0)
out0 = F.sigmoid_focal_loss( out0 = F.sigmoid_focal_loss(
logit, label, normalizer=fg_num_0, reduction='mean' logit, label, normalizer=fg_num_0, reduction='sum'
) )
out1 = F.sigmoid_focal_loss( out1 = F.sigmoid_focal_loss(
logit, label, normalizer=fg_num_1, reduction='mean' logit, label, normalizer=fg_num_1, reduction='sum'
) )
out0.retain_grads() out0.retain_grads()
...@@ -5788,6 +5778,28 @@ class TestLossAPI(unittest.TestCase): ...@@ -5788,6 +5778,28 @@ class TestLossAPI(unittest.TestCase):
self.assertEqual(out0.grad.shape, []) self.assertEqual(out0.grad.shape, [])
self.assertEqual(logit.grad.shape, [2, 3]) self.assertEqual(logit.grad.shape, [2, 3])
def test_cross_entropy(self):
input = paddle.rand([3, 5])
input.stop_gradient = False
label = paddle.randint(0, 5, shape=[3])
loss = paddle.nn.functional.cross_entropy(input, label, reduction='sum')
loss.backward()
self.assertEqual(loss.shape, [])
self.assertEqual(input.grad.shape, [3, 5])
def test_l1_loss(self):
input = paddle.rand([3, 5])
input.stop_gradient = False
label = paddle.rand([3, 5])
loss = paddle.nn.functional.l1_loss(input, label, reduction='mean')
loss.backward()
self.assertEqual(loss.shape, [])
self.assertEqual(input.grad.shape, [3, 5])
class TestLossAPIStatic(unittest.TestCase): class TestLossAPIStatic(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -5818,12 +5830,42 @@ class TestLossAPIStatic(unittest.TestCase): ...@@ -5818,12 +5830,42 @@ class TestLossAPIStatic(unittest.TestCase):
prog, fetch_list=[out0, out1, out0.grad_name, logit.grad_name] prog, fetch_list=[out0, out1, out0.grad_name, logit.grad_name]
) )
np.testing.assert_allclose(res[0], res[1]) np.testing.assert_allclose(res[0], res[1])
# because static use paddle.mean self.assertEqual(res[0].shape, ())
# self.assertEqual(res[0].shape, ()) self.assertEqual(res[1].shape, ())
# self.assertEqual(res[1].shape, ()) self.assertEqual(res[2].shape, ())
# self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, (2, 3)) self.assertEqual(res[3].shape, (2, 3))
@prog_scope()
def test_cross_entropy(self):
input = paddle.rand([3, 5])
input.stop_gradient = False
label = paddle.randint(0, 5, shape=[3])
label.stop_gradient = False
loss = paddle.nn.functional.cross_entropy(
input, label, reduction='mean'
)
paddle.static.append_backward(loss)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[loss, input.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (3, 5))
@prog_scope()
def test_l1_loss(self):
input = paddle.rand([3, 5])
input.stop_gradient = False
label = paddle.rand([3, 5])
loss = paddle.nn.functional.l1_loss(input, label, reduction='sum')
paddle.static.append_backward(loss)
prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[loss, input.grad_name])
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (3, 5))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -252,7 +252,7 @@ def mean_composite(x, axis, keepdim): ...@@ -252,7 +252,7 @@ def mean_composite(x, axis, keepdim):
operator.mul, [x.shape[axis] for axis in axes] operator.mul, [x.shape[axis] for axis in axes]
) )
norm = fill_constant( norm = fill_constant(
shape=x.shape if len(x.shape) == 0 else [1], shape=[],
value=value_to_fill, value=value_to_fill,
dtype=sum_x.dtype, dtype=sum_x.dtype,
) )
......
...@@ -142,22 +142,18 @@ class ClipGradForMOEByGlobalNorm(ClipGradBase): ...@@ -142,22 +142,18 @@ class ClipGradForMOEByGlobalNorm(ClipGradBase):
global_norm_var = [] global_norm_var = []
if len(sum_square_list_fp16) > 0: if len(sum_square_list_fp16) > 0:
global_norm_var_fp16 = paddle.concat(sum_square_list_fp16) global_norm_var_fp16 = paddle.add_n(sum_square_list_fp16)
global_norm_var_fp16 = paddle.sum(global_norm_var_fp16)
global_norm_var.append(global_norm_var_fp16.astype(sum_dtype)) global_norm_var.append(global_norm_var_fp16.astype(sum_dtype))
if len(sum_square_list_fp32) > 0: if len(sum_square_list_fp32) > 0:
global_norm_var_fp32 = paddle.concat(sum_square_list_fp32) global_norm_var_fp32 = paddle.add_n(sum_square_list_fp32)
global_norm_var_fp32 = paddle.sum(global_norm_var_fp32)
if sum_dtype == 'float32': if sum_dtype == 'float32':
global_norm_var.append(global_norm_var_fp32) global_norm_var.append(global_norm_var_fp32)
else: else:
global_norm_var.append(global_norm_var_fp32.astype(sum_dtype)) global_norm_var.append(global_norm_var_fp32.astype(sum_dtype))
if len(sum_square_list) > 0: if len(sum_square_list) > 0:
global_norm_var_fp64 = paddle.concat(sum_square_list) global_norm_var_fp64 = paddle.add_n(sum_square_list)
global_norm_var_fp64 = paddle.sum(global_norm_var_fp64)
global_norm_var.append(global_norm_var_fp64) global_norm_var.append(global_norm_var_fp64)
global_norm_var = paddle.concat(global_norm_var) global_norm_var = paddle.add_n(global_norm_var)
global_norm_var = paddle.sum(global_norm_var)
return global_norm_var, sum_dtype return global_norm_var, sum_dtype
@no_grad() @no_grad()
......
...@@ -206,7 +206,7 @@ def get_program(): ...@@ -206,7 +206,7 @@ def get_program():
auto.shard_tensor(error_cost, _g_process_mesh, [None, None, None]) auto.shard_tensor(error_cost, _g_process_mesh, [None, None, None])
loss = paddle.mean(error_cost) loss = paddle.mean(error_cost)
auto.shard_tensor(loss, _g_process_mesh, [None]) auto.shard_tensor(loss, _g_process_mesh, [])
return train_program, start_program, dataloader, i, loss return train_program, start_program, dataloader, i, loss
......
...@@ -41,14 +41,14 @@ paddle.enable_static() ...@@ -41,14 +41,14 @@ paddle.enable_static()
'v_not_none', 'v_not_none',
utils.reduce, utils.reduce,
np.random.rand(2, 3), np.random.rand(2, 3),
np.random.rand(1), np.array(np.random.rand()),
False, False,
), ),
( (
'xs_stop_gradient', 'xs_stop_gradient',
utils.reduce, utils.reduce,
np.random.rand(2, 3), np.random.rand(2, 3),
np.random.rand(1), np.array(np.random.rand()),
True, True,
), ),
( (
......
...@@ -178,7 +178,7 @@ def train(use_pure_fp16=True, use_nesterov=False, optimizer=""): ...@@ -178,7 +178,7 @@ def train(use_pure_fp16=True, use_nesterov=False, optimizer=""):
(loss,) = exe.run( (loss,) = exe.run(
train_program, feed=feeder.feed(data), fetch_list=[sum_cost] train_program, feed=feeder.feed(data), fetch_list=[sum_cost]
) )
loss_v = loss[0] if isinstance(loss, np.ndarray) else loss loss_v = float(loss) if isinstance(loss, np.ndarray) else loss
print( print(
'PassID {:1}, Train Batch ID {:04}, train loss {:2.4}'.format( 'PassID {:1}, Train Batch ID {:04}, train loss {:2.4}'.format(
pass_id, batch_id + 1, float(loss_v) pass_id, batch_id + 1, float(loss_v)
......
...@@ -1205,7 +1205,7 @@ class TestStickBreakingTransform(unittest.TestCase): ...@@ -1205,7 +1205,7 @@ class TestStickBreakingTransform(unittest.TestCase):
@param.param_func(((np.random.random(10),),)) @param.param_func(((np.random.random(10),),))
def test_forward_log_det_jacobian(self, x): def test_forward_log_det_jacobian(self, x):
self.assertEqual( self.assertEqual(
self._t.forward_log_det_jacobian(paddle.to_tensor(x)).shape, [1] self._t.forward_log_det_jacobian(paddle.to_tensor(x)).shape, []
) )
......
...@@ -65,7 +65,7 @@ class TestAsyncRead(unittest.TestCase): ...@@ -65,7 +65,7 @@ class TestAsyncRead(unittest.TestCase):
) )
# index data # index data
index_array1 = paddle.gather(self.src, self.index) index_array1 = paddle.gather(self.src, self.index)
count_numel = paddle.sum(count).numpy()[0] count_numel = paddle.sum(count).item()
index_array2 = self.dst[count_numel : count_numel + len(self.index)] index_array2 = self.dst[count_numel : count_numel + len(self.index)]
np.testing.assert_allclose( np.testing.assert_allclose(
index_array1.numpy(), index_array2.numpy(), rtol=1e-05 index_array1.numpy(), index_array2.numpy(), rtol=1e-05
......
...@@ -41,7 +41,7 @@ def desired(primal, cotangent, axis, keep_dim): ...@@ -41,7 +41,7 @@ def desired(primal, cotangent, axis, keep_dim):
class TestSumGradComp(unittest.TestCase): class TestSumGradComp(unittest.TestCase):
def test_sum_grad_comp_1(self): def test_sum_grad_comp_1(self):
self.primal = np.random.rand(10, 10) self.primal = np.random.rand(10, 10)
self.cotangent = np.random.rand(1) self.cotangent = np.array(np.random.rand())
paddle.disable_static() paddle.disable_static()
np.testing.assert_allclose( np.testing.assert_allclose(
......
...@@ -126,7 +126,7 @@ class TestCustomStream(unittest.TestCase): ...@@ -126,7 +126,7 @@ class TestCustomStream(unittest.TestCase):
for out in outs: for out in outs:
for baseline, result in zip(outs[0], out): for baseline, result in zip(outs[0], out):
self.assertEqual(baseline[0], result[0]) self.assertEqual(baseline, result)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -1269,10 +1269,11 @@ class TestSundryAPI(unittest.TestCase): ...@@ -1269,10 +1269,11 @@ class TestSundryAPI(unittest.TestCase):
out0.numpy(), out0.numpy(),
out1.numpy(), out1.numpy(),
) )
self.assertEqual(out0.shape, [])
out0.retain_grads() out0.retain_grads()
out0.backward() out0.backward()
self.assertEqual(out0.grad.shape, [1]) self.assertEqual(out0.grad.shape, [])
self.assertEqual(logit.grad.shape, [2, 3]) self.assertEqual(logit.grad.shape, [2, 3])
def test_allclose(self): def test_allclose(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册