未验证 提交 2fa50566 编写于 作者: J Jason 提交者: GitHub

Merge pull request #347 from Channingss/Conv

add cast(int64) for embedding, fix bug of python2, delete repeated Geater
......@@ -15,7 +15,7 @@ paddlepaddle >= 1.8.0
**按需安装以下依赖**
tensorflow : tensorflow == 1.14.0
caffe : 无
onnx : onnx == 1.6.0
onnx : onnx >= 1.6.0
## 安装
### 安装方式一(推荐)
......
......@@ -170,8 +170,8 @@ def onnx2paddle(model_path, save_dir, params_merge=False):
try:
import onnx
version = onnx.version.version
if version != '1.6.0':
print("[ERROR] onnx==1.6.0 is required")
if version < '1.6.0':
print("[ERROR] onnx>=1.6.0 is required")
return
except:
print("[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\".")
......
......@@ -267,8 +267,9 @@ class SymbolicShapeInference:
if pending_nodes and self.verbose_ > 0:
print('SymbolicShapeInference: orphaned nodes discarded: ')
print('\n'.join(
[n.op_type + ': ' + n.output[0] for n in pending_nodes]))
for n in pending_nodes:
print(n.op_type + ': ' + n.output[0])
if input_shapes is not None:
for input_name, shape in input_shapes.items():
for idx in range(len(self.out_mp_.graph.input)):
......
......@@ -487,16 +487,6 @@ class OpSet9():
node.fluid_code.add_layer(
'hard_shrink', inputs=val_x, output=node, param_attr=attr)
def Greater(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
node.fluid_code.add_layer(
'greater_than',
inputs={'x': val_x,
'y': val_y},
output=node,
param_attr=None)
@print_mapping_info
def Constant(self, node):
val_output = self.graph.get_node(node.layer.output[0], copy=True)
......@@ -566,25 +556,26 @@ class OpSet9():
def Expand(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_shape = self.graph.get_input_node(node, idx=1, copy=True)
if len(val_shape.outputs) == 1:
self.omit_nodes.append(val_shape.layer_name)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
out_shape = node.out_shapes[0]
val_x_dtype = val_x.dtype
name_ones = node.layer_name + '_ones'
attr_ones = {'shape': out_shape, 'dtype': string(val_x_dtype)}
attr_ones = {
'shape': val_shape.layer_name,
'dtype': string(val_x_dtype),
'value': 1
}
node.fluid_code.add_layer(
'ones', inputs=None, output=name_ones, param_attr=attr_ones)
'fill_constant',
inputs=None,
output=name_ones,
param_attr=attr_ones)
inputs = {'x': name_ones, 'y': val_x}
attr = {'name': string(node.layer_name)}
node.fluid_code.add_layer(
'elementwise_mul',
inputs=inputs,
output=node.layer_name,
param_attr=attr)
param_attr=None)
@print_mapping_info
def Gather(self, node):
......@@ -652,9 +643,15 @@ class OpSet9():
elif axis == 0 and len(indices_shape) > 1:
if val_x.out_shapes[0] is not None and isinstance(
val_x, ONNXGraphDataNode):
indices_cast = indices.layer_name + '_cast'
node.fluid_code.add_layer(
'embedding',
'cast',
inputs=indices,
output=indices_cast,
param_attr={'dtype': string('int64')})
node.fluid_code.add_layer(
'embedding',
inputs=indices_cast,
output=node,
use_fluid=True,
param_attr={
......@@ -663,7 +660,6 @@ class OpSet9():
})
else:
from functools import reduce
#indices_shape = [1,7]
reshape_shape = reduce(lambda x, y: x * y, indices_shape)
indices_reshape = indices.layer_name + '_shape'
node.fluid_code.add_layer(
......@@ -703,7 +699,7 @@ class OpSet9():
perm = list(range(len(val_x.out_shapes[0])))
perm = [axis] + perm[:axis] + perm[axis + 1:]
attr_trans = {'perm': perm}
name_trans = val_x.layer_name + '_trans'
name_trans = val_x.layer_name + '_transpose'
node.fluid_code.add_layer(
'transpose',
inputs=val_x,
......@@ -715,8 +711,12 @@ class OpSet9():
'index': indices_reshape},
output=node,
param_attr=None)
input_transpose = node.layer_name + '_transpose'
node.fluid_code.add_layer(
'transpose', inputs=node, output=node, param_attr=attr_trans)
'transpose',
inputs=node,
output=input_transpose,
param_attr=attr_trans)
val_x_shape = val_x.out_shapes[0]
reshaped_shape = []
for i in perm:
......@@ -725,7 +725,7 @@ class OpSet9():
reshaped_shape.append(i)
node.fluid_code.add_layer(
'reshape',
inputs=node,
inputs=input_transpose,
output=node,
param_attr={'shape': reshaped_shape})
......@@ -859,17 +859,21 @@ class OpSet9():
}
else:
if starts.dtype != 'int32':
starts_cast = starts.layer_name + '_cast'
node.fluid_code.add_layer(
'cast',
inputs=starts,
output=starts,
output=starts_cast,
param_attr={'dtype': string('int32')})
attr['starts'] = starts_cast
if ends.dtype != 'int32':
ends_cast = ends.layer_name + '_cast'
node.fluid_code.add_layer(
'cast',
inputs=ends,
output=ends,
output=ends_cast,
param_attr={'dtype': string('int32')})
attr['ends'] = ends_cast
else:
starts = node.get_attr('starts')
ends = node.get_attr('ends')
......@@ -1138,7 +1142,7 @@ class OpSet9():
x_shape = val_x.out_shapes[0]
y_shape = val_y.out_shapes[0]
inputs = {"x": val_x, "y": val_y}
if y_shape[0] == 1 and x_shape[-1] != 1:
if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1:
y_squeeze = val_y.layer_name + '_squeeze'
node.fluid_code.add_layer(
"squeeze",
......@@ -1286,7 +1290,6 @@ class OpSet9():
'y': cast_condition},
output=mul_val_x,
param_attr=None)
mul_val_y = val_y.layer_name + '_mul'
node.fluid_code.add_layer(
"elementwise_mul",
......@@ -1339,7 +1342,8 @@ class OpSet9():
if val_repeats.dtype != 'int32':
attr = {"dtype": string("int32")}
node.fluid_code.add_layer(
"cast", inputs=repeats,
"cast",
inputs=repeats,
output="{}.tmp".format(repeats),
param_attr=attr)
repeats = "{}.tmp".format(repeats)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册