diff --git a/core/trainers/single_infer.py b/core/trainers/single_infer.py index 4b3d137c1979d7af97e509dd4e75a7c3db9ee6ca..9028e6967aad202e512350ff3e8605383ae95fde 100755 --- a/core/trainers/single_infer.py +++ b/core/trainers/single_infer.py @@ -48,6 +48,12 @@ class SingleInfer(TranspileTrainer): envs.set_global_envs(self._config) envs.update_workspace() self._runner_name = envs.get_global_env("mode") + device = envs.get_global_env("runner." + self._runner_name + ".device") + if device == 'gpu': + self._place = fluid.CUDAPlace(0) + elif device == 'cpu': + self._place = fluid.CPUPlace() + self._exe = fluid.Executor(self._place) def processor_register(self): self.regist_context_processor('uninit', self.instance) @@ -189,7 +195,8 @@ class SingleInfer(TranspileTrainer): context['status'] = 'train_pass' def executor_train(self, context): - epochs = int(self._env["epochs"]) + epochs = int( + envs.get_global_env("runner." + self._runner_name + ".epochs")) for j in range(epochs): for model_dict in self._env["phase"]: if j == 0: diff --git a/core/trainers/single_trainer.py b/core/trainers/single_trainer.py index 73d82b1ae0a0231ab3657e285f805c84e49aab91..3bf08196a6401823a8a10f18a43b8973b3d3b896 100755 --- a/core/trainers/single_trainer.py +++ b/core/trainers/single_trainer.py @@ -36,18 +36,18 @@ class SingleTrainer(TranspileTrainer): def __init__(self, config=None): super(TranspileTrainer, self).__init__(config) self._env = self._config - device = envs.get_global_env("device") - if device == 'gpu': - self._place = fluid.CUDAPlace(0) - elif device == 'cpu': - self._place = fluid.CPUPlace() - self._exe = fluid.Executor(self._place) self.processor_register() self._model = {} self._dataset = {} envs.set_global_envs(self._config) envs.update_workspace() self._runner_name = envs.get_global_env("mode") + device = envs.get_global_env("runner." + self._runner_name + ".device") + if device == 'gpu': + self._place = fluid.CUDAPlace(0) + elif device == 'cpu': + self._place = fluid.CPUPlace() + self._exe = fluid.Executor(self._place) def processor_register(self): self.regist_context_processor('uninit', self.instance) @@ -192,7 +192,8 @@ class SingleTrainer(TranspileTrainer): context['status'] = 'train_pass' def executor_train(self, context): - epochs = int(self._env["epochs"]) + epochs = int( + envs.get_global_env("runner." + self._runner_name + ".epochs")) for j in range(epochs): for model_dict in self._env["phase"]: if j == 0: