提交 4519e2e8 编写于 作者: C Channingss

merge paddle/develop

上级 d43f75b2
......@@ -41,6 +41,27 @@ class OpSet11(OpSet10):
outputs=op.output('Out'), )
return [min_node, max_node, node]
def pad2d(self, op, block):
x_shape = block.var(op.input('X')[0]).shape
paddings = op.attr('paddings')
onnx_pads = []
#TODO support pads is Variable
if op.attr('data_format') == 'NCHW':
pads = [
0, 0, paddings[0], paddings[2], 0, 0, paddings[1], paddings[3]
]
else:
pads = [
0, paddings[0], paddings[2], 0, 0, paddings[1], paddings[3], 0
]
pads_name = self.get_name(op.type, 'pads')
pads_node = self.make_constant_node(pads_name,
onnx_pb.TensorProto.INT64, pads)
constant_value_name = self.get_name(op.type, 'constant_value')
constant_value_node = self.make_constant_node(constant_value_name,
onnx_pb.TensorProto.FLOAT,
op.attr('pad_value'))
def clip(self, op, block):
min_name = self.get_name(op.type, 'min')
max_name = self.get_name(op.type, 'max')
......
......@@ -59,7 +59,7 @@ class OpSet9(object):
'Constant', inputs=[], outputs=[name], value=tensor)
return node
def convert_weights(self, program):
def convert_weights(self, program, scope=None):
var_names = program.global_block().vars
nodes = list()
for name in var_names:
......@@ -68,7 +68,7 @@ class OpSet9(object):
continue
if not var.persistable:
continue
weight = np.array(fluid.global_scope().find_var(name).get_tensor())
weight = np.array(scope.find_var(name).get_tensor())
tensor = helper.make_tensor(
name=name,
dims=var.shape,
......@@ -236,6 +236,28 @@ class OpSet9(object):
pads=op.attr('paddings') + op.attr('paddings'))
return node
def pad2d(self, op, block):
x_shape = block.var(op.input('X')[0]).shape
paddings = op.attr('paddings')
onnx_pads = []
if op.attr('data_format') == 'NCHW':
pads = [
0, 0, paddings[0], paddings[2], 0, 0, paddings[1], paddings[3]
]
else:
pads = [
0, paddings[0], paddings[2], 0, 0, paddings[1], paddings[3], 0
]
#TODO support pads is Variable
node = helper.make_node(
'Pad',
inputs=op.input('X'),
outputs=op.output('Out'),
mode=op.attr('mode'),
value=op.attr('pad_value'),
pads=pads)
return node
def softmax(self, op, block):
axis = op.attr('axis')
shape = block.var(op.output('Out')[0]).shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册