提交 1da94211 编写于 作者: C Channingss

support ONNX >=1.6.0

上级 9270b81d
...@@ -15,7 +15,7 @@ paddlepaddle >= 1.8.0 ...@@ -15,7 +15,7 @@ paddlepaddle >= 1.8.0
**按需安装以下依赖** **按需安装以下依赖**
tensorflow : tensorflow == 1.14.0 tensorflow : tensorflow == 1.14.0
caffe : 无 caffe : 无
onnx : onnx == 1.6.0 onnx : onnx >= 1.6.0
## 安装 ## 安装
### 安装方式一(推荐) ### 安装方式一(推荐)
......
...@@ -170,8 +170,8 @@ def onnx2paddle(model_path, save_dir, params_merge=False): ...@@ -170,8 +170,8 @@ def onnx2paddle(model_path, save_dir, params_merge=False):
try: try:
import onnx import onnx
version = onnx.version.version version = onnx.version.version
if version != '1.6.0': if version < '1.6.0':
print("[ERROR] onnx==1.6.0 is required") print("[ERROR] onnx>=1.6.0 is required")
return return
except: except:
print("[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\".") print("[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\".")
......
...@@ -642,14 +642,15 @@ class OpSet9(): ...@@ -642,14 +642,15 @@ class OpSet9():
elif axis == 0 and len(indices_shape) > 1: elif axis == 0 and len(indices_shape) > 1:
if val_x.out_shapes[0] is not None and isinstance( if val_x.out_shapes[0] is not None and isinstance(
val_x, ONNXGraphDataNode): val_x, ONNXGraphDataNode):
indices_cast = indices.layer_name + '_cast'
node.fluid_code.add_layer( node.fluid_code.add_layer(
'cast', 'cast',
inputs=indices, inputs=indices,
output=indices, output=indices_cast,
param_attr={'dtype': string('int64')}) param_attr={'dtype': string('int64')})
node.fluid_code.add_layer( node.fluid_code.add_layer(
'embedding', 'embedding',
inputs=indices, inputs=indices_cast,
output=node, output=node,
use_fluid=True, use_fluid=True,
param_attr={ param_attr={
...@@ -1140,7 +1141,7 @@ class OpSet9(): ...@@ -1140,7 +1141,7 @@ class OpSet9():
x_shape = val_x.out_shapes[0] x_shape = val_x.out_shapes[0]
y_shape = val_y.out_shapes[0] y_shape = val_y.out_shapes[0]
inputs = {"x": val_x, "y": val_y} inputs = {"x": val_x, "y": val_y}
if y_shape[0] == 1 and x_shape[-1] != 1: if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1:
y_squeeze = val_y.layer_name + '_squeeze' y_squeeze = val_y.layer_name + '_squeeze'
node.fluid_code.add_layer( node.fluid_code.add_layer(
"squeeze", "squeeze",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册