提交 71b51e82 编写于 作者: C Channingss

expand support dynamic shape

上级 fe75f532
......@@ -556,25 +556,26 @@ class OpSet9():
def Expand(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_shape = self.graph.get_input_node(node, idx=1, copy=True)
if len(val_shape.outputs) == 1:
self.omit_nodes.append(val_shape.layer_name)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
out_shape = node.out_shapes[0]
val_x_dtype = val_x.dtype
name_ones = node.layer_name + '_ones'
attr_ones = {'shape': out_shape, 'dtype': string(val_x_dtype)}
attr_ones = {
'shape': val_shape.layer_name,
'dtype': string(val_x_dtype),
'value': 1
}
node.fluid_code.add_layer(
'ones', inputs=None, output=name_ones, param_attr=attr_ones)
'fill_constant',
inputs=None,
output=name_ones,
param_attr=attr_ones)
inputs = {'x': name_ones, 'y': val_x}
attr = {'name': string(node.layer_name)}
node.fluid_code.add_layer(
'elementwise_mul',
inputs=inputs,
output=node.layer_name,
param_attr=attr)
param_attr=None)
@print_mapping_info
def Gather(self, node):
......@@ -1341,7 +1342,8 @@ class OpSet9():
if val_repeats.dtype != 'int32':
attr = {"dtype": string("int32")}
node.fluid_code.add_layer(
"cast", inputs=repeats,
"cast",
inputs=repeats,
output="{}.tmp".format(repeats),
param_attr=attr)
repeats = "{}.tmp".format(repeats)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册