提交 f55efcfc 编写于 作者: Y Yang Zhang

Remove scope hack

1. weights are bound to respective models
2. models can coexists in same scope (actual parameter names are uniqued)
3. this make model composition (e.g., transfer learning) harder
上级 fcae408a
......@@ -68,10 +68,6 @@ class StaticGraphAdapter(object):
self._startup_prog = fluid.default_startup_program()
self._main_prog = fluid.default_main_program()
# HACK separate models by cleanup global scope
self._scope = fluid.executor.global_scope()
fluid.executor.g_scope = fluid.core.Scope()
self._label_vars = None # label variables
self._endpoints = {}
self._loss_endpoint = None
......@@ -118,13 +114,11 @@ class StaticGraphAdapter(object):
if prog is None or self.model._optimizer is None:
print("optimizer not initialized, save parameters only")
prog = self._main_prog
with fluid.executor.scope_guard(self._scope):
fluid.save(prog, path)
fluid.save(prog, path)
def load(self, path):
prog = self._main_prog
with fluid.executor.scope_guard(self._scope):
fluid.load(prog, path, self._executor)
fluid.load(prog, path, self._executor)
def _run(self, inputs, labels=None, device='CPU', device_ids=None):
inputs = to_list(inputs)
......@@ -133,18 +127,17 @@ class StaticGraphAdapter(object):
assert len(inputs) == len(self._input_desc), "number of inputs" \
+ " does not match number of arguments of `forward` method"
with fluid.executor.scope_guard(self._scope):
if self._progs.get(self.mode, None) is None:
self._make_program(self._infer_input_vars(inputs))
if self._progs.get(self.mode, None) is None:
self._make_program(self._infer_input_vars(inputs))
ids = [str(i) for i in device_ids]
ids.sort()
prog_hash = '_'.join([self.mode] + ids)
compiled_prog = self._compiled_progs.get(prog_hash, None)
if compiled_prog is None:
compiled_prog = self._compile_and_initialize(
self._progs[self.mode], device, device_ids)
self._compiled_progs[prog_hash] = compiled_prog
ids = [str(i) for i in device_ids]
ids.sort()
prog_hash = '_'.join([self.mode] + ids)
compiled_prog = self._compiled_progs.get(prog_hash, None)
if compiled_prog is None:
compiled_prog = self._compile_and_initialize(
self._progs[self.mode], device, device_ids)
self._compiled_progs[prog_hash] = compiled_prog
feed = {}
input_names = [name for name in self._input_desc.keys()]
......@@ -157,7 +150,7 @@ class StaticGraphAdapter(object):
feed[v.name] = labels[idx]
outputs = self._executor.run(
compiled_prog, scope=self._scope, feed=feed,
compiled_prog, feed=feed,
fetch_list=self._endpoints[self.mode])
return outputs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册