提交 8e2a592b 编写于 作者: X Xin Pan

fix

test=develop
上级 7526ac14
...@@ -101,6 +101,10 @@ class CompiledProgram(object): ...@@ -101,6 +101,10 @@ class CompiledProgram(object):
self._exec_strategy = exec_strategy self._exec_strategy = exec_strategy
self._loss_name = loss_name self._loss_name = loss_name
self._share_vars_from = share_vars_from self._share_vars_from = share_vars_from
if self._exec_strategy is None:
self._exec_strategy = ExecutionStrategy()
if self._build_strategy is None:
self._build_strategy = BuildStrategy()
return self return self
def _with_distributed(self): def _with_distributed(self):
...@@ -124,12 +128,6 @@ class CompiledProgram(object): ...@@ -124,12 +128,6 @@ class CompiledProgram(object):
else: else:
self._local_scopes = [] self._local_scopes = []
self._places = []
if self._exec_strategy is None:
self._exec_strategy = ExecutionStrategy()
if self._build_strategy is None:
self._build_strategy = BuildStrategy()
self._exec_strategy.use_cuda = isinstance(self._place, core.CUDAPlace) self._exec_strategy.use_cuda = isinstance(self._place, core.CUDAPlace)
if self._exec_strategy.use_cuda: if self._exec_strategy.use_cuda:
gpus_env = os.getenv("FLAGS_selected_gpus") gpus_env = os.getenv("FLAGS_selected_gpus")
...@@ -194,6 +192,7 @@ class CompiledProgram(object): ...@@ -194,6 +192,7 @@ class CompiledProgram(object):
if place and self._place != place: if place and self._place != place:
raise ValueError("Cannot compile with different place") raise ValueError("Cannot compile with different place")
return self return self
self._compiled = True
self._scope = scope self._scope = scope
self._place = place self._place = place
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册