提交 7408da55 编写于 作者: T tangwei

fix ctr trainer

上级 9368aec8
...@@ -12,28 +12,24 @@ ...@@ -12,28 +12,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
import sys import sys
import copy
import yaml
import time import time
import json import json
import datetime import datetime
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from .. utils import fs as fs
from .. utils import util as util
from .. metrics.auc_metrics import AUCMetric
from .. models import base as model_basic
from .. reader import dataset
from . import trainer
from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
from paddle.fluid.incubate.fleet.base.role_maker import GeneralRoleMaker from paddle.fluid.incubate.fleet.base.role_maker import GeneralRoleMaker
from eleps.utils import fs as fs
from eleps.utils import util as util
from eleps.metrics.auc_metrics import AUCMetric
from eleps.models import base as model_basic
from eleps.reader import dataset
from .trainer import Trainer
def wroker_numric_opt(value, env, opt): def wroker_numric_opt(value, env, opt):
""" """
...@@ -75,21 +71,24 @@ def worker_numric_max(value, env="mpi"): ...@@ -75,21 +71,24 @@ def worker_numric_max(value, env="mpi"):
return wroker_numric_opt(value, env, "max") return wroker_numric_opt(value, env, "max")
class CtrPaddleTrainer(trainer.Trainer): class CtrPaddleTrainer(Trainer):
"""R """R
""" """
def __init__(self, config): def __init__(self, config):
"""R """R
""" """
trainer.Trainer.__init__(self, config) Trainer.__init__(self, config)
config['output_path'] = util.get_absolute_path( config['output_path'] = util.get_absolute_path(
config['output_path'], config['io']['afs']) config['output_path'], config['io']['afs'])
self.global_config = config
self._place = fluid.CPUPlace() self._place = fluid.CPUPlace()
self._exe = fluid.Executor(self._place) self._exe = fluid.Executor(self._place)
self._exector_context = {} self._exector_context = {}
self.global_config = config
self._metrics = {} self._metrics = {}
self._path_generator = util.PathGenerator({ self._path_generator = util.PathGenerator({
'templates': [ 'templates': [
{'name': 'xbox_base_done', 'template': config['output_path'] + '/xbox_base_done.txt'}, {'name': 'xbox_base_done', 'template': config['output_path'] + '/xbox_base_done.txt'},
...@@ -202,7 +201,7 @@ class CtrPaddleTrainer(trainer.Trainer): ...@@ -202,7 +201,7 @@ class CtrPaddleTrainer(trainer.Trainer):
"""R """R
""" """
cost_printer = util.CostPrinter(util.print_cost, cost_printer = util.CostPrinter(util.print_cost,
{'master': True, 'log_format': 'save model cost %s sec'}) {'master': True, 'log_format': 'save model cost %s sec'})
model_path = self._path_generator.generate_path('batch_model', {'day': day, 'pass_id': pass_index}) model_path = self._path_generator.generate_path('batch_model', {'day': day, 'pass_id': pass_index})
save_mode = 0 # just save all save_mode = 0 # just save all
if pass_index < 1: # batch_model if pass_index < 1: # batch_model
...@@ -224,8 +223,8 @@ class CtrPaddleTrainer(trainer.Trainer): ...@@ -224,8 +223,8 @@ class CtrPaddleTrainer(trainer.Trainer):
model_path = "" model_path = ""
xbox_model_donefile = "" xbox_model_donefile = ""
cost_printer = util.CostPrinter(util.print_cost, {'master': True, \ cost_printer = util.CostPrinter(util.print_cost, {'master': True, \
'log_format': 'save xbox model cost %s sec', 'log_format': 'save xbox model cost %s sec',
'stdout': stdout_str}) 'stdout': stdout_str})
if pass_index < 1: if pass_index < 1:
save_mode = 2 save_mode = 2
xbox_patch_id = xbox_base_key xbox_patch_id = xbox_base_key
...@@ -239,8 +238,8 @@ class CtrPaddleTrainer(trainer.Trainer): ...@@ -239,8 +238,8 @@ class CtrPaddleTrainer(trainer.Trainer):
cost_printer.done() cost_printer.done()
cost_printer = util.CostPrinter(util.print_cost, {'master': True, cost_printer = util.CostPrinter(util.print_cost, {'master': True,
'log_format': 'save cache model cost %s sec', 'log_format': 'save cache model cost %s sec',
'stdout': stdout_str}) 'stdout': stdout_str})
model_file_handler = fs.FileHandler(self.global_config['io']['afs']) model_file_handler = fs.FileHandler(self.global_config['io']['afs'])
if self.global_config['save_cache_model']: if self.global_config['save_cache_model']:
cache_save_num = fleet.save_cache_model(None, model_path, mode=save_mode) cache_save_num = fleet.save_cache_model(None, model_path, mode=save_mode)
...@@ -255,8 +254,8 @@ class CtrPaddleTrainer(trainer.Trainer): ...@@ -255,8 +254,8 @@ class CtrPaddleTrainer(trainer.Trainer):
'save_combine': True 'save_combine': True
} }
cost_printer = util.CostPrinter(util.print_cost, {'master': True, cost_printer = util.CostPrinter(util.print_cost, {'master': True,
'log_format': 'save dense model cost %s sec', 'log_format': 'save dense model cost %s sec',
'stdout': stdout_str}) 'stdout': stdout_str})
if fleet._role_maker.is_first_worker(): if fleet._role_maker.is_first_worker():
for executor in self.global_config['executor']: for executor in self.global_config['executor']:
if 'layer_for_inference' not in executor: if 'layer_for_inference' not in executor:
...@@ -333,8 +332,8 @@ class CtrPaddleTrainer(trainer.Trainer): ...@@ -333,8 +332,8 @@ class CtrPaddleTrainer(trainer.Trainer):
self._train_pass = util.TimeTrainPass(self.global_config) self._train_pass = util.TimeTrainPass(self.global_config)
if not self.global_config['cold_start']: if not self.global_config['cold_start']:
cost_printer = util.CostPrinter(util.print_cost, cost_printer = util.CostPrinter(util.print_cost,
{'master': True, 'log_format': 'load model cost %s sec', {'master': True, 'log_format': 'load model cost %s sec',
'stdout': stdout_str}) 'stdout': stdout_str})
self.print_log("going to load model %s" % self._train_pass._checkpoint_model_path, {'master': True}) self.print_log("going to load model %s" % self._train_pass._checkpoint_model_path, {'master': True})
# if config.need_reqi_changeslot and config.reqi_dnn_plugin_day >= self._train_pass.date() # if config.need_reqi_changeslot and config.reqi_dnn_plugin_day >= self._train_pass.date()
# and config.reqi_dnn_plugin_pass >= self._pass_id: # and config.reqi_dnn_plugin_pass >= self._pass_id:
...@@ -373,7 +372,7 @@ class CtrPaddleTrainer(trainer.Trainer): ...@@ -373,7 +372,7 @@ class CtrPaddleTrainer(trainer.Trainer):
util.rank0_print("shrink table") util.rank0_print("shrink table")
cost_printer = util.CostPrinter(util.print_cost, cost_printer = util.CostPrinter(util.print_cost,
{'master': True, 'log_format': 'shrink table done, cost %s sec'}) {'master': True, 'log_format': 'shrink table done, cost %s sec'})
fleet.shrink_sparse_table() fleet.shrink_sparse_table()
for executor in self._exector_context: for executor in self._exector_context:
self._exector_context[executor]['model'].shrink({ self._exector_context[executor]['model'].shrink({
...@@ -402,8 +401,8 @@ class CtrPaddleTrainer(trainer.Trainer): ...@@ -402,8 +401,8 @@ class CtrPaddleTrainer(trainer.Trainer):
train_begin_time = time.time() train_begin_time = time.time()
cost_printer = util.CostPrinter(util.print_cost, \ cost_printer = util.CostPrinter(util.print_cost, \
{'master': True, 'log_format': 'load into memory done, cost %s sec', {'master': True, 'log_format': 'load into memory done, cost %s sec',
'stdout': stdout_str}) 'stdout': stdout_str})
current_dataset = {} current_dataset = {}
for name in self._dataset: for name in self._dataset:
current_dataset[name] = self._dataset[name].load_dataset({ current_dataset[name] = self._dataset[name].load_dataset({
...@@ -437,7 +436,7 @@ class CtrPaddleTrainer(trainer.Trainer): ...@@ -437,7 +436,7 @@ class CtrPaddleTrainer(trainer.Trainer):
for executor in self.global_config['executor']: for executor in self.global_config['executor']:
self.run_executor(executor, current_dataset[executor['dataset_name']], stdout_str) self.run_executor(executor, current_dataset[executor['dataset_name']], stdout_str)
cost_printer = util.CostPrinter(util.print_cost, \ cost_printer = util.CostPrinter(util.print_cost, \
{'master': True, 'log_format': 'release_memory cost %s sec'}) {'master': True, 'log_format': 'release_memory cost %s sec'})
for name in current_dataset: for name in current_dataset:
current_dataset[name].release_memory() current_dataset[name].release_memory()
pure_train_cost = time.time() - pure_train_begin pure_train_cost = time.time() - pure_train_begin
......
...@@ -28,7 +28,6 @@ from ..utils import envs ...@@ -28,7 +28,6 @@ from ..utils import envs
class TranspileTrainer(Trainer): class TranspileTrainer(Trainer):
def __init__(self, config=None): def __init__(self, config=None):
Trainer.__init__(self, config) Trainer.__init__(self, config)
self.exe = fluid.Executor(fluid.CPUPlace())
self.processor_register() self.processor_register()
self.inference_models = [] self.inference_models = []
...@@ -87,9 +86,9 @@ class TranspileTrainer(Trainer): ...@@ -87,9 +86,9 @@ class TranspileTrainer(Trainer):
dirname = os.path.join(dirname, str(epoch_id)) dirname = os.path.join(dirname, str(epoch_id))
if is_fleet: if is_fleet:
fleet.save_inference_model(dirname, feed_varnames, fetch_vars, self.exe) fleet.save_inference_model(dirname, feed_varnames, fetch_vars)
else: else:
fluid.io.save_inference_model(dirname, feed_varnames, fetch_vars, self.exe) fluid.io.save_inference_model(dirname, feed_varnames, fetch_vars, self._exe)
self.inference_models.append((epoch_id, dirname)) self.inference_models.append((epoch_id, dirname))
def save_persistables(): def save_persistables():
...@@ -104,9 +103,9 @@ class TranspileTrainer(Trainer): ...@@ -104,9 +103,9 @@ class TranspileTrainer(Trainer):
dirname = os.path.join(dirname, str(epoch_id)) dirname = os.path.join(dirname, str(epoch_id))
if is_fleet: if is_fleet:
fleet.save_persistables(self.exe, dirname) fleet.save_persistables(dirname)
else: else:
fluid.io.save_persistables(self.exe, dirname) fluid.io.save_persistables(self._exe, dirname)
self.increment_models.append((epoch_id, dirname)) self.increment_models.append((epoch_id, dirname))
save_persistables() save_persistables()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册