提交 8e14d8f9 编写于 作者: D dongdaxiang

add data_generator package into setup.py

上级 17790188
...@@ -72,7 +72,11 @@ Tensor = LoDTensor ...@@ -72,7 +72,11 @@ Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + \ __all__ = framework.__all__ + executor.__all__ + \
trainer_desc.__all__ + inferencer.__all__ + transpiler.__all__ + \ trainer_desc.__all__ + inferencer.__all__ + transpiler.__all__ + \
parallel_executor.__all__ + lod_tensor.__all__ + \ parallel_executor.__all__ + lod_tensor.__all__ + \
<<<<<<< HEAD
data_feed_desc.__all__ + async_executor.__all__ + compiler.__all__ + [ data_feed_desc.__all__ + async_executor.__all__ + compiler.__all__ + [
=======
data_feed_desc.__all__ + async_executor.__all__ + compiler.__all__ + [
>>>>>>> add data_generator package into setup.py
'io', 'io',
'initializer', 'initializer',
'layers', 'layers',
......
...@@ -654,7 +654,7 @@ class Executor(object): ...@@ -654,7 +654,7 @@ class Executor(object):
trainer._set_thread(thread) trainer._set_thread(thread)
trainer._set_debug(debug) trainer._set_debug(debug)
trainer._set_fetch_var_and_info(fetch_list, fetch_info, print_period) trainer._set_fetch_var_and_info(fetch_list, fetch_info, print_period)
return trainer return scope, trainer
def infer_from_dataset(self, def infer_from_dataset(self,
program=None, program=None,
...@@ -702,7 +702,7 @@ class Executor(object): ...@@ -702,7 +702,7 @@ class Executor(object):
dataset=dataset) dataset=dataset)
""" """
trainer = self._prepare_trainer( scope, trainer = self._prepare_trainer(
program=program, program=program,
dataset=dataset, dataset=dataset,
scope=scope, scope=scope,
...@@ -775,7 +775,7 @@ class Executor(object): ...@@ -775,7 +775,7 @@ class Executor(object):
""" """
trainer = self._prepare_trainer( scope, trainer = self._prepare_trainer(
program=program, program=program,
dataset=dataset, dataset=dataset,
scope=scope, scope=scope,
......
...@@ -75,14 +75,14 @@ class MultiTrainer(TrainerDesc): ...@@ -75,14 +75,14 @@ class MultiTrainer(TrainerDesc):
pass pass
def _set_program(self, program): def _set_program(self, program):
super(MultiTrainer, self).set_program(program) super(MultiTrainer, self)._set_program(program)
self.program_ = program self.program_ = program
def _gen_trainer_desc(self): def _gen_trainer_desc(self):
super(MultiTrainer, self).gen_trainer_desc() super(MultiTrainer, self)._gen_trainer_desc()
self.proto_desc.class_name = "MultiTrainer" self.proto_desc.class_name = "MultiTrainer"
self.device_worker_.set_infer(self.infer_) self.device_worker_._set_infer(self.infer_)
self.device_worker_.gen_worker_desc(self.proto_desc) self.device_worker_._gen_worker_desc(self.proto_desc)
class DistMultiTrainer(TrainerDesc): class DistMultiTrainer(TrainerDesc):
...@@ -91,14 +91,14 @@ class DistMultiTrainer(TrainerDesc): ...@@ -91,14 +91,14 @@ class DistMultiTrainer(TrainerDesc):
pass pass
def _set_program(self, program): def _set_program(self, program):
super(DistMultiTrainer, self).set_program(program) super(DistMultiTrainer, self)._set_program(program)
self.program_ = program self.program_ = program
def _gen_trainer_desc(self): def _gen_trainer_desc(self):
super(DistMultiTrainer, self).gen_trainer_desc() super(DistMultiTrainer, self)._gen_trainer_desc()
self.proto_desc.class_name = "DistMultiTrainer" self.proto_desc.class_name = "DistMultiTrainer"
if self.program_ == None: if self.program_ == None:
print("None program") print("None program")
self.device_worker_.set_infer(self.infer_) self.device_worker_._set_infer(self.infer_)
self.device_worker_.set_program(self.program_) self.device_worker_._set_program(self.program_)
self.device_worker_.gen_worker_desc(self.proto_desc) self.device_worker_._gen_worker_desc(self.proto_desc)
...@@ -29,13 +29,13 @@ class TrainerFactory(object): ...@@ -29,13 +29,13 @@ class TrainerFactory(object):
# default is MultiTrainer + Hogwild # default is MultiTrainer + Hogwild
trainer = MultiTrainer() trainer = MultiTrainer()
device_worker = Hogwild() device_worker = Hogwild()
trainer.set_device_worker(device_worker) trainer._set_device_worker(device_worker)
else: else:
trainer_class = opt_info["trainer"] trainer_class = opt_info["trainer"]
device_worker_class = opt_info["device_worker"] device_worker_class = opt_info["device_worker"]
trainer = globals()[trainer_class]() trainer = globals()[trainer_class]()
device_worker = globals()[device_worker_class]() device_worker = globals()[device_worker_class]()
device_worker.set_fleet_desc(opt_info["fleet_desc"]) device_worker._set_fleet_desc(opt_info["fleet_desc"])
trainer.set_device_worker(device_worker) trainer._set_device_worker(device_worker)
trainer.set_fleet_desc(opt_info["fleet_desc"]) trainer._set_fleet_desc(opt_info["fleet_desc"])
return trainer return trainer
...@@ -122,6 +122,7 @@ packages=['paddle', ...@@ -122,6 +122,7 @@ packages=['paddle',
'paddle.fluid.transpiler', 'paddle.fluid.transpiler',
'paddle.fluid.transpiler.details', 'paddle.fluid.transpiler.details',
'paddle.fluid.incubate', 'paddle.fluid.incubate',
'paddle.fluid.incubate.data_generator',
'paddle.fluid.incubate.fleet', 'paddle.fluid.incubate.fleet',
'paddle.fluid.incubate.fleet.base', 'paddle.fluid.incubate.fleet.base',
'paddle.fluid.incubate.fleet.parameter_server', 'paddle.fluid.incubate.fleet.parameter_server',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册