提交 f55efcfc 编写于 作者: Y Yang Zhang

Remove scope hack

1. weights are bound to respective models
2. models can coexists in same scope (actual parameter names are uniqued)
3. this make model composition (e.g., transfer learning) harder
上级 fcae408a
...@@ -68,10 +68,6 @@ class StaticGraphAdapter(object): ...@@ -68,10 +68,6 @@ class StaticGraphAdapter(object):
self._startup_prog = fluid.default_startup_program() self._startup_prog = fluid.default_startup_program()
self._main_prog = fluid.default_main_program() self._main_prog = fluid.default_main_program()
# HACK separate models by cleanup global scope
self._scope = fluid.executor.global_scope()
fluid.executor.g_scope = fluid.core.Scope()
self._label_vars = None # label variables self._label_vars = None # label variables
self._endpoints = {} self._endpoints = {}
self._loss_endpoint = None self._loss_endpoint = None
...@@ -118,13 +114,11 @@ class StaticGraphAdapter(object): ...@@ -118,13 +114,11 @@ class StaticGraphAdapter(object):
if prog is None or self.model._optimizer is None: if prog is None or self.model._optimizer is None:
print("optimizer not initialized, save parameters only") print("optimizer not initialized, save parameters only")
prog = self._main_prog prog = self._main_prog
with fluid.executor.scope_guard(self._scope): fluid.save(prog, path)
fluid.save(prog, path)
def load(self, path): def load(self, path):
prog = self._main_prog prog = self._main_prog
with fluid.executor.scope_guard(self._scope): fluid.load(prog, path, self._executor)
fluid.load(prog, path, self._executor)
def _run(self, inputs, labels=None, device='CPU', device_ids=None): def _run(self, inputs, labels=None, device='CPU', device_ids=None):
inputs = to_list(inputs) inputs = to_list(inputs)
...@@ -133,18 +127,17 @@ class StaticGraphAdapter(object): ...@@ -133,18 +127,17 @@ class StaticGraphAdapter(object):
assert len(inputs) == len(self._input_desc), "number of inputs" \ assert len(inputs) == len(self._input_desc), "number of inputs" \
+ " does not match number of arguments of `forward` method" + " does not match number of arguments of `forward` method"
with fluid.executor.scope_guard(self._scope): if self._progs.get(self.mode, None) is None:
if self._progs.get(self.mode, None) is None: self._make_program(self._infer_input_vars(inputs))
self._make_program(self._infer_input_vars(inputs))
ids = [str(i) for i in device_ids] ids = [str(i) for i in device_ids]
ids.sort() ids.sort()
prog_hash = '_'.join([self.mode] + ids) prog_hash = '_'.join([self.mode] + ids)
compiled_prog = self._compiled_progs.get(prog_hash, None) compiled_prog = self._compiled_progs.get(prog_hash, None)
if compiled_prog is None: if compiled_prog is None:
compiled_prog = self._compile_and_initialize( compiled_prog = self._compile_and_initialize(
self._progs[self.mode], device, device_ids) self._progs[self.mode], device, device_ids)
self._compiled_progs[prog_hash] = compiled_prog self._compiled_progs[prog_hash] = compiled_prog
feed = {} feed = {}
input_names = [name for name in self._input_desc.keys()] input_names = [name for name in self._input_desc.keys()]
...@@ -157,7 +150,7 @@ class StaticGraphAdapter(object): ...@@ -157,7 +150,7 @@ class StaticGraphAdapter(object):
feed[v.name] = labels[idx] feed[v.name] = labels[idx]
outputs = self._executor.run( outputs = self._executor.run(
compiled_prog, scope=self._scope, feed=feed, compiled_prog, feed=feed,
fetch_list=self._endpoints[self.mode]) fetch_list=self._endpoints[self.mode])
return outputs return outputs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册