提交 891296ac 编写于 作者: xiebaiyuan's avatar xiebaiyuan

add persistable params

上级 c0d41a60
......@@ -17,7 +17,7 @@ class Converter:
print mdl_json_path
self.mdl_json = load_mdl(self.mdl_json_path)
self.program_desc = framework_pb2.ProgramDesc()
self.weight_list_ = []
# print(json_dick)
# layers = (json_dick['layer'])
# for layer in layers:
......@@ -33,6 +33,12 @@ class Converter:
self.package_vars(block_desc)
print 'blocks: '
print self.program_desc.blocks
print 'convert end.....'
desc_serialize_to_string = self.program_desc.SerializeToString()
f = open("newyolo/__model__", "wb")
f.write(desc_serialize_to_string)
f.close()
def package_ops(self, block_desc):
......@@ -50,8 +56,10 @@ class Converter:
# print i
if 'name' in layer:
l_name = layer['name']
if 'type' in layer:
self.package_ops_type(desc_ops_add, layer)
if 'weight' in layer:
self.package_ops_weight2inputs(desc_ops_add, layer)
......@@ -62,6 +70,7 @@ class Converter:
self.package_ops_inputs(desc_ops_add, layer)
self.package_ops_attrs(desc_ops_add, layer)
self.add_op_fetch(block_desc)
def add_op_feed(self, block_desc):
......@@ -205,18 +214,19 @@ class Converter:
outputs_add.parameter = types.op_io_dict.get(desc_ops_add.type).get(types.mdl_outputs_key)
outputs_add.arguments.append(o)
@staticmethod
def package_ops_weight2inputs(desc_ops_add, layer):
def package_ops_weight2inputs(self, desc_ops_add, layer):
l_weights = layer['weight']
for w in l_weights:
self.weight_list_.append(w)
op_weight_tup = types.op_io_dict.get(desc_ops_add.type).get(types.mdl_weight_key)
# print len(op_weight_tup)
for i, val in enumerate(op_weight_tup):
# print i
# print val
inputs_add = desc_ops_add.inputs.add()
# print w
inputs_add.parameter = op_weight_tup[i]
inputs_add.arguments.append(l_weights[i])
# for w in l_weights:
# inputs_add = desc_ops_add.inputs.add()
# # print w
......@@ -231,6 +241,25 @@ class Converter:
desc_ops_add.type = types.mdl2fluid_op_layer_dict.get(l_type)
def package_vars(self, block_desc):
# feed
# vars
# {
# name: "feed"
# type {
# type: FEED_MINIBATCH
# }
# persistable: true
# }
vars_add = block_desc.vars.add()
vars_add.name = 'feed'
vars_add.type.type = 9 # 9 is FEED_MINIBATCH
vars_add.persistable = 1
# fetch
vars_add = block_desc.vars.add()
vars_add.name = 'fetch'
vars_add.type.type = 10 # 10 is fetch list
vars_add.persistable = 1
json_matrix_ = self.mdl_json['matrix']
# print json_matrix_
for j in json_matrix_:
......@@ -242,7 +271,10 @@ class Converter:
tensor.data_type = 5 # 5 is FP32
for dims in json_matrix_.get(j):
tensor.dims.append(dims)
pass
if j in self.weight_list_:
vars_add.persistable = 1
else:
vars_add.persistable = 0
# print mdl_path
......
......@@ -26,5 +26,5 @@ def get_file_size(file_path):
return round(fsize, 2)
path = "/Users/xiebaiyuan/PaddleProject/paddle-mobile/python/tools/mdl2fluid/yolo/__model__"
path = "newyolo/__model__"
read_model(path)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册