提交 23d489c3 编写于 作者: C Channingss

fix bug & optimize code struct

上级 5b6614fa
...@@ -12,9 +12,8 @@ ...@@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from x2paddle.op_mapper.onnx2paddle.opsets.opset9 import OpSet9 from x2paddle.op_mapper.onnx2paddle.opset9 import OpSet9, custom_layers
from x2paddle.core.op_mapper import OpMapper from x2paddle.core.op_mapper import OpMapper
from x2paddle.op_mapper.onnx_opsets.custom_layer import *
from x2paddle.decoder.onnx_decoder import ONNXGraph, ONNXGraphNode, ONNXGraphDataNode from x2paddle.decoder.onnx_decoder import ONNXGraph, ONNXGraphNode, ONNXGraphDataNode
......
from .opset import OpSet9
from .custom_layer import custom_layers
...@@ -17,6 +17,7 @@ from x2paddle.core.graph import GraphNode ...@@ -17,6 +17,7 @@ from x2paddle.core.graph import GraphNode
from x2paddle.core.fluid_code import Layer from x2paddle.core.fluid_code import Layer
from x2paddle.core.fluid_code import FluidCode from x2paddle.core.fluid_code import FluidCode
from x2paddle.core.util import string from x2paddle.core.util import string
from x2paddle.op_mapper.onnx2paddle.opset9.custom_layer import *
from functools import reduce from functools import reduce
import numpy as np import numpy as np
import onnx import onnx
...@@ -1379,8 +1380,8 @@ class OpSet9(): ...@@ -1379,8 +1380,8 @@ class OpSet9():
node, idx=5 - miss_arg_num, copy=True) node, idx=5 - miss_arg_num, copy=True)
x_shape = val_x.out_shapes[0] x_shape = val_x.out_shapes[0]
print(x_shape)
assert x_shape[1] == 1, 'only X with batch_size = 1 supported' #assert x_shape[1] == 1, 'only X with batch_size = 1 supported'
assert node.get_attr('clip', None) is None, 'clipping not supported' assert node.get_attr('clip', None) is None, 'clipping not supported'
hidden_size = node.get_attr('hidden_size', None) hidden_size = node.get_attr('hidden_size', None)
...@@ -1467,8 +1468,8 @@ class OpSet9(): ...@@ -1467,8 +1468,8 @@ class OpSet9():
inputs=val_b, inputs=val_b,
output=var_bi + ',' + var_bh, output=var_bi + ',' + var_bh,
param_attr={ param_attr={
'axis': 1, 'dim': 1,
'split': [hidden_size * 3, hidden_size * 3], 'num_or_sections': [hidden_size * 3, hidden_size * 3],
'name': string(node.layer_name + '.b/split') 'name': string(node.layer_name + '.b/split')
}) })
var_bi0 = node.layer_name + '_bi0' var_bi0 = node.layer_name + '_bi0'
...@@ -1480,7 +1481,7 @@ class OpSet9(): ...@@ -1480,7 +1481,7 @@ class OpSet9():
'name': string(var_bi0)}) 'name': string(var_bi0)})
node.fluid_code.add_layer( node.fluid_code.add_layer(
'elmentwise_add', 'elementwise_add',
inputs=[var_mm, var_bi0], inputs=[var_mm, var_bi0],
output=var_fc, output=var_fc,
param_attr={ param_attr={
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册