提交 7408da55 编写于 作者: T tangwei

fix ctr trainer

上级 9368aec8
...@@ -12,28 +12,24 @@ ...@@ -12,28 +12,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
import sys import sys
import copy
import yaml
import time import time
import json import json
import datetime import datetime
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from .. utils import fs as fs
from .. utils import util as util
from .. metrics.auc_metrics import AUCMetric
from .. models import base as model_basic
from .. reader import dataset
from . import trainer
from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
from paddle.fluid.incubate.fleet.base.role_maker import GeneralRoleMaker from paddle.fluid.incubate.fleet.base.role_maker import GeneralRoleMaker
from eleps.utils import fs as fs
from eleps.utils import util as util
from eleps.metrics.auc_metrics import AUCMetric
from eleps.models import base as model_basic
from eleps.reader import dataset
from .trainer import Trainer
def wroker_numric_opt(value, env, opt): def wroker_numric_opt(value, env, opt):
""" """
...@@ -75,21 +71,24 @@ def worker_numric_max(value, env="mpi"): ...@@ -75,21 +71,24 @@ def worker_numric_max(value, env="mpi"):
return wroker_numric_opt(value, env, "max") return wroker_numric_opt(value, env, "max")
class CtrPaddleTrainer(trainer.Trainer): class CtrPaddleTrainer(Trainer):
"""R """R
""" """
def __init__(self, config): def __init__(self, config):
"""R """R
""" """
trainer.Trainer.__init__(self, config) Trainer.__init__(self, config)
config['output_path'] = util.get_absolute_path( config['output_path'] = util.get_absolute_path(
config['output_path'], config['io']['afs']) config['output_path'], config['io']['afs'])
self.global_config = config
self._place = fluid.CPUPlace() self._place = fluid.CPUPlace()
self._exe = fluid.Executor(self._place) self._exe = fluid.Executor(self._place)
self._exector_context = {} self._exector_context = {}
self.global_config = config
self._metrics = {} self._metrics = {}
self._path_generator = util.PathGenerator({ self._path_generator = util.PathGenerator({
'templates': [ 'templates': [
{'name': 'xbox_base_done', 'template': config['output_path'] + '/xbox_base_done.txt'}, {'name': 'xbox_base_done', 'template': config['output_path'] + '/xbox_base_done.txt'},
......
...@@ -28,7 +28,6 @@ from ..utils import envs ...@@ -28,7 +28,6 @@ from ..utils import envs
class TranspileTrainer(Trainer): class TranspileTrainer(Trainer):
def __init__(self, config=None): def __init__(self, config=None):
Trainer.__init__(self, config) Trainer.__init__(self, config)
self.exe = fluid.Executor(fluid.CPUPlace())
self.processor_register() self.processor_register()
self.inference_models = [] self.inference_models = []
...@@ -87,9 +86,9 @@ class TranspileTrainer(Trainer): ...@@ -87,9 +86,9 @@ class TranspileTrainer(Trainer):
dirname = os.path.join(dirname, str(epoch_id)) dirname = os.path.join(dirname, str(epoch_id))
if is_fleet: if is_fleet:
fleet.save_inference_model(dirname, feed_varnames, fetch_vars, self.exe) fleet.save_inference_model(dirname, feed_varnames, fetch_vars)
else: else:
fluid.io.save_inference_model(dirname, feed_varnames, fetch_vars, self.exe) fluid.io.save_inference_model(dirname, feed_varnames, fetch_vars, self._exe)
self.inference_models.append((epoch_id, dirname)) self.inference_models.append((epoch_id, dirname))
def save_persistables(): def save_persistables():
...@@ -104,9 +103,9 @@ class TranspileTrainer(Trainer): ...@@ -104,9 +103,9 @@ class TranspileTrainer(Trainer):
dirname = os.path.join(dirname, str(epoch_id)) dirname = os.path.join(dirname, str(epoch_id))
if is_fleet: if is_fleet:
fleet.save_persistables(self.exe, dirname) fleet.save_persistables(dirname)
else: else:
fluid.io.save_persistables(self.exe, dirname) fluid.io.save_persistables(self._exe, dirname)
self.increment_models.append((epoch_id, dirname)) self.increment_models.append((epoch_id, dirname))
save_persistables() save_persistables()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册