未验证 提交 2c12e94b 编写于 作者: J Jason 提交者: GitHub

Merge pull request #403 from Channingss/fix_shape_infer

Fix shape infer
......@@ -346,8 +346,12 @@ class ONNXGraph(Graph):
#if len(value_info['shape']) == 0 or value_info[
# 'dtype'] is None or 0 in value_info['shape']:
# #TODO add node shape inference
shape = value_info['shape']
for idx in range(len(shape)):
if shape[idx] == 0:
shape[idx] = -1
node.out_shapes.append(shape)
node.dtype = value_info['dtype']
node.out_shapes.append(value_info['shape'])
else:
node.out_shapes.append([])
......
......@@ -14,7 +14,6 @@
# Reference Code from https://github.com/microsoft/onnxruntime, Licensed under the MIT License.
# -*- coding: UTF-8 -*-
import argparse
import numpy as np
import onnx
......@@ -23,7 +22,6 @@ from onnx import helper, numpy_helper, shape_inference
import sympy
from packaging import version
assert version.parse(onnx.__version__) >= version.parse("1.5.0")
def get_attribute(node, attr_name, default_value=None):
......@@ -45,17 +43,15 @@ def get_shape_from_type_proto(type_proto):
def get_shape_from_sympy_shape(sympy_shape):
sympy_shape = [
return [
None if i is None else (int(i) if is_literal(i) else str(i))
for i in sympy_shape
]
return sympy_shape
def is_literal(dim):
return type(dim) in [int, np.int64, np.int32, sympy.Integer] or (
hasattr(dim, 'is_number') and
dim.is_number) # or (hasattr(dim, 'is_integer') and dim.is_integer)
return type(dim) in [int, np.int64, np.int32, sympy.Integer] or (hasattr(
dim, 'is_number') and dim.is_number)
def handle_negative_axis(axis, rank):
......@@ -119,6 +115,7 @@ class SymbolicShapeInference:
'Div': self._infer_symbolic_compute_ops,
'Expand': self._infer_Expand,
'Equal': self._infer_symbolic_compute_ops,
'Floor': self._infer_symbolic_compute_ops,
'Gather': self._infer_Gather,
'GatherElements': self._infer_GatherElements,
'GatherND': self._infer_GatherND,
......@@ -145,13 +142,13 @@ class SymbolicShapeInference:
'Size': self._infer_Size,
'Slice': self._infer_Slice,
'Split': self._infer_Split,
'SplitToSequence': self._infer_SplitToSequence,
'Squeeze': self._infer_Squeeze,
'Sub': self._infer_symbolic_compute_ops,
'Tile': self._infer_Tile,
'TopK': self._infer_TopK,
'Unsqueeze': self._infer_Unsqueeze,
'Where': self._infer_symbolic_compute_ops,
'Transpose': self._infer_Transpose,
'ZipMap': self._infer_ZipMap
}
self.run_ = True
......@@ -267,8 +264,9 @@ class SymbolicShapeInference:
if pending_nodes and self.verbose_ > 0:
print('SymbolicShapeInference: orphaned nodes discarded: ')
for n in pending_nodes:
print(n.op_type + ': ' + n.output[0])
print(
*[n.op_type + ': ' + n.output[0] for n in pending_nodes],
sep='\n')
if input_shapes is not None:
for input_name, shape in input_shapes.items():
......@@ -280,6 +278,7 @@ class SymbolicShapeInference:
helper.make_tensor_value_info(
value_info.name,
value_info.type.tensor_type.elem_type, shape))
self.initializers_ = dict(
[(i.name, i) for i in self.out_mp_.graph.initializer])
self.known_vi_ = dict(
......@@ -351,21 +350,11 @@ class SymbolicShapeInference:
def _get_shape(self, node, idx):
name = node.input[idx]
shape = []
if name in self.known_vi_:
shape = get_shape_from_type_proto(self.known_vi_[name].type)
elif name in self.initializers_:
assert name in self.initializers_
shape = list(self.initializers_[name].dims)
return shape
def _get_initializer_value(self, node, idx):
name = node.input[idx]
if name in self.initializers_:
value = numpy_helper.to_array(self.initializers_[name])
return value
return get_shape_from_type_proto(self.known_vi_[name].type)
else:
return False
assert name in self.initializers_
return list(self.initializers_[name].dims)
def _get_shape_rank(self, node, idx):
return len(self._get_shape(node, idx))
......@@ -373,7 +362,7 @@ class SymbolicShapeInference:
def _get_sympy_shape(self, node, idx):
sympy_shape = []
for d in self._get_shape(node, idx):
if type(d) is str:
if type(d) == str:
sympy_shape.append(self.symbolic_dims_[d] if d in
self.symbolic_dims_ else sympy.Symbol(
d, integer=True))
......@@ -416,10 +405,14 @@ class SymbolicShapeInference:
# run single node inference with self.known_vi_ shapes
# note that inference rely on initializer values is not handled
# as we don't copy initializer weights to tmp_graph for inference speed purpose
if node.op_type == 'SplitToSequence':
make_value_info_func = helper.make_sequence_value_info
else:
make_value_info_func = helper.make_tensor_value_info
tmp_graph = helper.make_graph(
[node], 'tmp', [self.known_vi_[i] for i in node.input if i], [
helper.make_tensor_value_info(i, onnx.TensorProto.UNDEFINED,
None) for i in node.output
make_value_info_func(i, onnx.TensorProto.UNDEFINED, None)
for i in node.output
])
self.tmp_mp_.graph.CopyFrom(tmp_graph)
self.tmp_mp_ = shape_inference.infer_shapes(self.tmp_mp_)
......@@ -465,6 +458,12 @@ class SymbolicShapeInference:
self.verbose_)
all_shapes_inferred = False
symbolic_shape_inference._preprocess(self.tmp_mp_)
# note that after _preprocess, Constant node will be converted to initializer and should be appended to subgraph.initializer
subgraph.initializer.extend([
i for i in symbolic_shape_inference.out_mp_.graph.initializer
if i.name not in subgraph_implicit_input and i.name not in
subgraph_inputs
])
symbolic_shape_inference.suggested_merge_ = self.suggested_merge_.copy()
while symbolic_shape_inference.run_:
all_shapes_inferred = symbolic_shape_inference._infer_impl(
......@@ -533,18 +532,13 @@ class SymbolicShapeInference:
assert len(node.output) == 1
values = self._get_int_values(node, broadcast=True)
if all([v is not None for v in values]):
new_shape = []
is_list = [type(v) == list for v in values]
as_list = any(is_list)
if as_list:
data = [op_func(vs) for vs in zip(*values)]
self.sympy_data_[node.output[0]] = data
new_shape = np.array(data).shape
self.sympy_data_[node.output[
0]] = [op_func(vs) for vs in zip(*values)]
else:
data = op_func(values)
self.sympy_data_[node.output[0]] = data
new_shape = np.array(data).shape
vi = self.known_vi_[node.output[0]]
self.sympy_data_[node.output[0]] = op_func(values)
def _pass_on_sympy_data(self, node):
assert len(node.input) == 1 or node.op_type == 'Reshape'
......@@ -677,8 +671,8 @@ class SymbolicShapeInference:
lhs_reduce_dim = -1
rhs_reduce_dim = -2
new_shape = self._broadcast_shapes(
lhs_shape[:-2], rhs_shape[:-2]) + [lhs_shape[-2]
] + [rhs_shape[-1]]
lhs_shape[:-2],
rhs_shape[:-2]) + [lhs_shape[-2]] + [rhs_shape[-1]]
# merge reduce dim
self._check_merged_dims(
[lhs_shape[lhs_reduce_dim], rhs_shape[rhs_reduce_dim]],
......@@ -706,6 +700,7 @@ class SymbolicShapeInference:
'Add': lambda l: l[0] + l[1],
'Div': lambda l: l[0] // l[1], # integer div in sympy
'Equal': lambda l: l[0] == l[1],
'Floor': lambda l: sympy.floor(l[0]),
'Max':
lambda l: l[1] if is_literal(l[0]) and int(l[0]) < -self.int_max_ else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])),
'Min':
......@@ -731,15 +726,6 @@ class SymbolicShapeInference:
helper.make_tensor_value_info(node.output[0], output_type,
self._get_shape(node, 0)))
def _infer_Transpose(self, node):
input_shape = self._get_shape(node, 0)
perm = get_attribute(node, 'perm')
output_shape = np.array(input_shape)[perm].tolist()
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
node.input[0]].type.tensor_type.elem_type, output_shape))
def _infer_Compress(self, node):
input_shape = self._get_shape(node, 0)
# create a new symbolic dimension for Compress output
......@@ -851,11 +837,11 @@ class SymbolicShapeInference:
axis = handle_negative_axis(
get_attribute(node, 'axis', 0), len(data_shape))
indices_shape = self._get_shape(node, 1)
new_shape = data_shape[:axis] + indices_shape + data_shape[axis + 1:]
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[
0], vi.type.tensor_type.elem_type, new_shape))
helper.make_tensor_value_info(
node.output[0], vi.type.tensor_type.elem_type, data_shape[:axis]
+ indices_shape + data_shape[axis + 1:]))
if node.input[0] in self.sympy_data_:
assert 0 == get_attribute(node, 'axis',
0) # only handle 1D sympy compute
......@@ -863,9 +849,8 @@ class SymbolicShapeInference:
data = self.sympy_data_[node.input[0]]
if type(data) == list:
if type(idx) == np.ndarray and len(idx.shape) == 1:
self.sympy_data_[node.output[0]] = [
data[int(i)] for i in idx
]
self.sympy_data_[node.output[
0]] = [data[int(i)] for i in idx]
else:
self.sympy_data_[node.output[0]] = data[int(idx)]
else:
......@@ -896,8 +881,8 @@ class SymbolicShapeInference:
def _infer_If(self, node):
# special case for constant condition, in case there are mismatching shape from the non-executed branch
subgraphs = [
get_attribute(node, 'then_branch'),
get_attribute(node, 'else_branch')
get_attribute(node, 'then_branch'), get_attribute(node,
'else_branch')
]
cond = self._try_get_value(node, 0)
if cond is not None:
......@@ -976,11 +961,14 @@ class SymbolicShapeInference:
0], vi.type.tensor_type.elem_type, [input_rank, nz_len]))
def _infer_OneHot(self, node):
shape = self._get_shape(node, 0)
sympy_shape = self._get_sympy_shape(node, 0)
depth = self._try_get_value(node, 1)
axis = get_attribute(node, 'axis', -1)
axis = handle_negative_axis(axis, len(shape) + 1)
new_shape = shape[:axis] + [self._new_symbolic_dim_from_output(node)
] + shape[axis:]
axis = handle_negative_axis(axis, len(sympy_shape) + 1)
new_shape = get_shape_from_sympy_shape(sympy_shape[:axis] + [
self._new_symbolic_dim_from_output(node)
if not is_literal(depth) else depth
] + sympy_shape[axis:])
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
helper.make_tensor_value_info(node.output[0], self.known_vi_[
......@@ -1125,11 +1113,16 @@ class SymbolicShapeInference:
sympy.simplify(sympy.floor(s)) for s in sizes
]
self._update_computed_dims(new_sympy_shape)
elif roi is not None and scales is not None:
elif scales is not None:
rank = len(scales)
if get_attribute(node, 'coordinate_transformation_mode'
) == 'tf_crop_and_resize':
assert len(roi) == 2 * rank
roi_start = list(roi)[:rank]
roi_end = list(roi)[rank:]
else:
roi_start = [0] * rank
roi_end = [1] * rank
scales = list(scales)
new_sympy_shape = [
sympy.simplify(sympy.floor(d * (end - start) * scale))
......@@ -1265,8 +1258,8 @@ class SymbolicShapeInference:
e = new_sympy_shape[i]
except Exception:
print(
'Unable to determine if {} <= {}, treat as equal'
.format(e, new_sympy_shape[i]))
'Unable to determine if {} <= {}, treat as equal'.
format(e, new_sympy_shape[i]))
e = new_sympy_shape[i]
if is_literal(s) and int(s) < 0:
......@@ -1290,7 +1283,7 @@ class SymbolicShapeInference:
self.sympy_data_[node.output[0]] = self.sympy_data_[node.input[0]][
starts[0]:ends[0]]
def _infer_Split(self, node):
def _infer_Split_Common(self, node, make_value_info_func):
input_sympy_shape = self._get_sympy_shape(node, 0)
axis = handle_negative_axis(
get_attribute(node, 'axis', 0), len(input_sympy_shape))
......@@ -1306,14 +1299,20 @@ class SymbolicShapeInference:
for i_o in range(len(split)):
vi = self.known_vi_[node.output[i_o]]
vi.CopyFrom(
helper.make_tensor_value_info(
node.output[i_o], self.known_vi_[node.input[
0]].type.tensor_type.elem_type,
get_shape_from_sympy_shape(input_sympy_shape[:axis] + [
make_value_info_func(node.output[i_o], self.known_vi_[
node.input[0]].type.tensor_type.elem_type,
get_shape_from_sympy_shape(
input_sympy_shape[:axis] + [
split[i_o]
] + input_sympy_shape[axis + 1:])))
self.known_vi_[vi.name] = vi
def _infer_Split(self, node):
self._infer_Split_Common(node, helper.make_tensor_value_info)
def _infer_SplitToSequence(self, node):
self._infer_Split_Common(node, helper.make_sequence_value_info)
def _infer_Squeeze(self, node):
self._pass_on_sympy_data(node)
......@@ -1416,10 +1415,18 @@ class SymbolicShapeInference:
self._onnx_infer_single_node(node)
if node.op_type in self.dispatcher_:
self.dispatcher_[node.op_type](node)
elif node.op_type in ['ConvTranspose']:
# onnx shape inference ops like ConvTranspose may have empty shape for symbolic input
# before adding symbolic compute for them
# mark the output type as UNDEFINED to allow guessing of rank
vi = self.known_vi_[node.output[0]]
if len(vi.type.tensor_type.shape.dim) == 0:
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
if self.verbose_ > 2:
print(node.op_type + ': ' + node.name)
for i, name in enumerate(node.input):
print(' Input {}: {} {}€5€5€5€5€5'.format(
print(' Input {}: {} {}'.format(
i, name, 'initializer'
if name in self.initializers_ else ''))
......@@ -1443,6 +1450,7 @@ class SymbolicShapeInference:
]
if len(in_dims) > 1:
self._check_merged_dims(in_dims, allow_broadcast=True)
for i_o in range(len(node.output)):
vi = self.known_vi_[node.output[i_o]]
out_type = vi.type
......@@ -1473,16 +1481,22 @@ class SymbolicShapeInference:
if node.op_type in [
'MatMul', 'MatMulInteger', 'MatMulInteger16'
]:
if None in out_shape:
idx = out_shape.index(None)
dim_idx = [
len(s) - len(out_shape) + idx
for s in shapes
]
# only support auto merge for MatMul for dim < rank-2 when rank > 2
assert len(shapes[0]) > 2 and dim_idx[0] < len(
shapes[0]) - 2
assert len(shapes[1]) > 2 and dim_idx[1] < len(
shapes[1]) - 2
assert len(shapes[0]) > 2 and dim_idx[
0] < len(shapes[0]) - 2
assert len(shapes[1]) > 2 and dim_idx[
1] < len(shapes[1]) - 2
elif node.op_type == 'Expand':
# auto merge for cases like Expand([min(batch, 1), min(seq, 512)], [batch, seq])
shapes = [
self._get_shape(node, 0),
self._get_value(node, 1)
self._get_shape(node, 0), self._get_value(node,
1)
]
else:
shapes = []
......@@ -1531,8 +1545,8 @@ class SymbolicShapeInference:
if self.verbose_ > 0:
if is_unknown_op:
print(
"Possible unknown op: {} node: {}, guessing {} shape"
.format(node.op_type, node.name,
"Possible unknown op: {} node: {}, guessing {} shape".
format(node.op_type, node.name,
vi.name))
if self.verbose_ > 2:
print(' {}: {} {}'.format(
......@@ -1555,24 +1569,26 @@ class SymbolicShapeInference:
if self.auto_merge_ and not out_type_undefined:
print('Merging: ' + str(self.suggested_merge_))
return False
self.run_ = False
return True
def _update_output_from_vi(self):
for output in self.out_mp_.graph.output:
if output.name in self.known_vi_:
tmp_output = self.known_vi_[output.name]
output.CopyFrom(tmp_output)
output.CopyFrom(self.known_vi_[output.name])
@staticmethod
def infer_shapes(in_mp,
int_max=2**31 - 1,
fixed_input_shape=None,
auto_merge=True,
int_max=2**31 - 1,
auto_merge=False,
guess_output_rank=False,
verbose=0):
if get_opset(in_mp) < 7:
print('Only support shape inferencing models of opset 7 and above.')
assert version.parse(onnx.__version__) >= version.parse("1.5.0")
onnx_opset = get_opset(in_mp)
if not onnx_opset or onnx_opset < 7:
print('Only support models of onnx opset 7 and above.')
return
symbolic_shape_inference = SymbolicShapeInference(
int_max, auto_merge, guess_output_rank, verbose)
......@@ -1588,9 +1604,8 @@ class SymbolicShapeInference:
print('!' * 10)
symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes(
symbolic_shape_inference.out_mp_)
#onnx.save(symbolic_shape_inference.out_mp_, 'tmp.onnx')
except:
print('Stopping at incomplete shape inference')
print('Stopping at incomplete symbolic shape inference')
symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes(
symbolic_shape_inference.out_mp_)
return symbolic_shape_inference.out_mp_.graph
......@@ -104,14 +104,6 @@ class OpSet9():
default_op_mapping = {
'Shape': ['shape', ['X'], ['Out']],
'Clip': [
'clip', ['X'], ['Out'], dict(), dict(
min=(np.asarray(
[255, 255, 127, 255], dtype=np.uint8).view(np.float32)[0]),
max=(np.asarray(
[255, 255, 127, 127], dtype=np.uint8).view(np.float32)[0]),
)
],
'Erf': ['erf', ['X'], ['Out']],
'Ceil': ['ceil', ['X'], ['Out']],
'ReduceMean': [
......@@ -357,6 +349,7 @@ class OpSet9():
'Warnning: paddle not support op:resize wiht mode: linear, we use bilinear replace linear'
)
fluid_op = 'resize_bilinear'
attr['align_corners'] = False
node.fluid_code.add_layer(
fluid_op, inputs=inputs, output=node, param_attr=attr)
......@@ -745,53 +738,59 @@ class OpSet9():
param_attr=None)
else:
input_inner_indices = node.layer_name + '_input_inner_indices'
shape = val_x.out_shapes[0]
node.fluid_code.add_layer(
'reshape',
inputs=indices.layer_name,
output=indices.layer_name,
param_attr={'shape': indices.out_shapes[0]})
zeros_like_val_x = val_x.layer_name + '_zeros'
node.fluid_code.add_layer(
'scatter_nd',
'zeros_like',
inputs=val_x,
output=zeros_like_val_x,
param_attr=None)
node.fluid_code.add_layer(
'scatter_nd_add',
inputs={
'shape': val_x.out_shapes[0],
'ref': zeros_like_val_x,
'index': indices,
'updates': updates
},
output=input_inner_indices,
param_attr=None)
indices_mask = node.layer_name + '_indices_mask'
constant_minus_one = node.layer_name + '_constant_minus_one'
# full_like support create tensor shape like input tensor
node.fluid_code.add_layer(
'fill_constant',
inputs=None,
'full_like',
inputs=updates,
output=constant_minus_one,
param_attr={
'shape': updates.out_shapes[0],
'dtype': string(updates.dtype),
'value': -1
})
indices_mask = node.layer_name + '_indices_mask'
param_attr={'dtype': string(updates.dtype),
'fill_value': -1})
node.fluid_code.add_layer(
'scatter_nd',
'scatter_nd_add',
inputs={
'shape': val_x.out_shapes[0],
'ref': zeros_like_val_x,
'index': indices,
'updates': constant_minus_one
},
output=indices_mask,
param_attr=None)
constant_1 = node.layer_name + '_constant_1'
constant_one = node.layer_name + '_constant_1'
# full_like support create tensor shape like input tensor
node.fluid_code.add_layer(
'fill_constant',
inputs=None,
output=constant_1,
param_attr={
'shape': val_x.out_shapes[0],
'dtype': string(val_x.dtype),
'value': 1
})
'full_like',
inputs=val_x,
output=constant_one,
param_attr={'dtype': string(val_x.dtype),
'fill_value': 1})
input_out_indices_mask = node.layer_name + '_input_out_indices_mask'
node.fluid_code.add_layer(
"elementwise_add",
inputs={"x": indices_mask,
"y": constant_1},
"y": constant_one},
output=input_out_indices_mask,
param_attr=None)
......@@ -831,27 +830,35 @@ class OpSet9():
if len(node.inputs) > 1:
starts = self.graph.get_input_node(node, idx=1, copy=True)
ends = self.graph.get_input_node(node, idx=2, copy=True)
starts_value = _const_weight_or_none(starts)
ends_value = _const_weight_or_none(ends)
if len(node.inputs) > 3:
axes = self.graph.get_input_node(node, idx=3, copy=True)
axes = _const_weight_or_none(axes, necessary=True)
if len(node.inputs) > 4:
steps = self.graph.get_input_node(node, idx=4, copy=True)
steps = _const_weight_or_none(steps)
if steps is not None:
assert steps == 1, "Only support convert op:Slice, which attribute:steps == 1"
attr = {
"axes": axes,
"starts": starts.layer_name,
"ends": ends.layer_name
}
starts_value = _const_weight_or_none(starts)
ends_value = _const_weight_or_none(ends)
if starts_value is not None and ends_value is not None:
self.omit_nodes.append(starts.layer_name)
self.omit_nodes.append(ends.layer_name)
starts_value = starts_value.copy()
ends_value = ends_value.copy()
#for idx in range(len(ends_value)):
# if ends_value[idx] > 2**31 - 1:
# ends_value[idx] = 2**31 - 1
#print(val_x.out_shapes)
for idx in range(len(ends_value)):
if ends_value[idx] > 2**31 - 1:
if starts_value[idx] >= val_x.out_shapes[0][axes[idx]]:
starts_value[idx] = val_x.out_shapes[0][axes[idx]] - 1
ends_value[idx] = val_x.out_shapes[0][axes[idx]]
starts_value[idx] = val_x.out_shapes[0][axes[idx]] - 1
elif ends_value[idx] > 2**31 - 1:
ends_value[idx] = 2**31 - 1
attr = {
"axes": axes,
......@@ -884,6 +891,11 @@ class OpSet9():
ends[idx] = 2**31 - 1
attr = {"axes": axes, "starts": starts, "ends": ends}
if steps is not None:
attr['strides'] = steps
node.fluid_code.add_layer(
'strided_slice', inputs=val_x, output=node, param_attr=attr)
else:
node.fluid_code.add_layer(
'slice', inputs=val_x, output=node, param_attr=attr)
......@@ -907,6 +919,38 @@ class OpSet9():
node.fluid_code.add_layer(
'fill_constant', inputs=None, output=node, param_attr=attr)
@print_mapping_info
def Clip(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
max_value, min_value = None, None
if len(node.inputs) == 1:
max_value = node.get_attr('max')
min_value = node.get_attr('min')
attr = {
'max': max_value,
'min': min_value,
}
node.fluid_code.add_layer(
'clip', inputs=val_x, output=node, param_attr=attr)
else:
max_ipt = self.graph.get_input_node(node, idx=1, copy=True)
min_ipt = self.graph.get_input_node(node, idx=2, copy=True)
max_value = _const_weight_or_none(max_ipt)
min_value = _const_weight_or_none(min_ipt)
self.omit_nodes.append(max_ipt.layer_name)
self.omit_nodes.append(min_ipt.layer_name)
if max_value.shape == (1, ):
max_value = max_value[0]
if min_value.shape == (1, ):
min_value = min_value[0]
if max_value is not None and min_value is not None:
attr = {'max': max_value, 'min': min_value}
node.fluid_code.add_layer(
'clip', inputs=val_x, output=node, param_attr=attr)
else:
raise
@print_mapping_info
def Split(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册