fl_trainer.py 7.9 KB
Newer Older
G
guru4elephant 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
G
guru4elephant 已提交
15
import logging
Q
qjing666 已提交
16
from paddle_fl.core.scheduler.agent_master import FLWorkerAgent
17
import numpy
18 19
import hmac
from .diffiehellman.diffiehellman import DiffieHellman
G
guru4elephant 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33

class FLTrainerFactory(object):
    def __init__(self):
        pass

    def create_fl_trainer(self, job):
        strategy = job._strategy
        trainer = None
        if strategy._fed_avg == True:
            trainer = FedAvgTrainer()
            trainer.set_trainer_job(job)
        elif strategy._dpsgd == True:
            trainer = FLTrainer()
            trainer.set_trainer_job(job)
34 35 36
        elif strategy._sec_agg == True:
            trainer = SecAggTrainer()
            trainer.set_trainer_job(job)
G
guru4elephant 已提交
37 38 39 40 41 42
        trainer.set_trainer_job(job)
        return trainer


class FLTrainer(object):
    def __init__(self):
G
guru4elephant 已提交
43
        self._logger = logging.getLogger("FLTrainer")
G
guru4elephant 已提交
44 45 46 47 48 49 50 51 52 53
        pass

    def set_trainer_job(self, job):
        self._startup_program = \
            job._trainer_startup_program
        self._main_program = \
            job._trainer_main_program
        self._step = job._strategy._inner_step
        self._feed_names = job._feed_names
        self._target_names = job._target_names
Q
qjing666 已提交
54 55 56
        self._scheduler_ep = job._scheduler_ep
	self._current_ep = None
	self.cur_step = 0
G
guru4elephant 已提交
57 58

    def start(self):
Q
qjing666 已提交
59 60 61
        #current_ep = "to be added"
        self.agent = FLWorkerAgent(self._scheduler_ep, self._current_ep)
        self.agent.connect_scheduler()
G
guru4elephant 已提交
62 63 64
        self.exe = fluid.Executor(fluid.CPUPlace())
        self.exe.run(self._startup_program)

G
guru4elephant 已提交
65 66 67
    def run(self, feed, fetch):
        self._logger.debug("begin to run")
        self.exe.run(self._main_program,
Q
qjing666 已提交
68 69
                      feed=feed,
                      fetch_list=fetch)
G
guru4elephant 已提交
70
        self._logger.debug("end to run current batch")
Q
qjing666 已提交
71
	self.cur_step += 1
G
guru4elephant 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89

    def save_inference_program(self, output_folder):
        target_vars = []
        infer_program = self._main_program.clone(for_test=True)
        for name in self._target_names:
            tmp_var = self._main_program.block(0)._find_var_recursive(name)
            target_vars.append(tmp_var)
        fluid.io.save_inference_model(
            output_folder,
            self._feed_names,
            target_vars,
            self.exe,
            main_program=infer_program)

    def stop(self):
        # ask for termination with master endpoint
        # currently not open sourced, will release the code later
        # TODO(guru4elephant): add connection with master
Q
qjing666 已提交
90 91 92 93 94 95 96 97 98
	if self.cur_step != 0:
		while not self.agent.finish_training():
			print('wait others finish')
			continue
        while not self.agent.can_join_training():
		print("wait permit")
		continue    
	print("ready to train")
	return False
G
guru4elephant 已提交
99

100

G
guru4elephant 已提交
101 102 103 104 105 106
class FedAvgTrainer(FLTrainer):
    def __init__(self):
        super(FedAvgTrainer, self).__init__()
        pass

    def start(self):
Q
qjing666 已提交
107 108 109
	#current_ep = "to be added"
        self.agent = FLWorkerAgent(self._scheduler_ep, self._current_ep)
	self.agent.connect_scheduler()
G
guru4elephant 已提交
110 111 112 113 114 115 116 117
        self.exe = fluid.Executor(fluid.CPUPlace())
        self.exe.run(self._startup_program)

    def set_trainer_job(self, job):
        super(FedAvgTrainer, self).set_trainer_job(job)
        self._send_program = job._trainer_send_program
        self._recv_program = job._trainer_recv_program

G
guru4elephant 已提交
118 119 120 121
    def reset(self):
        self.cur_step = 0

    def run(self, feed, fetch):
G
guru4elephant 已提交
122 123
        self._logger.debug("begin to run FedAvgTrainer, cur_step=%d, inner_step=%d" %
                           (self.cur_step, self._step))
