From 34e0bd573f47b8fff0be45648b6778791a2c04d6 Mon Sep 17 00:00:00 2001 From: xiebaiyuan Date: Tue, 18 Sep 2018 19:27:07 +0800 Subject: [PATCH] add persistable params --- python/tools/mdl2fluid/mdl2fluid.py | 42 +++++++++++++++++++++++--- python/tools/mdl2fluid/model_reader.py | 2 +- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/python/tools/mdl2fluid/mdl2fluid.py b/python/tools/mdl2fluid/mdl2fluid.py index f9805f603d..6a56de13f8 100644 --- a/python/tools/mdl2fluid/mdl2fluid.py +++ b/python/tools/mdl2fluid/mdl2fluid.py @@ -17,7 +17,7 @@ class Converter: print mdl_json_path self.mdl_json = load_mdl(self.mdl_json_path) self.program_desc = framework_pb2.ProgramDesc() - + self.weight_list_ = [] # print(json_dick) # layers = (json_dick['layer']) # for layer in layers: @@ -33,6 +33,12 @@ class Converter: self.package_vars(block_desc) print 'blocks: ' print self.program_desc.blocks + print 'convert end.....' + desc_serialize_to_string = self.program_desc.SerializeToString() + + f = open("newyolo/__model__", "wb") + f.write(desc_serialize_to_string) + f.close() def package_ops(self, block_desc): @@ -50,8 +56,10 @@ class Converter: # print i if 'name' in layer: l_name = layer['name'] + if 'type' in layer: self.package_ops_type(desc_ops_add, layer) + if 'weight' in layer: self.package_ops_weight2inputs(desc_ops_add, layer) @@ -62,6 +70,7 @@ class Converter: self.package_ops_inputs(desc_ops_add, layer) self.package_ops_attrs(desc_ops_add, layer) + self.add_op_fetch(block_desc) def add_op_feed(self, block_desc): @@ -205,18 +214,19 @@ class Converter: outputs_add.parameter = types.op_io_dict.get(desc_ops_add.type).get(types.mdl_outputs_key) outputs_add.arguments.append(o) - @staticmethod - def package_ops_weight2inputs(desc_ops_add, layer): + def package_ops_weight2inputs(self, desc_ops_add, layer): l_weights = layer['weight'] + for w in l_weights: + self.weight_list_.append(w) op_weight_tup = types.op_io_dict.get(desc_ops_add.type).get(types.mdl_weight_key) # print len(op_weight_tup) for i, val in enumerate(op_weight_tup): # print i # print val inputs_add = desc_ops_add.inputs.add() - # print w inputs_add.parameter = op_weight_tup[i] inputs_add.arguments.append(l_weights[i]) + # for w in l_weights: # inputs_add = desc_ops_add.inputs.add() # # print w @@ -231,6 +241,25 @@ class Converter: desc_ops_add.type = types.mdl2fluid_op_layer_dict.get(l_type) def package_vars(self, block_desc): + # feed + # vars + # { + # name: "feed" + # type { + # type: FEED_MINIBATCH + # } + # persistable: true + # } + vars_add = block_desc.vars.add() + vars_add.name = 'feed' + vars_add.type.type = 9 # 9 is FEED_MINIBATCH + vars_add.persistable = 1 + # fetch + vars_add = block_desc.vars.add() + vars_add.name = 'fetch' + vars_add.type.type = 10 # 10 is fetch list + vars_add.persistable = 1 + json_matrix_ = self.mdl_json['matrix'] # print json_matrix_ for j in json_matrix_: @@ -242,7 +271,10 @@ class Converter: tensor.data_type = 5 # 5 is FP32 for dims in json_matrix_.get(j): tensor.dims.append(dims) - pass + if j in self.weight_list_: + vars_add.persistable = 1 + else: + vars_add.persistable = 0 # print mdl_path diff --git a/python/tools/mdl2fluid/model_reader.py b/python/tools/mdl2fluid/model_reader.py index 9a10c57f39..8d53350db2 100644 --- a/python/tools/mdl2fluid/model_reader.py +++ b/python/tools/mdl2fluid/model_reader.py @@ -26,5 +26,5 @@ def get_file_size(file_path): return round(fsize, 2) -path = "/Users/xiebaiyuan/PaddleProject/paddle-mobile/python/tools/mdl2fluid/yolo/__model__" +path = "newyolo/__model__" read_model(path) -- GitLab