diff --git a/fleetrec/models/base.py b/fleetrec/models/base.py index a309d37fafa83bf9fd146ddf9e2def8ac61c1538..22cedc067c45a192da58e4796bcba5b4aab0d3c8 100644 --- a/fleetrec/models/base.py +++ b/fleetrec/models/base.py @@ -75,13 +75,6 @@ class Model(object): def __init__(self, config): """R """ - self._config = config - self._name = config['name'] - f = open(config['layer_file'], 'r') - self._build_nodes = yaml.safe_load(f.read()) - self._build_phase = ['input', 'param', 'summary', 'layer'] - self._build_param = {'layer': {}, 'inner_layer': {}, 'layer_extend': {}, 'model': {}} - self._inference_meta = {'dependency': {}, 'params': {}} self._cost = None self._metrics = {} self._data_var = [] @@ -130,44 +123,6 @@ class Model(object): """ pass - def inference_params(self, inference_layer): - """ - get params name for inference_layer - Args: - inference_layer(str): layer for inference - Return: - params(list): params name list that for inference layer - """ - layer = inference_layer - if layer in self._inference_meta['params']: - return self._inference_meta['params'][layer] - - self._inference_meta['params'][layer] = [] - self._inference_meta['dependency'][layer] = self.get_dependency(self._build_param['inner_layer'], layer) - for node in self._build_nodes['layer']: - if node['name'] not in self._inference_meta['dependency'][layer]: - continue - if 'inference_param' in self._build_param['layer_extend'][node['name']]: - self._inference_meta['params'][layer] += \ - self._build_param['layer_extend'][node['name']]['inference_param']['params'] - return self._inference_meta['params'][layer] - - def get_dependency(self, layer_graph, dest_layer): - """ - get model of dest_layer depends on - Args: - layer_graph(dict) : all model in graph - Return: - depend_layers(list) : sub-graph model for calculate dest_layer - """ - dependency_list = [] - if dest_layer in layer_graph: - dependencys = copy.deepcopy(layer_graph[dest_layer]['input']) - dependency_list = copy.deepcopy(dependencys) - for dependency in dependencys: - dependency_list = dependency_list + self.get_dependency(layer_graph, dependency) - return list(set(dependency_list)) - class YamlModel(Model): """R @@ -177,7 +132,13 @@ class YamlModel(Model): """R """ Model.__init__(self, config) - pass + self._config = config + self._name = config['name'] + f = open(config['layer_file'], 'r') + self._build_nodes = yaml.safe_load(f.read()) + self._build_phase = ['input', 'param', 'summary', 'layer'] + self._build_param = {'layer': {}, 'inner_layer': {}, 'layer_extend': {}, 'model': {}} + self._inference_meta = {'dependency': {}, 'params': {}} def build_model(self): """R @@ -289,3 +250,41 @@ class YamlModel(Model): program, vars=params_var_list, filename=params_file_name) else: fluid.io.save_vars(executor, params_file_name, program, vars=params_var_list) + + def inference_params(self, inference_layer): + """ + get params name for inference_layer + Args: + inference_layer(str): layer for inference + Return: + params(list): params name list that for inference layer + """ + layer = inference_layer + if layer in self._inference_meta['params']: + return self._inference_meta['params'][layer] + + self._inference_meta['params'][layer] = [] + self._inference_meta['dependency'][layer] = self.get_dependency(self._build_param['inner_layer'], layer) + for node in self._build_nodes['layer']: + if node['name'] not in self._inference_meta['dependency'][layer]: + continue + if 'inference_param' in self._build_param['layer_extend'][node['name']]: + self._inference_meta['params'][layer] += \ + self._build_param['layer_extend'][node['name']]['inference_param']['params'] + return self._inference_meta['params'][layer] + + def get_dependency(self, layer_graph, dest_layer): + """ + get model of dest_layer depends on + Args: + layer_graph(dict) : all model in graph + Return: + depend_layers(list) : sub-graph model for calculate dest_layer + """ + dependency_list = [] + if dest_layer in layer_graph: + dependencys = copy.deepcopy(layer_graph[dest_layer]['input']) + dependency_list = copy.deepcopy(dependencys) + for dependency in dependencys: + dependency_list = dependency_list + self.get_dependency(layer_graph, dependency) + return list(set(dependency_list)) \ No newline at end of file