提交 71b51e82 编写于 作者: C Channingss

expand support dynamic shape

上级 fe75f532
...@@ -556,25 +556,26 @@ class OpSet9(): ...@@ -556,25 +556,26 @@ class OpSet9():
def Expand(self, node): def Expand(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_shape = self.graph.get_input_node(node, idx=1, copy=True) val_shape = self.graph.get_input_node(node, idx=1, copy=True)
if len(val_shape.outputs) == 1: if len(val_shape.outputs) == 1:
self.omit_nodes.append(val_shape.layer_name) self.omit_nodes.append(val_shape.layer_name)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
out_shape = node.out_shapes[0]
val_x_dtype = val_x.dtype val_x_dtype = val_x.dtype
name_ones = node.layer_name + '_ones' name_ones = node.layer_name + '_ones'
attr_ones = {'shape': out_shape, 'dtype': string(val_x_dtype)} attr_ones = {
'shape': val_shape.layer_name,
'dtype': string(val_x_dtype),
'value': 1
}
node.fluid_code.add_layer( node.fluid_code.add_layer(
'ones', inputs=None, output=name_ones, param_attr=attr_ones) 'fill_constant',
inputs=None,
output=name_ones,
param_attr=attr_ones)
inputs = {'x': name_ones, 'y': val_x} inputs = {'x': name_ones, 'y': val_x}
attr = {'name': string(node.layer_name)}
node.fluid_code.add_layer( node.fluid_code.add_layer(
'elementwise_mul', 'elementwise_mul',
inputs=inputs, inputs=inputs,
output=node.layer_name, output=node.layer_name,
param_attr=attr) param_attr=None)
@print_mapping_info @print_mapping_info
def Gather(self, node): def Gather(self, node):
...@@ -1341,7 +1342,8 @@ class OpSet9(): ...@@ -1341,7 +1342,8 @@ class OpSet9():
if val_repeats.dtype != 'int32': if val_repeats.dtype != 'int32':
attr = {"dtype": string("int32")} attr = {"dtype": string("int32")}
node.fluid_code.add_layer( node.fluid_code.add_layer(
"cast", inputs=repeats, "cast",
inputs=repeats,
output="{}.tmp".format(repeats), output="{}.tmp".format(repeats),
param_attr=attr) param_attr=attr)
repeats = "{}.tmp".format(repeats) repeats = "{}.tmp".format(repeats)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册