未验证 提交 a057df50 编写于 作者: W wanghuancoder 提交者: GitHub

fix split and concat out (#41419)

上级 91212104
......@@ -227,7 +227,6 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"c_reduce", {"Out"}},
{"c_scatter", {"Out"}},
{"barrier", {"Out"}},
{"assign", {"Out"}},
{"fake_quantize_dequantize_moving_average_abs_max",
{"Out", "OutScale", "OutAccum", "OutState"}},
{"fake_quantize_dequantize_abs_max", {"Out", "OutScale"}},
......@@ -243,6 +242,8 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"get_float_status", {"FloatStatusOut"}},
{"assign", {"Out"}},
{"assign_value", {"Out"}},
{"split", {"Out"}},
{"concat", {"Out"}},
};
// NOTE(pangyoki): Tensor View Strategy.
......
......@@ -5026,7 +5026,9 @@ def split(input, num_or_sections, dim=-1, name=None):
raise TypeError(
"The type of 'num_or_sections' in split must be int, list or tuple in imperative mode, but "
"received %s." % (type(num_or_sections)))
return _C_ops.split(input, num, *attrs)
out = [_varbase_creator() for n in range(num)]
_C_ops.split(input, out, *attrs)
return out
check_variable_and_dtype(
input, 'input',
......
......@@ -337,7 +337,9 @@ def concat(input, axis=0, name=None):
axis = axis.item(0)
if not isinstance(input, Variable):
input = [t for t in input if t.shape.count(0) == 0]
return _C_ops.concat(input, 'axis', axis)
out = _varbase_creator()
_C_ops.concat(input, out, 'axis', axis)
return out
check_type(input, 'input', (list, tuple, Variable), 'concat')
if not isinstance(input, Variable):
......
......@@ -69,7 +69,9 @@ def parameters_to_vector(parameters, name=None):
out = _varbase_creator(dtype=dtype)
if in_dygraph_mode():
with paddle.fluid.dygraph.no_grad():
_C_ops.concat(parameters, 'axis', 0)._share_underline_tensor_to(out)
tmp = _varbase_creator()
_C_ops.concat(parameters, tmp, 'axis', 0)
tmp._share_underline_tensor_to(out)
else:
_dygraph_tracer().trace_op(
type='concat',
......@@ -120,8 +122,8 @@ def vector_to_parameters(vec, parameters, name=None):
if in_dygraph_mode():
with paddle.fluid.dygraph.no_grad():
res = _C_ops.split(vec,
len(parameters), 'axis', 0, 'sections', sections)
res = [_varbase_creator() for n in range(len(parameters))]
_C_ops.split(vec, res, 'axis', 0, 'sections', sections)
for i in range(0, len(res)):
res[i]._share_underline_tensor_to(parameters[i])
else:
......
......@@ -3911,7 +3911,8 @@ def diff(x, n=1, axis=-1, prepend=None, append=None, name=None):
input_list = [x, append]
has_pend = True
if has_pend:
new_input = _C_ops.concat(input_list, 'axis', axis)
new_input = _varbase_creator()
_C_ops.concat(input_list, new_input, 'axis', axis)
else:
new_input = x
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册