提交 bfcd53da 编写于 作者: T tangwei

code fix

上级 af3dad94
......@@ -75,13 +75,6 @@ class Model(object):
def __init__(self, config):
"""R
"""
self._config = config
self._name = config['name']
f = open(config['layer_file'], 'r')
self._build_nodes = yaml.safe_load(f.read())
self._build_phase = ['input', 'param', 'summary', 'layer']
self._build_param = {'layer': {}, 'inner_layer': {}, 'layer_extend': {}, 'model': {}}
self._inference_meta = {'dependency': {}, 'params': {}}
self._cost = None
self._metrics = {}
self._data_var = []
......@@ -130,44 +123,6 @@ class Model(object):
"""
pass
def inference_params(self, inference_layer):
"""
get params name for inference_layer
Args:
inference_layer(str): layer for inference
Return:
params(list): params name list that for inference layer
"""
layer = inference_layer
if layer in self._inference_meta['params']:
return self._inference_meta['params'][layer]
self._inference_meta['params'][layer] = []
self._inference_meta['dependency'][layer] = self.get_dependency(self._build_param['inner_layer'], layer)
for node in self._build_nodes['layer']:
if node['name'] not in self._inference_meta['dependency'][layer]:
continue
if 'inference_param' in self._build_param['layer_extend'][node['name']]:
self._inference_meta['params'][layer] += \
self._build_param['layer_extend'][node['name']]['inference_param']['params']
return self._inference_meta['params'][layer]
def get_dependency(self, layer_graph, dest_layer):
"""
get model of dest_layer depends on
Args:
layer_graph(dict) : all model in graph
Return:
depend_layers(list) : sub-graph model for calculate dest_layer
"""
dependency_list = []
if dest_layer in layer_graph:
dependencys = copy.deepcopy(layer_graph[dest_layer]['input'])
dependency_list = copy.deepcopy(dependencys)
for dependency in dependencys:
dependency_list = dependency_list + self.get_dependency(layer_graph, dependency)
return list(set(dependency_list))
class YamlModel(Model):
"""R
......@@ -177,7 +132,13 @@ class YamlModel(Model):
"""R
"""
Model.__init__(self, config)
pass
self._config = config
self._name = config['name']
f = open(config['layer_file'], 'r')
self._build_nodes = yaml.safe_load(f.read())
self._build_phase = ['input', 'param', 'summary', 'layer']
self._build_param = {'layer': {}, 'inner_layer': {}, 'layer_extend': {}, 'model': {}}
self._inference_meta = {'dependency': {}, 'params': {}}
def build_model(self):
"""R
......@@ -289,3 +250,41 @@ class YamlModel(Model):
program, vars=params_var_list, filename=params_file_name)
else:
fluid.io.save_vars(executor, params_file_name, program, vars=params_var_list)
def inference_params(self, inference_layer):
"""
get params name for inference_layer
Args:
inference_layer(str): layer for inference
Return:
params(list): params name list that for inference layer
"""
layer = inference_layer
if layer in self._inference_meta['params']:
return self._inference_meta['params'][layer]
self._inference_meta['params'][layer] = []
self._inference_meta['dependency'][layer] = self.get_dependency(self._build_param['inner_layer'], layer)
for node in self._build_nodes['layer']:
if node['name'] not in self._inference_meta['dependency'][layer]:
continue
if 'inference_param' in self._build_param['layer_extend'][node['name']]:
self._inference_meta['params'][layer] += \
self._build_param['layer_extend'][node['name']]['inference_param']['params']
return self._inference_meta['params'][layer]
def get_dependency(self, layer_graph, dest_layer):
"""
get model of dest_layer depends on
Args:
layer_graph(dict) : all model in graph
Return:
depend_layers(list) : sub-graph model for calculate dest_layer
"""
dependency_list = []
if dest_layer in layer_graph:
dependencys = copy.deepcopy(layer_graph[dest_layer]['input'])
dependency_list = copy.deepcopy(dependencys)
for dependency in dependencys:
dependency_list = dependency_list + self.get_dependency(layer_graph, dependency)
return list(set(dependency_list))
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册