提交 3d693567 编写于 作者: J jingqinghe

update code

上级 4ddc72bf
......@@ -124,6 +124,21 @@ class FLTrainer(object):
with open(model_path + ".pdmodel", "wb") as f:
f.write(self._main_program.desc.serialize_to_string())
def save_serving_model(self, model_path, client_conf_path):
feed_vars = {}
target_vars = {}
for target in self._target_names:
tmp_target = self._main_program.block(0)._find_var_recursive(
target)
target_vars[target] = tmp_target
for feed in self._feed_names:
tmp_feed = self._main_program.block(0)._find_var_recursive(feed)
feed_vars[feed] = tmp_feed
serving_io.save_model(model_path, client_conf_path, feed_vars,
target_vars, self._main_program)
def stop(self):
# ask for termination with master endpoint
# currently not open sourced, will release the code later
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from paddle_serving_client import Client
client = Client()
client.load_client_config("imdb_client_conf/serving_client_conf.prototxt")
client.connect(["127.0.0.1:9292"])
data_dict = {}
for i in range(3):
data_dict[str(i)] = np.random.rand(1, 5).astype('float32')
fetch_map = client.predict(
feed={"0": data_dict['0'],
"1": data_dict['1'],
"2": data_dict['2']},
fetch=["fc_2.tmp_2"])
print("fetched result: ", fetch_map)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import paddle_fl.paddle_fl as fl
from paddle_fl.paddle_fl.core.master.job_generator import JobGenerator
from paddle_fl.paddle_fl.core.strategy.fl_strategy_base import FLStrategyFactory
class Model(object):
def __init__(self):
pass
def mlp(self, inputs, label, hidden_size=128):
self.concat = fluid.layers.concat(inputs, axis=1)
self.fc1 = fluid.layers.fc(input=self.concat, size=256, act='relu')
self.fc2 = fluid.layers.fc(input=self.fc1, size=128, act='relu')
self.predict = fluid.layers.fc(input=self.fc2, size=2, act='softmax')
self.sum_cost = fluid.layers.cross_entropy(
input=self.predict, label=label)
self.accuracy = fluid.layers.accuracy(input=self.predict, label=label)
self.loss = fluid.layers.reduce_mean(self.sum_cost)
self.startup_program = fluid.default_startup_program()
inputs = [fluid.layers.data( \
name=str(slot_id), shape=[5],
dtype="float32")
for slot_id in range(3)]
label = fluid.layers.data( \
name="label",
shape=[1],
dtype='int64')
model = Model()
model.mlp(inputs, label)
job_generator = JobGenerator()
optimizer = fluid.optimizer.SGD(learning_rate=0.1)
job_generator.set_optimizer(optimizer)
job_generator.set_losses([model.loss])
job_generator.set_startup_program(model.startup_program)
job_generator.set_infer_feed_and_target_names([x.name for x in inputs],
[model.predict.name])
build_strategy = FLStrategyFactory()
build_strategy.fed_avg = True
build_strategy.inner_step = 10
strategy = build_strategy.create_fl_strategy()
# endpoints will be collected through the cluster
# in this example, we suppose endpoints have been collected
endpoints = ["127.0.0.1:8181"]
output = "fl_job_config"
job_generator.generate_fl_job(
strategy, server_endpoints=endpoints, worker_num=2, output=output)
# fl_job_config will be dispatched to workers
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle_fl.paddle_fl.core.scheduler.agent_master import FLScheduler
worker_num = 2
server_num = 1
# Define the number of worker/server and the port for scheduler
scheduler = FLScheduler(worker_num, server_num, port=9091)
scheduler.set_sample_worker_num(worker_num)
scheduler.init_env()
print("init env done.")
scheduler.start_fl_training()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle_fl.paddle_fl as fl
import paddle.fluid as fluid
from paddle_fl.paddle_fl.core.server.fl_server import FLServer
from paddle_fl.paddle_fl.core.master.fl_job import FLRunTimeJob
server = FLServer()
server_id = 0
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_server_job(job_path, server_id)
job._scheduler_ep = "127.0.0.1:9091" # IP address for scheduler
server.set_server_job(job)
server._current_ep = "127.0.0.1:8181" # IP address for server
server.start()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle_fl.paddle_fl.core.trainer.fl_trainer import FLTrainerFactory
from paddle_fl.paddle_fl.core.master.fl_job import FLRunTimeJob
import numpy as np
import paddle_serving_client.io as serving_io
import sys
import logging
import time
logging.basicConfig(
filename="test.log",
filemode="w",
format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
datefmt="%d-%M-%Y %H:%M:%S",
level=logging.DEBUG)
def reader():
for i in range(1000):
data_dict = {}
for i in range(3):
data_dict[str(i)] = np.random.rand(1, 5).astype('float32')
data_dict["label"] = np.random.randint(2, size=(1, 1)).astype('int64')
yield data_dict
trainer_id = int(sys.argv[1]) # trainer id for each guest
job_path = "fl_job_config"
job = FLRunTimeJob()
job.load_trainer_job(job_path, trainer_id)
job._scheduler_ep = "127.0.0.1:9091" # Inform the scheduler IP to trainer
trainer = FLTrainerFactory().create_fl_trainer(job)
trainer._current_ep = "127.0.0.1:{}".format(9000 + trainer_id)
place = fluid.CPUPlace()
trainer.start(place)
print("scheduler_ep is {}, current_ep is {}".format(trainer._scheduler_ep,
trainer._current_ep))
"""
feed_vars = {}
target_vars = {}
for target in trainer._target_names:
tmp_target = trainer._main_program.block(0)._find_var_recursive(target)
target_vars[target] = tmp_target
for feed in trainer._feed_names:
tmp_feed = trainer._main_program.block(0)._find_var_recursive(feed)
feed_vars[feed] = tmp_feed
"""
epoch_id = 0
while not trainer.stop():
if epoch_id > 10:
break
print("{} epoch {} start train".format(
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())),
epoch_id))
train_step = 0
for data in reader():
trainer.run(feed=data, fetch=[])
train_step += 1
if train_step == trainer._step:
break
epoch_id += 1
if epoch_id % 5 == 0:
# trainer.save_inference_program(output_folder)
trainer.save_serving_model("test", "imdb_client_conf")
# serving_io.save_model("test","imdb_client_conf", feed_vars, target_vars, trainer._main_program)
unset http_proxy
unset https_proxy
ps -ef | grep -E fl_ | grep -v grep | awk '{print $2}' | xargs kill -9
log_dir=${1:-$(pwd)}
mkdir -p ${log_dir}
python fl_master.py > ${log_dir}/master.log &
sleep 2
python -u fl_scheduler.py > ${log_dir}/scheduler.log &
sleep 5
python -u fl_server.py > ${log_dir}/server0.log &
sleep 2
for ((i=0;i<2;i++))
do
python -u fl_trainer.py $i > ${log_dir}/trainer$i.log &
sleep 2
done
model_dir=$1
python -m paddle_serving_server.serve --model $model_dir --thread 10 --port 9292 &
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册