提交 c653fd8a 编写于 作者: J jingqinghe

support CUDA Place

上级 9544170e
......@@ -57,11 +57,11 @@ class FLTrainer(object):
self._current_ep = None
self.cur_step = 0
def start(self):
def start(self, place):
#current_ep = "to be added"
self.agent = FLWorkerAgent(self._scheduler_ep, self._current_ep)
self.agent.connect_scheduler()
self.exe = fluid.Executor(fluid.CPUPlace())
self.exe = fluid.Executor(place)
self.exe.run(self._startup_program)
def run(self, feed, fetch):
......@@ -103,11 +103,11 @@ class FedAvgTrainer(FLTrainer):
super(FedAvgTrainer, self).__init__()
pass
def start(self):
def start(self, place):
#current_ep = "to be added"
self.agent = FLWorkerAgent(self._scheduler_ep, self._current_ep)
self.agent.connect_scheduler()
self.exe = fluid.Executor(fluid.CPUPlace())
self.exe = fluid.Executor(place)
self.exe.run(self._startup_program)
def set_trainer_job(self, job):
......@@ -185,10 +185,10 @@ class SecAggTrainer(FLTrainer):
def step_id(self, s):
self._step_id = s
def start(self):
def start(self, place):
self.agent = FLWorkerAgent(self._scheduler_ep, self._current_ep)
self.agent.connect_scheduler()
self.exe = fluid.Executor(fluid.CPUPlace())
self.exe = fluid.Executor(place)
self.exe.run(self._startup_program)
self.cur_step = 0
......
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory
from paddle_fl.core.master.fl_job import FLRunTimeJob
import numpy as np
......@@ -42,7 +42,8 @@ job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id)
trainer.start()
place = fluid.CPUPlace()
trainer.start(place)
print(trainer._scheduler_ep, trainer._current_ep)
output_folder = "fl_model"
epoch_id = 0
......
......@@ -35,7 +35,8 @@ job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091" # Inform scheduler IP address to trainer
trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id)
trainer.start()
place = fluid.CPUPlace()
trainer.start(place)
test_program = trainer._main_program.clone(for_test=True)
......
......@@ -37,7 +37,8 @@ job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
print(job._target_names)
trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id)
trainer.start()
place = fluid.CPUPlace()
trainer.start(place)
print(trainer._step)
test_program = trainer._main_program.clone(for_test=True)
......
......@@ -36,7 +36,8 @@ job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id)
trainer.start()
place = fluid.CPUPlace()
trainer.start(place)
r = Gru4rec_Reader()
train_reader = r.reader(train_file_dir, place, batch_size=125)
......
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory
from paddle_fl.core.master.fl_job import FLRunTimeJob
import numpy as np
......@@ -47,7 +47,8 @@ trainer = FLTrainerFactory().create_fl_trainer(job)
#trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
trainer._current_ep = os.environ['TRAINER0_SERVICE_HOST'] + ":" + os.environ[
'TRAINER0_SERVICE_PORT_TRAINER0']
trainer.start()
place = fluid.CPUPlace()
trainer.start(place)
print(trainer._scheduler_ep, trainer._current_ep)
output_folder = "fl_model"
epoch_id = 0
......
......@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle_fl.core.trainer.fl_trainer import FLTrainerFactory
from paddle_fl.core.master.fl_job import FLRunTimeJob
import numpy as np
......@@ -47,7 +47,8 @@ trainer = FLTrainerFactory().create_fl_trainer(job)
#trainer._current_ep = "127.0.0.1:{}".format(9000+trainer_id)
trainer._current_ep = os.environ['TRAINER1_SERVICE_HOST'] + ":" + os.environ[
'TRAINER1_SERVICE_PORT_TRAINER1']
trainer.start()
place = fluid.CPUPlace()
trainer.start(place)
print(trainer._scheduler_ep, trainer._current_ep)
output_folder = "fl_model"
epoch_id = 0
......
......@@ -53,7 +53,8 @@ trainer.trainer_id = trainer_id
trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id)
trainer.trainer_num = trainer_num
trainer.key_dir = "./keys/"
trainer.start()
place = fluid.CPUPlace()
trainer.start(place)
output_folder = "fl_model"
epoch_id = 0
......
......@@ -99,7 +99,8 @@ else:
job._scheduler_ep = scheduler_conf["ENDPOINT"]
trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = endpoint
trainer.start()
place = fluid.CPUPlace()
trainer.start(place)
print(trainer._scheduler_ep, trainer._current_ep)
output_folder = "fl_model"
epoch_id = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册