提交 7408da55 编写于 作者: T tangwei

fix ctr trainer

上级 9368aec8
......@@ -12,28 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import sys
import copy
import yaml
import time
import json
import datetime
import numpy as np
import paddle.fluid as fluid
from .. utils import fs as fs
from .. utils import util as util
from .. metrics.auc_metrics import AUCMetric
from .. models import base as model_basic
from .. reader import dataset
from . import trainer
from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
from paddle.fluid.incubate.fleet.base.role_maker import GeneralRoleMaker
from eleps.utils import fs as fs
from eleps.utils import util as util
from eleps.metrics.auc_metrics import AUCMetric
from eleps.models import base as model_basic
from eleps.reader import dataset
from .trainer import Trainer
def wroker_numric_opt(value, env, opt):
"""
......@@ -75,21 +71,24 @@ def worker_numric_max(value, env="mpi"):
return wroker_numric_opt(value, env, "max")
class CtrPaddleTrainer(trainer.Trainer):
class CtrPaddleTrainer(Trainer):
"""R
"""
def __init__(self, config):
"""R
"""
trainer.Trainer.__init__(self, config)
Trainer.__init__(self, config)
config['output_path'] = util.get_absolute_path(
config['output_path'], config['io']['afs'])
self.global_config = config
self._place = fluid.CPUPlace()
self._exe = fluid.Executor(self._place)
self._exector_context = {}
self.global_config = config
self._metrics = {}
self._path_generator = util.PathGenerator({
'templates': [
{'name': 'xbox_base_done', 'template': config['output_path'] + '/xbox_base_done.txt'},
......@@ -202,7 +201,7 @@ class CtrPaddleTrainer(trainer.Trainer):
"""R
"""
cost_printer = util.CostPrinter(util.print_cost,
{'master': True, 'log_format': 'save model cost %s sec'})
{'master': True, 'log_format': 'save model cost %s sec'})
model_path = self._path_generator.generate_path('batch_model', {'day': day, 'pass_id': pass_index})
save_mode = 0 # just save all
if pass_index < 1: # batch_model
......@@ -224,8 +223,8 @@ class CtrPaddleTrainer(trainer.Trainer):
model_path = ""
xbox_model_donefile = ""
cost_printer = util.CostPrinter(util.print_cost, {'master': True, \
'log_format': 'save xbox model cost %s sec',
'stdout': stdout_str})
'log_format': 'save xbox model cost %s sec',
'stdout': stdout_str})
if pass_index < 1:
save_mode = 2
xbox_patch_id = xbox_base_key
......@@ -239,8 +238,8 @@ class CtrPaddleTrainer(trainer.Trainer):
cost_printer.done()
cost_printer = util.CostPrinter(util.print_cost, {'master': True,
'log_format': 'save cache model cost %s sec',
'stdout': stdout_str})
'log_format': 'save cache model cost %s sec',
'stdout': stdout_str})
model_file_handler = fs.FileHandler(self.global_config['io']['afs'])
if self.global_config['save_cache_model']:
cache_save_num = fleet.save_cache_model(None, model_path, mode=save_mode)
......@@ -255,8 +254,8 @@ class CtrPaddleTrainer(trainer.Trainer):
'save_combine': True
}
cost_printer = util.CostPrinter(util.print_cost, {'master': True,
'log_format': 'save dense model cost %s sec',
'stdout': stdout_str})
'log_format': 'save dense model cost %s sec',
'stdout': stdout_str})
if fleet._role_maker.is_first_worker():
for executor in self.global_config['executor']:
if 'layer_for_inference' not in executor:
......@@ -333,8 +332,8 @@ class CtrPaddleTrainer(trainer.Trainer):
self._train_pass = util.TimeTrainPass(self.global_config)
if not self.global_config['cold_start']:
cost_printer = util.CostPrinter(util.print_cost,
{'master': True, 'log_format': 'load model cost %s sec',
'stdout': stdout_str})
{'master': True, 'log_format': 'load model cost %s sec',
'stdout': stdout_str})
self.print_log("going to load model %s" % self._train_pass._checkpoint_model_path, {'master': True})
# if config.need_reqi_changeslot and config.reqi_dnn_plugin_day >= self._train_pass.date()
# and config.reqi_dnn_plugin_pass >= self._pass_id:
......@@ -373,7 +372,7 @@ class CtrPaddleTrainer(trainer.Trainer):
util.rank0_print("shrink table")
cost_printer = util.CostPrinter(util.print_cost,
{'master': True, 'log_format': 'shrink table done, cost %s sec'})
{'master': True, 'log_format': 'shrink table done, cost %s sec'})
fleet.shrink_sparse_table()
for executor in self._exector_context:
self._exector_context[executor]['model'].shrink({
......@@ -402,8 +401,8 @@ class CtrPaddleTrainer(trainer.Trainer):
train_begin_time = time.time()
cost_printer = util.CostPrinter(util.print_cost, \
{'master': True, 'log_format': 'load into memory done, cost %s sec',
'stdout': stdout_str})
{'master': True, 'log_format': 'load into memory done, cost %s sec',
'stdout': stdout_str})
current_dataset = {}
for name in self._dataset:
current_dataset[name] = self._dataset[name].load_dataset({
......@@ -437,7 +436,7 @@ class CtrPaddleTrainer(trainer.Trainer):
for executor in self.global_config['executor']:
self.run_executor(executor, current_dataset[executor['dataset_name']], stdout_str)
cost_printer = util.CostPrinter(util.print_cost, \
{'master': True, 'log_format': 'release_memory cost %s sec'})
{'master': True, 'log_format': 'release_memory cost %s sec'})
for name in current_dataset:
current_dataset[name].release_memory()
pure_train_cost = time.time() - pure_train_begin
......
......@@ -28,7 +28,6 @@ from ..utils import envs
class TranspileTrainer(Trainer):
def __init__(self, config=None):
Trainer.__init__(self, config)
self.exe = fluid.Executor(fluid.CPUPlace())
self.processor_register()
self.inference_models = []
......@@ -87,9 +86,9 @@ class TranspileTrainer(Trainer):
dirname = os.path.join(dirname, str(epoch_id))
if is_fleet:
fleet.save_inference_model(dirname, feed_varnames, fetch_vars, self.exe)
fleet.save_inference_model(dirname, feed_varnames, fetch_vars)
else:
fluid.io.save_inference_model(dirname, feed_varnames, fetch_vars, self.exe)
fluid.io.save_inference_model(dirname, feed_varnames, fetch_vars, self._exe)
self.inference_models.append((epoch_id, dirname))
def save_persistables():
......@@ -104,9 +103,9 @@ class TranspileTrainer(Trainer):
dirname = os.path.join(dirname, str(epoch_id))
if is_fleet:
fleet.save_persistables(self.exe, dirname)
fleet.save_persistables(dirname)
else:
fluid.io.save_persistables(self.exe, dirname)
fluid.io.save_persistables(self._exe, dirname)
self.increment_models.append((epoch_id, dirname))
save_persistables()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册