提交 23d489c3 编写于 作者: C Channingss

fix bug & optimize code struct

上级 5b6614fa
......@@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.op_mapper.onnx2paddle.opsets.opset9 import OpSet9
from x2paddle.op_mapper.onnx2paddle.opset9 import OpSet9, custom_layers
from x2paddle.core.op_mapper import OpMapper
from x2paddle.op_mapper.onnx_opsets.custom_layer import *
from x2paddle.decoder.onnx_decoder import ONNXGraph, ONNXGraphNode, ONNXGraphDataNode
......
from .opset import OpSet9
from .custom_layer import custom_layers
......@@ -17,6 +17,7 @@ from x2paddle.core.graph import GraphNode
from x2paddle.core.fluid_code import Layer
from x2paddle.core.fluid_code import FluidCode
from x2paddle.core.util import string
from x2paddle.op_mapper.onnx2paddle.opset9.custom_layer import *
from functools import reduce
import numpy as np
import onnx
......@@ -1379,8 +1380,8 @@ class OpSet9():
node, idx=5 - miss_arg_num, copy=True)
x_shape = val_x.out_shapes[0]
assert x_shape[1] == 1, 'only X with batch_size = 1 supported'
print(x_shape)
#assert x_shape[1] == 1, 'only X with batch_size = 1 supported'
assert node.get_attr('clip', None) is None, 'clipping not supported'
hidden_size = node.get_attr('hidden_size', None)
......@@ -1467,8 +1468,8 @@ class OpSet9():
inputs=val_b,
output=var_bi + ',' + var_bh,
param_attr={
'axis': 1,
'split': [hidden_size * 3, hidden_size * 3],
'dim': 1,
'num_or_sections': [hidden_size * 3, hidden_size * 3],
'name': string(node.layer_name + '.b/split')
})
var_bi0 = node.layer_name + '_bi0'
......@@ -1480,7 +1481,7 @@ class OpSet9():
'name': string(var_bi0)})
node.fluid_code.add_layer(
'elmentwise_add',
'elementwise_add',
inputs=[var_mm, var_bi0],
output=var_fc,
param_attr={
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册