提交 0fe0fc12 编写于 作者: C Channingss

update elementwise_ops for paddle1.8

上级 1efaccda
......@@ -243,62 +243,7 @@ class OpSet9():
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
val_y_shape = val_y.out_shapes[0]
val_x_shape = val_x.out_shapes[0]
inputs = {}
if len(val_x_shape) < len(val_y_shape):
if node.layer_type in ['Mul', 'Add']:
val_x, val_y = val_y, val_x
val_y_shape, val_x_shape = val_x_shape, val_y_shape
inputs = {'x': val_x, 'y': val_y}
elif node.layer_type in ['Sub', 'Div', 'Pow']:
val_x_expand = val_x.layer_name + '_expand'
x_value = _const_weight_or_none(val_x)
if (val_x_shape == [1] or len(val_x_shape) == 0) and x_value:
attr = {
'shape': val_y_shape,
'dtype': string(val_x.dtype),
'value': x_value
if len(val_x_shape) == 0 else x_value[0]
}
node.fluid_code.add_layer(
'fill_constant',
inputs=None,
output=val_x_expand,
param_attr=attr)
val_x_shape = val_y_shape
inputs = {'x': val_x_expand, 'y': val_y}
else:
assert 'Unsupported situation happened.'
else:
inputs = {'x': val_x, 'y': val_y}
print(node.layer_name)
print(val_x_shape, val_y_shape)
str_y_shape = ','.join(str(e) for e in val_y_shape)
str_x_shape = ','.join(str(e) for e in val_x_shape)
slice_idx = 0
if str_y_shape not in str_x_shape:
for dim in val_y_shape:
if dim == 1:
slice_idx += 1
else:
break
if slice_idx < len(val_y_shape) and slice_idx > 0:
val_y_reshaped = val_y_shape[slice_idx:]
var_y_reshaped = val_y.layer_name + '_reshaped'
attr_reshaped = {
'shape': val_y_reshaped,
'name': string(var_y_reshaped)
}
node.fluid_code.add_layer(
'reshape',
inputs=val_y,
output=var_y_reshaped,
param_attr=attr_reshaped)
inputs['y'] = var_y_reshaped
node.fluid_code.add_layer(
op_type, inputs=inputs, output=node, param_attr=None)
else:
node.fluid_code.add_layer(
op_type, inputs=inputs, output=node, param_attr=None)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册