未验证 提交 ac63c094 编写于 作者: J Jason 提交者: GitHub

Merge pull request #313 from Channingss/paddle_onnx

fix bug & optimize code struct
......@@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.op_mapper.onnx2paddle.opsets.opset9 import OpSet9
from x2paddle.op_mapper.onnx2paddle.opset9 import OpSet9, custom_layers
from x2paddle.core.op_mapper import OpMapper
from x2paddle.op_mapper.onnx_opsets.custom_layer import *
from x2paddle.decoder.onnx_decoder import ONNXGraph, ONNXGraphNode, ONNXGraphDataNode
......
from .opset import OpSet9
from .custom_layer import custom_layers
......@@ -17,6 +17,7 @@ from x2paddle.core.graph import GraphNode
from x2paddle.core.fluid_code import Layer
from x2paddle.core.fluid_code import FluidCode
from x2paddle.core.util import string
from x2paddle.op_mapper.onnx2paddle.opset9.custom_layer import *
from functools import reduce
import numpy as np
import onnx
......@@ -1467,8 +1468,8 @@ class OpSet9():
inputs=val_b,
output=var_bi + ',' + var_bh,
param_attr={
'axis': 1,
'split': [hidden_size * 3, hidden_size * 3],
'dim': 1,
'num_or_sections': [hidden_size * 3, hidden_size * 3],
'name': string(node.layer_name + '.b/split')
})
var_bi0 = node.layer_name + '_bi0'
......@@ -1480,7 +1481,7 @@ class OpSet9():
'name': string(var_bi0)})
node.fluid_code.add_layer(
'elmentwise_add',
'elementwise_add',
inputs=[var_mm, var_bi0],
output=var_fc,
param_attr={
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册