提交 15f79444 编写于 作者: X xjqbest

fix

上级 98d75569
...@@ -48,6 +48,12 @@ class SingleInfer(TranspileTrainer): ...@@ -48,6 +48,12 @@ class SingleInfer(TranspileTrainer):
envs.set_global_envs(self._config) envs.set_global_envs(self._config)
envs.update_workspace() envs.update_workspace()
self._runner_name = envs.get_global_env("mode") self._runner_name = envs.get_global_env("mode")
device = envs.get_global_env("runner." + self._runner_name + ".device")
if device == 'gpu':
self._place = fluid.CUDAPlace(0)
elif device == 'cpu':
self._place = fluid.CPUPlace()
self._exe = fluid.Executor(self._place)
def processor_register(self): def processor_register(self):
self.regist_context_processor('uninit', self.instance) self.regist_context_processor('uninit', self.instance)
...@@ -189,7 +195,8 @@ class SingleInfer(TranspileTrainer): ...@@ -189,7 +195,8 @@ class SingleInfer(TranspileTrainer):
context['status'] = 'train_pass' context['status'] = 'train_pass'
def executor_train(self, context): def executor_train(self, context):
epochs = int(self._env["epochs"]) epochs = int(
envs.get_global_env("runner." + self._runner_name + ".epochs"))
for j in range(epochs): for j in range(epochs):
for model_dict in self._env["phase"]: for model_dict in self._env["phase"]:
if j == 0: if j == 0:
......
...@@ -36,18 +36,18 @@ class SingleTrainer(TranspileTrainer): ...@@ -36,18 +36,18 @@ class SingleTrainer(TranspileTrainer):
def __init__(self, config=None): def __init__(self, config=None):
super(TranspileTrainer, self).__init__(config) super(TranspileTrainer, self).__init__(config)
self._env = self._config self._env = self._config
device = envs.get_global_env("device")
if device == 'gpu':
self._place = fluid.CUDAPlace(0)
elif device == 'cpu':
self._place = fluid.CPUPlace()
self._exe = fluid.Executor(self._place)
self.processor_register() self.processor_register()
self._model = {} self._model = {}
self._dataset = {} self._dataset = {}
envs.set_global_envs(self._config) envs.set_global_envs(self._config)
envs.update_workspace() envs.update_workspace()
self._runner_name = envs.get_global_env("mode") self._runner_name = envs.get_global_env("mode")
device = envs.get_global_env("runner." + self._runner_name + ".device")
if device == 'gpu':
self._place = fluid.CUDAPlace(0)
elif device == 'cpu':
self._place = fluid.CPUPlace()
self._exe = fluid.Executor(self._place)
def processor_register(self): def processor_register(self):
self.regist_context_processor('uninit', self.instance) self.regist_context_processor('uninit', self.instance)
...@@ -192,7 +192,8 @@ class SingleTrainer(TranspileTrainer): ...@@ -192,7 +192,8 @@ class SingleTrainer(TranspileTrainer):
context['status'] = 'train_pass' context['status'] = 'train_pass'
def executor_train(self, context): def executor_train(self, context):
epochs = int(self._env["epochs"]) epochs = int(
envs.get_global_env("runner." + self._runner_name + ".epochs"))
for j in range(epochs): for j in range(epochs):
for model_dict in self._env["phase"]: for model_dict in self._env["phase"]:
if j == 0: if j == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册