未验证 提交 a057df50 编写于 作者: W wanghuancoder 提交者: GitHub

fix split and concat out (#41419)

上级 91212104
...@@ -227,7 +227,6 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = { ...@@ -227,7 +227,6 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"c_reduce", {"Out"}}, {"c_reduce", {"Out"}},
{"c_scatter", {"Out"}}, {"c_scatter", {"Out"}},
{"barrier", {"Out"}}, {"barrier", {"Out"}},
{"assign", {"Out"}},
{"fake_quantize_dequantize_moving_average_abs_max", {"fake_quantize_dequantize_moving_average_abs_max",
{"Out", "OutScale", "OutAccum", "OutState"}}, {"Out", "OutScale", "OutAccum", "OutState"}},
{"fake_quantize_dequantize_abs_max", {"Out", "OutScale"}}, {"fake_quantize_dequantize_abs_max", {"Out", "OutScale"}},
...@@ -243,6 +242,8 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = { ...@@ -243,6 +242,8 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"get_float_status", {"FloatStatusOut"}}, {"get_float_status", {"FloatStatusOut"}},
{"assign", {"Out"}}, {"assign", {"Out"}},
{"assign_value", {"Out"}}, {"assign_value", {"Out"}},
{"split", {"Out"}},
{"concat", {"Out"}},
}; };
// NOTE(pangyoki): Tensor View Strategy. // NOTE(pangyoki): Tensor View Strategy.
......
...@@ -5026,7 +5026,9 @@ def split(input, num_or_sections, dim=-1, name=None): ...@@ -5026,7 +5026,9 @@ def split(input, num_or_sections, dim=-1, name=None):
raise TypeError( raise TypeError(
"The type of 'num_or_sections' in split must be int, list or tuple in imperative mode, but " "The type of 'num_or_sections' in split must be int, list or tuple in imperative mode, but "
"received %s." % (type(num_or_sections))) "received %s." % (type(num_or_sections)))
return _C_ops.split(input, num, *attrs) out = [_varbase_creator() for n in range(num)]
_C_ops.split(input, out, *attrs)
return out
check_variable_and_dtype( check_variable_and_dtype(
input, 'input', input, 'input',
......
...@@ -337,7 +337,9 @@ def concat(input, axis=0, name=None): ...@@ -337,7 +337,9 @@ def concat(input, axis=0, name=None):
axis = axis.item(0) axis = axis.item(0)
if not isinstance(input, Variable): if not isinstance(input, Variable):
input = [t for t in input if t.shape.count(0) == 0] input = [t for t in input if t.shape.count(0) == 0]
return _C_ops.concat(input, 'axis', axis) out = _varbase_creator()
_C_ops.concat(input, out, 'axis', axis)
return out
check_type(input, 'input', (list, tuple, Variable), 'concat') check_type(input, 'input', (list, tuple, Variable), 'concat')
if not isinstance(input, Variable): if not isinstance(input, Variable):
......
...@@ -69,7 +69,9 @@ def parameters_to_vector(parameters, name=None): ...@@ -69,7 +69,9 @@ def parameters_to_vector(parameters, name=None):
out = _varbase_creator(dtype=dtype) out = _varbase_creator(dtype=dtype)
if in_dygraph_mode(): if in_dygraph_mode():
with paddle.fluid.dygraph.no_grad(): with paddle.fluid.dygraph.no_grad():
_C_ops.concat(parameters, 'axis', 0)._share_underline_tensor_to(out) tmp = _varbase_creator()
_C_ops.concat(parameters, tmp, 'axis', 0)
tmp._share_underline_tensor_to(out)
else: else:
_dygraph_tracer().trace_op( _dygraph_tracer().trace_op(
type='concat', type='concat',
...@@ -120,8 +122,8 @@ def vector_to_parameters(vec, parameters, name=None): ...@@ -120,8 +122,8 @@ def vector_to_parameters(vec, parameters, name=None):
if in_dygraph_mode(): if in_dygraph_mode():
with paddle.fluid.dygraph.no_grad(): with paddle.fluid.dygraph.no_grad():
res = _C_ops.split(vec, res = [_varbase_creator() for n in range(len(parameters))]
len(parameters), 'axis', 0, 'sections', sections) _C_ops.split(vec, res, 'axis', 0, 'sections', sections)
for i in range(0, len(res)): for i in range(0, len(res)):
res[i]._share_underline_tensor_to(parameters[i]) res[i]._share_underline_tensor_to(parameters[i])
else: else:
......
...@@ -3911,7 +3911,8 @@ def diff(x, n=1, axis=-1, prepend=None, append=None, name=None): ...@@ -3911,7 +3911,8 @@ def diff(x, n=1, axis=-1, prepend=None, append=None, name=None):
input_list = [x, append] input_list = [x, append]
has_pend = True has_pend = True
if has_pend: if has_pend:
new_input = _C_ops.concat(input_list, 'axis', axis) new_input = _varbase_creator()
_C_ops.concat(input_list, new_input, 'axis', axis)
else: else:
new_input = x new_input = x
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册