提交 8e14d8f9 编写于 作者: D dongdaxiang

add data_generator package into setup.py

上级 17790188
......@@ -72,7 +72,11 @@ Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + \
trainer_desc.__all__ + inferencer.__all__ + transpiler.__all__ + \
parallel_executor.__all__ + lod_tensor.__all__ + \
<<<<<<< HEAD
data_feed_desc.__all__ + async_executor.__all__ + compiler.__all__ + [
=======
data_feed_desc.__all__ + async_executor.__all__ + compiler.__all__ + [
>>>>>>> add data_generator package into setup.py
'io',
'initializer',
'layers',
......
......@@ -654,7 +654,7 @@ class Executor(object):
trainer._set_thread(thread)
trainer._set_debug(debug)
trainer._set_fetch_var_and_info(fetch_list, fetch_info, print_period)
return trainer
return scope, trainer
def infer_from_dataset(self,
program=None,
......@@ -702,7 +702,7 @@ class Executor(object):
dataset=dataset)
"""
trainer = self._prepare_trainer(
scope, trainer = self._prepare_trainer(
program=program,
dataset=dataset,
scope=scope,
......@@ -775,7 +775,7 @@ class Executor(object):
"""
trainer = self._prepare_trainer(
scope, trainer = self._prepare_trainer(
program=program,
dataset=dataset,
scope=scope,
......
......@@ -75,14 +75,14 @@ class MultiTrainer(TrainerDesc):
pass
def _set_program(self, program):
super(MultiTrainer, self).set_program(program)
super(MultiTrainer, self)._set_program(program)
self.program_ = program
def _gen_trainer_desc(self):
super(MultiTrainer, self).gen_trainer_desc()
super(MultiTrainer, self)._gen_trainer_desc()
self.proto_desc.class_name = "MultiTrainer"
self.device_worker_.set_infer(self.infer_)
self.device_worker_.gen_worker_desc(self.proto_desc)
self.device_worker_._set_infer(self.infer_)
self.device_worker_._gen_worker_desc(self.proto_desc)
class DistMultiTrainer(TrainerDesc):
......@@ -91,14 +91,14 @@ class DistMultiTrainer(TrainerDesc):
pass
def _set_program(self, program):
super(DistMultiTrainer, self).set_program(program)
super(DistMultiTrainer, self)._set_program(program)
self.program_ = program
def _gen_trainer_desc(self):
super(DistMultiTrainer, self).gen_trainer_desc()
super(DistMultiTrainer, self)._gen_trainer_desc()
self.proto_desc.class_name = "DistMultiTrainer"
if self.program_ == None:
print("None program")
self.device_worker_.set_infer(self.infer_)
self.device_worker_.set_program(self.program_)
self.device_worker_.gen_worker_desc(self.proto_desc)
self.device_worker_._set_infer(self.infer_)
self.device_worker_._set_program(self.program_)
self.device_worker_._gen_worker_desc(self.proto_desc)
......@@ -29,13 +29,13 @@ class TrainerFactory(object):
# default is MultiTrainer + Hogwild
trainer = MultiTrainer()
device_worker = Hogwild()
trainer.set_device_worker(device_worker)
trainer._set_device_worker(device_worker)
else:
trainer_class = opt_info["trainer"]
device_worker_class = opt_info["device_worker"]
trainer = globals()[trainer_class]()
device_worker = globals()[device_worker_class]()
device_worker.set_fleet_desc(opt_info["fleet_desc"])
trainer.set_device_worker(device_worker)
trainer.set_fleet_desc(opt_info["fleet_desc"])
device_worker._set_fleet_desc(opt_info["fleet_desc"])
trainer._set_device_worker(device_worker)
trainer._set_fleet_desc(opt_info["fleet_desc"])
return trainer
......@@ -122,6 +122,7 @@ packages=['paddle',
'paddle.fluid.transpiler',
'paddle.fluid.transpiler.details',
'paddle.fluid.incubate',
'paddle.fluid.incubate.data_generator',
'paddle.fluid.incubate.fleet',
'paddle.fluid.incubate.fleet.base',
'paddle.fluid.incubate.fleet.parameter_server',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册