未验证 提交 2fa50566 编写于 作者: J Jason 提交者: GitHub

Merge pull request #347 from Channingss/Conv

add cast(int64) for embedding, fix bug of python2, delete repeated Geater
...@@ -15,7 +15,7 @@ paddlepaddle >= 1.8.0 ...@@ -15,7 +15,7 @@ paddlepaddle >= 1.8.0
**按需安装以下依赖** **按需安装以下依赖**
tensorflow : tensorflow == 1.14.0 tensorflow : tensorflow == 1.14.0
caffe : 无 caffe : 无
onnx : onnx == 1.6.0 onnx : onnx >= 1.6.0
## 安装 ## 安装
### 安装方式一(推荐) ### 安装方式一(推荐)
......
...@@ -170,8 +170,8 @@ def onnx2paddle(model_path, save_dir, params_merge=False): ...@@ -170,8 +170,8 @@ def onnx2paddle(model_path, save_dir, params_merge=False):
try: try:
import onnx import onnx
version = onnx.version.version version = onnx.version.version
if version != '1.6.0': if version < '1.6.0':
print("[ERROR] onnx==1.6.0 is required") print("[ERROR] onnx>=1.6.0 is required")
return return
except: except:
print("[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\".") print("[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\".")
......
...@@ -267,8 +267,9 @@ class SymbolicShapeInference: ...@@ -267,8 +267,9 @@ class SymbolicShapeInference:
if pending_nodes and self.verbose_ > 0: if pending_nodes and self.verbose_ > 0:
print('SymbolicShapeInference: orphaned nodes discarded: ') print('SymbolicShapeInference: orphaned nodes discarded: ')
print('\n'.join( for n in pending_nodes:
[n.op_type + ': ' + n.output[0] for n in pending_nodes])) print(n.op_type + ': ' + n.output[0])
if input_shapes is not None: if input_shapes is not None:
for input_name, shape in input_shapes.items(): for input_name, shape in input_shapes.items():
for idx in range(len(self.out_mp_.graph.input)): for idx in range(len(self.out_mp_.graph.input)):
......
...@@ -487,16 +487,6 @@ class OpSet9(): ...@@ -487,16 +487,6 @@ class OpSet9():
node.fluid_code.add_layer( node.fluid_code.add_layer(
'hard_shrink', inputs=val_x, output=node, param_attr=attr) 'hard_shrink', inputs=val_x, output=node, param_attr=attr)
def Greater(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
node.fluid_code.add_layer(
'greater_than',
inputs={'x': val_x,
'y': val_y},
output=node,
param_attr=None)
@print_mapping_info @print_mapping_info
def Constant(self, node): def Constant(self, node):
val_output = self.graph.get_node(node.layer.output[0], copy=True) val_output = self.graph.get_node(node.layer.output[0], copy=True)
...@@ -566,25 +556,26 @@ class OpSet9(): ...@@ -566,25 +556,26 @@ class OpSet9():
def Expand(self, node): def Expand(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_shape = self.graph.get_input_node(node, idx=1, copy=True) val_shape = self.graph.get_input_node(node, idx=1, copy=True)
if len(val_shape.outputs) == 1: if len(val_shape.outputs) == 1:
self.omit_nodes.append(val_shape.layer_name) self.omit_nodes.append(val_shape.layer_name)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
out_shape = node.out_shapes[0]
val_x_dtype = val_x.dtype val_x_dtype = val_x.dtype
name_ones = node.layer_name + '_ones' name_ones = node.layer_name + '_ones'
attr_ones = {'shape': out_shape, 'dtype': string(val_x_dtype)} attr_ones = {
'shape': val_shape.layer_name,
'dtype': string(val_x_dtype),
'value': 1
}
node.fluid_code.add_layer( node.fluid_code.add_layer(
'ones', inputs=None, output=name_ones, param_attr=attr_ones) 'fill_constant',
inputs=None,
output=name_ones,
param_attr=attr_ones)
inputs = {'x': name_ones, 'y': val_x} inputs = {'x': name_ones, 'y': val_x}
attr = {'name': string(node.layer_name)}
node.fluid_code.add_layer( node.fluid_code.add_layer(
'elementwise_mul', 'elementwise_mul',
inputs=inputs, inputs=inputs,
output=node.layer_name, output=node.layer_name,
param_attr=attr) param_attr=None)
@print_mapping_info @print_mapping_info
def Gather(self, node): def Gather(self, node):
...@@ -652,9 +643,15 @@ class OpSet9(): ...@@ -652,9 +643,15 @@ class OpSet9():
elif axis == 0 and len(indices_shape) > 1: elif axis == 0 and len(indices_shape) > 1:
if val_x.out_shapes[0] is not None and isinstance( if val_x.out_shapes[0] is not None and isinstance(
val_x, ONNXGraphDataNode): val_x, ONNXGraphDataNode):
indices_cast = indices.layer_name + '_cast'
node.fluid_code.add_layer( node.fluid_code.add_layer(
'embedding', 'cast',
inputs=indices, inputs=indices,
output=indices_cast,
param_attr={'dtype': string('int64')})
node.fluid_code.add_layer(
'embedding',
inputs=indices_cast,
output=node, output=node,
use_fluid=True, use_fluid=True,
param_attr={ param_attr={
...@@ -663,7 +660,6 @@ class OpSet9(): ...@@ -663,7 +660,6 @@ class OpSet9():
}) })
else: else:
from functools import reduce from functools import reduce
#indices_shape = [1,7]
reshape_shape = reduce(lambda x, y: x * y, indices_shape) reshape_shape = reduce(lambda x, y: x * y, indices_shape)
indices_reshape = indices.layer_name + '_shape' indices_reshape = indices.layer_name + '_shape'
node.fluid_code.add_layer( node.fluid_code.add_layer(
...@@ -703,7 +699,7 @@ class OpSet9(): ...@@ -703,7 +699,7 @@ class OpSet9():
perm = list(range(len(val_x.out_shapes[0]))) perm = list(range(len(val_x.out_shapes[0])))
perm = [axis] + perm[:axis] + perm[axis + 1:] perm = [axis] + perm[:axis] + perm[axis + 1:]
attr_trans = {'perm': perm} attr_trans = {'perm': perm}
name_trans = val_x.layer_name + '_trans' name_trans = val_x.layer_name + '_transpose'
node.fluid_code.add_layer( node.fluid_code.add_layer(
'transpose', 'transpose',
inputs=val_x, inputs=val_x,
...@@ -715,8 +711,12 @@ class OpSet9(): ...@@ -715,8 +711,12 @@ class OpSet9():
'index': indices_reshape}, 'index': indices_reshape},
output=node, output=node,
param_attr=None) param_attr=None)
input_transpose = node.layer_name + '_transpose'
node.fluid_code.add_layer( node.fluid_code.add_layer(
'transpose', inputs=node, output=node, param_attr=attr_trans) 'transpose',
inputs=node,
output=input_transpose,
param_attr=attr_trans)
val_x_shape = val_x.out_shapes[0] val_x_shape = val_x.out_shapes[0]
reshaped_shape = [] reshaped_shape = []
for i in perm: for i in perm:
...@@ -725,7 +725,7 @@ class OpSet9(): ...@@ -725,7 +725,7 @@ class OpSet9():
reshaped_shape.append(i) reshaped_shape.append(i)
node.fluid_code.add_layer( node.fluid_code.add_layer(
'reshape', 'reshape',
inputs=node, inputs=input_transpose,
output=node, output=node,
param_attr={'shape': reshaped_shape}) param_attr={'shape': reshaped_shape})
...@@ -859,17 +859,21 @@ class OpSet9(): ...@@ -859,17 +859,21 @@ class OpSet9():
} }
else: else:
if starts.dtype != 'int32': if starts.dtype != 'int32':
starts_cast = starts.layer_name + '_cast'
node.fluid_code.add_layer( node.fluid_code.add_layer(
'cast', 'cast',
inputs=starts, inputs=starts,
output=starts, output=starts_cast,
param_attr={'dtype': string('int32')}) param_attr={'dtype': string('int32')})
attr['starts'] = starts_cast
if ends.dtype != 'int32': if ends.dtype != 'int32':
ends_cast = ends.layer_name + '_cast'
node.fluid_code.add_layer( node.fluid_code.add_layer(
'cast', 'cast',
inputs=ends, inputs=ends,
output=ends, output=ends_cast,
param_attr={'dtype': string('int32')}) param_attr={'dtype': string('int32')})
attr['ends'] = ends_cast
else: else:
starts = node.get_attr('starts') starts = node.get_attr('starts')
ends = node.get_attr('ends') ends = node.get_attr('ends')
...@@ -1138,7 +1142,7 @@ class OpSet9(): ...@@ -1138,7 +1142,7 @@ class OpSet9():
x_shape = val_x.out_shapes[0] x_shape = val_x.out_shapes[0]
y_shape = val_y.out_shapes[0] y_shape = val_y.out_shapes[0]
inputs = {"x": val_x, "y": val_y} inputs = {"x": val_x, "y": val_y}
if y_shape[0] == 1 and x_shape[-1] != 1: if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1:
y_squeeze = val_y.layer_name + '_squeeze' y_squeeze = val_y.layer_name + '_squeeze'
node.fluid_code.add_layer( node.fluid_code.add_layer(
"squeeze", "squeeze",
...@@ -1286,7 +1290,6 @@ class OpSet9(): ...@@ -1286,7 +1290,6 @@ class OpSet9():
'y': cast_condition}, 'y': cast_condition},
output=mul_val_x, output=mul_val_x,
param_attr=None) param_attr=None)
mul_val_y = val_y.layer_name + '_mul' mul_val_y = val_y.layer_name + '_mul'
node.fluid_code.add_layer( node.fluid_code.add_layer(
"elementwise_mul", "elementwise_mul",
...@@ -1339,7 +1342,8 @@ class OpSet9(): ...@@ -1339,7 +1342,8 @@ class OpSet9():
if val_repeats.dtype != 'int32': if val_repeats.dtype != 'int32':
attr = {"dtype": string("int32")} attr = {"dtype": string("int32")}
node.fluid_code.add_layer( node.fluid_code.add_layer(
"cast", inputs=repeats, "cast",
inputs=repeats,
output="{}.tmp".format(repeats), output="{}.tmp".format(repeats),
param_attr=attr) param_attr=attr)
repeats = "{}.tmp".format(repeats) repeats = "{}.tmp".format(repeats)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册