提交 051a168e 编写于 作者: xiebaiyuan's avatar xiebaiyuan

add weights

上级 4d59c32d
......@@ -14,6 +14,7 @@ class Converter:
def __init__(self, mdl_json_path):
self.mdl_json_path = mdl_json_path
print mdl_json_path
self.mdl_json = load_mdl(self.mdl_json_path)
self.program_desc = framework_pb2.ProgramDesc()
......@@ -38,39 +39,57 @@ class Converter:
layers_ = self.mdl_json['layer']
for layer in layers_:
desc_ops = block_desc.ops.add()
desc_ops_add = block_desc.ops.add()
# print layer
# for i in layer:
# print i
if 'name' in layer:
l_name = layer['name']
if 'type' in layer:
l_type = layer['type']
# print l_type
# print mdl2fluid_op_layer_dict.get(l_type)
desc_ops_add.type = types.mdl2fluid_op_layer_dict.get(l_type)
if 'weight' in layer:
l_weights = layer['weight']
op_tup = types.op_io_dict.get(desc_ops_add.type).get(types.mdl_weight_key)
# print len(op_tup)
for i, val in enumerate(op_tup):
print i
print val
inputs_add = desc_ops_add.inputs.add()
# print w
inputs_add.parameter = op_tup[i]
inputs_add.arguments.append(l_weights[i])
# for w in l_weights:
# inputs_add = desc_ops_add.inputs.add()
# # print w
# inputs_add.parameter = op_tup[0]
# inputs_add.arguments.append(w)
if 'param' in layer:
l_params = layer['param']
if 'output' in layer:
l_outputs = layer['output']
for o in l_outputs:
# print o
outputs_add = desc_ops_add.outputs.add()
outputs_add.parameter = types.op_io_dict.get(desc_ops_add.type).get(types.mdl_outputs_key)
outputs_add.arguments.append(o)
if 'input' in layer:
l_inputs = layer['input']
inputs_add = desc_ops.inputs.add()
for i in l_inputs:
inputs_add = desc_ops_add.inputs.add()
# print i
inputs_add.parameter = ''
inputs_add.parameter = types.op_io_dict.get(desc_ops_add.type).get(types.mdl_inputs_key)
inputs_add.arguments.append(i)
if 'type' in layer:
l_type = layer['type']
# print l_type
# print mdl2fluid_op_layer_dict.get(l_type)
desc_ops.type = types.mdl2fluid_op_layer_dict.get(l_type)
mdl_path = "multiobjects/YOLO_Universal.json"
# print mdl_path
# # model
# mdl_model = load_mdl(mdl_path)
......@@ -107,6 +126,6 @@ mdl_path = "multiobjects/YOLO_Universal.json"
#
# package()
mdl_path = "/Users/xiebaiyuan/PaddleProject/paddle-mobile/python/tools/mdl2fluid/multiobjects/YOLO_Universal.json"
converter = Converter(mdl_path)
converter.convert()
......@@ -5,12 +5,25 @@ mdl2fluid_op_layer_dict = {
'PointwiseConvolutionLayer': 'fusion_conv_add'
}
mdl_outputs_key = "outputs"
mdl_inputs_key = "inputs"
mdl_weight_key = "weights"
# inputs_key = "inputs"
fusion_conv_add_dict = {
'inputs': 'Input',
'outputs': 'Out'
mdl_inputs_key: 'Input',
mdl_outputs_key: 'Out',
mdl_weight_key: ('Filter', 'Y')
}
relu_dict = {
'inputs': 'X',
'outputs': 'Out'
mdl_inputs_key: 'X',
mdl_outputs_key: 'Out',
mdl_weight_key: ()
}
op_io_dict = {
'fusion_conv_add': fusion_conv_add_dict,
'relu': relu_dict
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册