# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License import errno import logging import os import pickle import re import numpy as np import paddle from paddle.framework import core from ..utils.log_utils import get_logger from .process_group import _g_process_group_map from .utils import get_dist_attr def check_filename(re_exp, filename): if re.search(re_exp, filename): return True else: return False def _process_path(path): filename = os.path.basename(path) if filename == "": raise ValueError( "path should be of 'dirname/filename' format, but received filename is empty string" ) try: dirname = os.path.dirname(path) os.makedirs(dirname) except OSError as e: if e.errno != errno.EEXIST: raise return dirname, filename class DistributedSaver: def __init__(self): self._logger = get_logger(logging.INFO) def save(self, path, serial_program, dist_main_program, dist_context): def _save_state(program, path, mode="param"): state = { k: np.array(v) for k, v in program.state_dict(mode).items() } with open(path, "wb") as f: pickle.dump(state, f) dirname, filename = _process_path(path) rank_id = paddle.distributed.get_rank() # save serial program when rank id is 0 if rank_id == 0: self._save_rank_mapping(dirname) serial_model_filename = filename + "_serial.pdmodel" serial_model_path = os.path.join(dirname, serial_model_filename) with open(serial_model_path, "wb") as f: f.write(serial_program.desc.serialize_to_string()) # save distributed main program dist_model_filename = filename + "_dist" + str(rank_id) + ".pdmodel" dist_model_path = os.path.join(dirname, dist_model_filename) with open(dist_model_path, "wb") as f: f.write(dist_main_program.desc.serialize_to_string()) # save distributed attribute dist_attr_filename = filename + "_dist" + str(rank_id) + ".pdattr" dist_attr_path = os.path.join(dirname, dist_attr_filename) dist_attrs = get_dist_attr(dist_main_program, dist_context) with open(dist_attr_path, "wb") as f: pickle.dump(dist_attrs, f) # save distributed params dist_param_filename = filename + "_dist" + str(rank_id) + ".pdparams" dist_param_path = os.path.join(dirname, dist_param_filename) _save_state(dist_main_program, dist_param_path) # save distributed opt states dist_opt_filename = filename + "_dist" + str(rank_id) + ".pdopt" dist_opt_path = os.path.join(dirname, dist_opt_filename) _save_state(dist_main_program, dist_opt_path, "opt") # TODO:save cluster.json def load(self, path, load_optimizer=True): # TODO: if `program` is None, load `path.pdmodel`. def _load_file(filename, dirname, suffix="pdparams"): file_list = [] for file in os.listdir(dirname): if check_filename(f'{filename}(.*)_dist(.*).{suffix}', file): file_list.append(os.path.join(dirname, file)) file_list.sort() return file_list def _load_state(filename, dirname, suffix="pdparams"): file_list = _load_file(filename, dirname, suffix) state_dict = {} for file in file_list: with open(file, 'rb') as f: state_dict_info = pickle.load(f, encoding='latin1') for name, value in state_dict_info.items(): if name in state_dict: state_dict[name].append(np.array(value)) else: state_dict[name] = [np.array(value)] self._logger.info(f"Load param file: {file_list}") return state_dict filename = os.path.basename(path) if filename == "": raise ValueError( "path should be of 'dirname/filename' format, but received filename is empty string" ) dirname = os.path.dirname(path) # load path.pdparam and path.pdopt param_state_dict = _load_state(filename, dirname) opt_state_dict = ( _load_state(filename, dirname, "pdopt") if load_optimizer else {} ) state_dict = dict(param_state_dict, **opt_state_dict) # load path.pdattr dist_attr_file_list = _load_file(filename, dirname, "pdattr") self._logger.info( f"Load distributed attribute file: {dist_attr_file_list}" ) dist_attr = {} for dist_attr_file in dist_attr_file_list: with open(dist_attr_file, 'rb') as f: dist_attr_info = pickle.load(f, encoding='latin1') for name, attr in dist_attr_info.items(): if name not in dist_attr: dist_attr[name] = attr return state_dict, dist_attr def save_inference_model(self, path, feed_vars, fetch_vars, exe, **kwargs): dirname, filename = _process_path(path) # save distributed inference program rank_id = paddle.distributed.get_rank() if rank_id == 0: self._save_rank_mapping(dirname) op_role_key = core.op_proto_and_checker_maker.kOpRoleAttrName() op_role_forward = int(core.op_proto_and_checker_maker.OpRole.Forward) dist_main_prog = kwargs.get('program', None) if not dist_main_prog: dist_main_prog = paddle.static.default_main_program() global_block = dist_main_prog.global_block() ops = global_block.ops feed_vars_names = [x.name for x in feed_vars] fetch_vars_names = [x.name for x in fetch_vars] last_idx = -1 for idx, op in enumerate(ops): if op.attr(op_role_key) != op_role_forward: continue if op.type == "read" or op.type == "feed" or op.type == 'recv_v2': feed_vars_names += op.output("Out") if op.type == "send_v2": fetch_vars_names += op.input("X") last_idx = max(idx, last_idx) for out_name in op.output_arg_names: if out_name in fetch_vars_names: last_idx = max(idx, last_idx) used_inputs = [] used_outputs = [] for idx, op in enumerate(ops): if idx > last_idx: break used_inputs += op.input_arg_names used_outputs += op.output_arg_names # delete duplicated elements and keep order feed_vars_names = list({}.fromkeys(feed_vars_names).keys()) used_inputs = list({}.fromkeys(used_inputs).keys()) fetch_vars_names = list({}.fromkeys(fetch_vars_names).keys()) used_outputs = list({}.fromkeys(used_outputs).keys()) dist_feed_vars_names = [ var_name for var_name in feed_vars_names if var_name in used_inputs ] dist_fetch_vars_names = [ var_name for var_name in fetch_vars_names if var_name in used_outputs ] dist_feed_vars = list( reversed([global_block.vars[name] for name in dist_feed_vars_names]) ) dist_fetch_vars = [ global_block.vars[name] for name in dist_fetch_vars_names ] dist_filename = filename + "_dist" + str(rank_id) dist_path = os.path.join(dirname, dist_filename) legacy_format = kwargs.get("legacy_format", False) paddle.static.save_inference_model( dist_path, dist_feed_vars, dist_fetch_vars, exe, program=dist_main_prog, legacy_format=legacy_format, ) def _save_rank_mapping(self, dirname): path = os.path.join(dirname, 'rank_mapping.csv') f = open(path, 'w') f.write('[ring_id -> ranks]\n') for process_group in _g_process_group_map.values(): ring_id = process_group._group_id ranks = [str(rank) for rank in process_group._ranks] id_to_rank = str(ring_id) + "," + ",".join(ranks) + '\n' f.write(id_to_rank) id_to_rank = "" f.write('[rank -> ring_ids]\n') rank_to_id_dict = {} for process_group in _g_process_group_map.values(): ring_id = process_group._group_id for rank in process_group._ranks: if rank in rank_to_id_dict: rank_to_id_dict[rank].append(str(ring_id)) else: rank_to_id_dict[rank] = [str(ring_id)] rank_to_id = "" for item, val in rank_to_id_dict.items(): rank_to_id += str(item) + "," rank_to_id += ",".join(val) + "\n" f.write(rank_to_id) rank_to_id = "" f.close()