提交 ae4db404 编写于 作者: W wjj19950828

fixed Where and ConstantOfShape

上级 a91d37bd
......@@ -1239,7 +1239,6 @@ class OpSet9():
@print_mapping_info
def ConstantOfShape(self, node):
val_shape = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
value = node.get_attr('value')
dtype = value.dtype
......@@ -1248,6 +1247,8 @@ class OpSet9():
'this is not supported')
if len(value) == 1:
value = value[0]
if value == float('inf') or value == float('-inf'):
value = string(value)
layer_attrs = {'dtype': string(dtype), 'fill_value': value}
self.paddle_graph.add_layer(
"paddle.full",
......@@ -1951,40 +1952,13 @@ class OpSet9():
val_x = self.graph.get_input_node(node, idx=1, copy=True)
val_y = self.graph.get_input_node(node, idx=2, copy=True)
not_condition = condition.name + '_not'
self.paddle_graph.add_layer(
"paddle.logical_not",
inputs={"x": condition.name},
outputs=[not_condition])
cast_not_condition = not_condition + '_cast'
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": not_condition},
outputs=[cast_not_condition],
dtype=string(val_x.dtype))
cast_condition = condition.name + '_cast'
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": condition.name},
outputs=[cast_condition],
dtype=string(val_x.dtype))
mul_val_x = val_x.name + '_mul'
self.paddle_graph.add_layer(
"paddle.multiply",
inputs={'x': val_x.name,
'y': cast_condition},
outputs=[mul_val_x])
mul_val_y = val_y.name + '_mul'
self.paddle_graph.add_layer(
"paddle.multiply",
inputs={'x': val_y.name,
'y': cast_not_condition},
outputs=[mul_val_y])
self.paddle_graph.add_layer(
"paddle.add",
inputs={'x': mul_val_x,
'y': mul_val_y},
"paddle.where",
inputs={
'condition': condition.name,
'x': val_x.name,
'y': val_y.name
},
outputs=[node.name])
@print_mapping_info
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册