提交 8a8cc243 编写于 作者: J jingqinghe

add trainer.save api

上级 5a6c0869
......@@ -11,8 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle.fluid as fluid
import logging
from paddle.fluid.executor import global_scope
import pickle
from paddle.fluid.io import is_belong_to_optimizer
from paddle_fl.paddle_fl.core.scheduler.agent_master import FLWorkerAgent
import numpy
import hmac
......@@ -83,6 +87,43 @@ class FLTrainer(object):
self.exe,
main_program=infer_program)
def save(self, parameter_dir, model_path):
base_name = os.path.basename(model_path)
assert base_name != "", \
"The input model_path MUST be format of dirname/filename [dirname\\filename in Windows system], but received model_path is empty string."
dir_name = os.path.dirname(model_path)
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name)
def get_tensor(var_name):
t = global_scope().find_var(var_name).get_tensor()
return numpy.array(t)
parameter_list = []
with open(parameter_dir + '/para_info', 'r') as fin:
for line in fin:
current_para = line[:-1]
parameter_list.append(current_para)
param_dict = {p: get_tensor(p) for p in parameter_list}
with open(model_path + ".pdparams", 'wb') as f:
pickle.dump(param_dict, f, protocol=2)
optimizer_var_list = list(
filter(is_belong_to_optimizer, self._main_program.list_vars()))
opt_dict = {p.name: get_tensor(p.name) for p in optimizer_var_list}
with open(model_path + ".pdopt", 'wb') as f:
pickle.dump(opt_dict, f, protocol=2)
main_program = self._main_program.clone()
self._main_program.desc.flush()
main_program.desc._set_version()
fluid.core.save_op_compatible_info(self._main_program.desc)
with open(model_path + ".pdmodel", "wb") as f:
f.write(self._main_program.desc.serialize_to_string())
def stop(self):
# ask for termination with master endpoint
# currently not open sourced, will release the code later
......
......@@ -79,3 +79,7 @@ while not trainer.stop():
train_test_feed=feeder)
print("Test with epoch %d, accuracy: %s" % (epoch_id, acc_val))
if trainer_id == 0:
save_dir = (output_folder + "/epoch_%d") % epoch_id
trainer.save(para_dir, save_dir)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册