diff --git a/python/paddle/fluid/inferencer.py b/python/paddle/fluid/inferencer.py index 276bc03109cfc9d9c2b764f97c0ba6615576cd2e..21277cb493881562ce3565da4bcbc0e97dafc71f 100644 --- a/python/paddle/fluid/inferencer.py +++ b/python/paddle/fluid/inferencer.py @@ -20,9 +20,13 @@ __all__ = [ class Inferencer(object): def __init__(self, network_func, params, place=None): - self.network_func = network_func + # we need to generate a framework.Program by calling + # network_func reference: fluid.program_guard in test_word2vec.py + # move the default_main_program to self.program + # and run the default_startup program self.params = params self.place = place def infer(self, inputs): + # run self.program pass diff --git a/python/paddle/fluid/params.py b/python/paddle/fluid/params.py index fcdb8617a97b7ffe55b2b9682354465dc3eb5b89..8d9d8f213400d66426fee4a29d9b6644906e7ef5 100644 --- a/python/paddle/fluid/params.py +++ b/python/paddle/fluid/params.py @@ -27,7 +27,15 @@ class Params(object): self._load(path) def _load(self, path): + # reference: load_persistables in io.py pass def save(self, path): + # reference: save_persistables in io.py + pass + + def add_params(self, scope): + # take the keys from the scope, + # if not already exists in self.scope, + # add the key and value into self.scope. pass diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index a878ed9d780509a4fbc261c39a45d99ecba8edaf..7d4c2837c101cf3c673cc9490657eedcde3be64f 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -34,10 +34,15 @@ class Event(Enum): class Trainer(object): def __init__(self, network_func, optimizer, params=None, place=None): + # we need to generate a framework.Program by calling + # network_func reference: fluid.program_guard in test_word2vec.py + # move the default_main_program to self.program + # and run the default_startup program on an empty self.network_func = network_func self.optimizer = optimizer self.params = params self.place = place + # TODO(helin): support distributed training def train(self, reader, num_epochs, event_handler): pass