提交 1faf669a 编写于 作者: Y Yang Zhang

Refactor `_compile_and_initialize` a bit

上级 67df9550
......@@ -245,13 +245,8 @@ class StaticGraphAdapter(object):
if self._progs.get(self.mode, None) is None:
self._make_program(self._infer_input_vars(inputs))
ids = [str(i) for i in device_ids]
ids.sort()
compiled_prog = self._compiled_progs.get(self.mode, None)
if compiled_prog is None:
compiled_prog = self._compile_and_initialize(
self._progs[self.mode], device, device_ids)
self._compiled_progs[self.mode] = compiled_prog
compiled_prog = self._compile_and_initialize(
self._progs[self.mode], device, device_ids)
feed = {}
input_names = [name for name in self._input_desc.keys()]
......@@ -320,36 +315,19 @@ class StaticGraphAdapter(object):
return label_vars
def _compile_and_initialize(self, prog, device='CPU', device_ids=None):
if device.lower() == 'cpu':
place = fluid.CPUPlace()
elif device.lower() == 'gpu' and isinstance(device_ids, (list, tuple)):
place = fluid.CUDAPlace(device_ids[0])
else:
raise "device not supported"
compiled_prog = self._compiled_progs.get(self.mode, None)
if compiled_prog is not None:
return compiled_prog
compiled_prog = fluid.CompiledProgram(prog)
if device.lower() == 'gpu' and len(device_ids) > 0:
places = [fluid.CUDAPlace(i) for i in device_ids]
loss_name = None
if self._loss_endpoint is not None:
loss_name = self._loss_endpoint.name
share_vars_from = None
if self.mode == 'eval' and 'train' in self._compiled_progs:
share_vars_from = self._compiled_progs['train']
# HACK invalidate eval program if is compiled before train program
# quite hackish, OTOH, it is generally uncommon that the eval
# program will be run before the train program
if self.mode == 'train' and 'eval' in self._compiled_progs:
del self._compiled_progs['eval']
compiled_prog = compiled_prog.with_data_parallel(
loss_name=loss_name, places=places,
share_vars_from=share_vars_from)
places = [device.lower() == 'gpu' and fluid.CUDAPlace(i)
or fluid.CPUPlace() for i in device_ids]
# XXX only run startup once as *ALL WEIGHTS* should have been
# initialized upon construction of the model even if `forward()`
# may run different code path for different mode
if self._executor is None:
self._executor = fluid.Executor(place)
# XXX only run startup once as *ALL* weights should be initialized
# upon construction of the model
# XXX incremental initialization, lifted from GuoSheng code
self._executor = fluid.Executor(places[0])
# XXX incremental initialization
uninitialized = []
for var_py in self._startup_prog.list_vars():
var = fluid.global_scope().find_var(var_py.name)
......@@ -360,6 +338,26 @@ class StaticGraphAdapter(object):
startup_prog = self._startup_prog._prune(uninitialized)
self._executor.run(startup_prog)
if len(device_ids) < 2:
return prog
share_vars_from = None
if self.mode == 'eval' and 'train' in self._compiled_progs:
share_vars_from = self._compiled_progs['train']
# HACK invalidate eval program if is compiled before train program
# quite hackish, OTOH, it is generally uncommon that the eval
# program will be run before the train program
if self.mode == 'train' and 'eval' in self._compiled_progs:
del self._compiled_progs['eval']
compiled_prog = fluid.CompiledProgram(prog)
loss_name = None
if self._loss_endpoint is not None:
loss_name = self._loss_endpoint.name
compiled_prog = compiled_prog.with_data_parallel(
loss_name=loss_name, places=places,
share_vars_from=share_vars_from)
self._compiled_progs[self.mode] = compiled_prog
return compiled_prog
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册