From 15f79444be4a54289a46a0e923f67f221cafd9bc Mon Sep 17 00:00:00 2001 From: xjqbest <173596896@qq.com> Date: Thu, 28 May 2020 19:24:47 +0800 Subject: [PATCH] fix --- core/trainers/single_infer.py | 9 ++++++++- core/trainers/single_trainer.py | 15 ++++++++------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/core/trainers/single_infer.py b/core/trainers/single_infer.py index 4b3d137c..9028e696 100755 --- a/core/trainers/single_infer.py +++ b/core/trainers/single_infer.py @@ -48,6 +48,12 @@ class SingleInfer(TranspileTrainer): envs.set_global_envs(self._config) envs.update_workspace() self._runner_name = envs.get_global_env("mode") + device = envs.get_global_env("runner." + self._runner_name + ".device") + if device == 'gpu': + self._place = fluid.CUDAPlace(0) + elif device == 'cpu': + self._place = fluid.CPUPlace() + self._exe = fluid.Executor(self._place) def processor_register(self): self.regist_context_processor('uninit', self.instance) @@ -189,7 +195,8 @@ class SingleInfer(TranspileTrainer): context['status'] = 'train_pass' def executor_train(self, context): - epochs = int(self._env["epochs"]) + epochs = int( + envs.get_global_env("runner." + self._runner_name + ".epochs")) for j in range(epochs): for model_dict in self._env["phase"]: if j == 0: diff --git a/core/trainers/single_trainer.py b/core/trainers/single_trainer.py index 73d82b1a..3bf08196 100755 --- a/core/trainers/single_trainer.py +++ b/core/trainers/single_trainer.py @@ -36,18 +36,18 @@ class SingleTrainer(TranspileTrainer): def __init__(self, config=None): super(TranspileTrainer, self).__init__(config) self._env = self._config - device = envs.get_global_env("device") - if device == 'gpu': - self._place = fluid.CUDAPlace(0) - elif device == 'cpu': - self._place = fluid.CPUPlace() - self._exe = fluid.Executor(self._place) self.processor_register() self._model = {} self._dataset = {} envs.set_global_envs(self._config) envs.update_workspace() self._runner_name = envs.get_global_env("mode") + device = envs.get_global_env("runner." + self._runner_name + ".device") + if device == 'gpu': + self._place = fluid.CUDAPlace(0) + elif device == 'cpu': + self._place = fluid.CPUPlace() + self._exe = fluid.Executor(self._place) def processor_register(self): self.regist_context_processor('uninit', self.instance) @@ -192,7 +192,8 @@ class SingleTrainer(TranspileTrainer): context['status'] = 'train_pass' def executor_train(self, context): - epochs = int(self._env["epochs"]) + epochs = int( + envs.get_global_env("runner." + self._runner_name + ".epochs")) for j in range(epochs): for model_dict in self._env["phase"]: if j == 0: -- GitLab