G
guru4elephant 已提交
124
        if self.cur_step % self._step == 0:
G
guru4elephant 已提交
125
            self._logger.debug("begin to run recv program")
G
guru4elephant 已提交
126
            self.exe.run(self._recv_program)
G
guru4elephant 已提交
127
        self._logger.debug("begin to run current step")
F
frankwhzhang 已提交
128
        loss = self.exe.run(self._main_program, 
G
guru4elephant 已提交
129 130 131
                     feed=feed,
                     fetch_list=fetch)
        if self.cur_step % self._step == 0:
G
guru4elephant 已提交
132
            self._logger.debug("begin to run send program")
G
guru4elephant 已提交
133 134
            self.exe.run(self._send_program)
        self.cur_step += 1
F
frankwhzhang 已提交
135
        return loss
136 137 138 139 140 141 142 143 144 145

    def stop(self):
        return False
       
 
class SecAggTrainer(FLTrainer):
    def __init__(self):
        super(SecAggTrainer, self).__init__()
        pass

146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
    @property
    def trainer_id(self):
        return self._trainer_id

    @trainer_id.setter
    def trainer_id(self, s):
        self._trainer_id = s

    @property
    def trainer_num(self):
        return self._trainer_num

    @trainer_num.setter
    def trainer_num(self, s):
        self._trainer_num = s

    @property
    def key_dir(self):
        return self._key_dir

    @key_dir.setter
    def key_dir(self, s):
        self._key_dir = s

    @property
    def step_id(self):
        return self._step_id

    @step_id.setter
    def step_id(self, s):
        self._step_id = s

178 179 180 181 182 183 184 185 186
    def start(self):
        self.exe = fluid.Executor(fluid.CPUPlace())
        self.exe.run(self._startup_program)
        self.cur_step = 0

    def set_trainer_job(self, job):
        super(SecAggTrainer, self).set_trainer_job(job)
        self._send_program = job._trainer_send_program
        self._recv_program = job._trainer_recv_program
187 188
        self_step = job._strategy._inner_step
        self._param_name_list = job._strategy._param_name_list
189 190 191 192

    def reset(self):
        self.cur_step = 0

193
    def run(self, feed, fetch):
194 195 196 197 198 199
        self._logger.debug("begin to run SecAggTrainer, cur_step=%d, inner_step=%d" %
                           (self.cur_step, self._step))
        if self.cur_step % self._step == 0:
            self._logger.debug("begin to run recv program")
            self.exe.run(self._recv_program)
        scope = fluid.global_scope()
G
guru4elephant 已提交
200
        self._logger.debug("begin to run current step")
F
frankwhzhang 已提交
201
        loss = self.exe.run(self._main_program, 
G
guru4elephant 已提交
202 203 204
                     feed=feed,
                     fetch_list=fetch)
        if self.cur_step % self._step == 0:
G
guru4elephant 已提交
205
            self._logger.debug("begin to run send program")
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
            noise = 0.0
            scale = pow(10.0, 5)
            digestmod="SHA256"
            # 1. load priv key and other's pub key
            dh = DiffieHellman(group=15, key_length=256)
            dh.load_private_key(self._key_dir + str(self._trainer_id) + "_priv_key.txt")
            key = str(self._step_id).encode("utf-8")
            for i in range(self._trainer_num):
                if i != self._trainer_id:
                    f = open(self._key_dir + str(i) + "_pub_key.txt", "r")
                    public_key = int(f.read())
                    dh.generate_shared_secret(public_key, echo_return_key=True)
                    msg = dh.shared_key.encode("utf-8")
                    hex_res1 = hmac.new(key=key, msg=msg, digestmod=digestmod).hexdigest()
                    current_noise = int(hex_res1[0:8], 16) / scale
                    if i > self._trainer_id:
                        noise = noise + current_noise
                    else:
                        noise = noise - current_noise

226
            scope = fluid.global_scope()
227 228 229
            for param_name in self._param_name_list:
                fluid.global_scope().var(param_name + str(self._trainer_id)).get_tensor().set(
                    numpy.array(scope.find_var(param_name + str(self._trainer_id)).get_tensor()) + noise, fluid.CPUPlace())
G
guru4elephant 已提交
230 231
            self.exe.run(self._send_program)
        self.cur_step += 1
F
frankwhzhang 已提交
232
        return loss
G
guru4elephant 已提交
233

G
guru4elephant 已提交
234 235
    def stop(self):
        return False
236