# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import collections import copy import os import pickle import numpy as np import paddle import paddle.distributed as dist from paddle.distributed.auto_parallel.static.converter import Converter from paddle.distributed.auto_parallel.static.dist_context import ( get_default_distributed_context, ) from paddle.distributed.auto_parallel.static.utils import ( is_backward_op, is_forward_op, is_loss_op, ) from paddle.fluid import core from paddle.fluid.framework import Program from paddle.static.io import deserialize_program _valid_types = [ core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS, core.VarDesc.VarType.LOD_TENSOR_ARRAY, ] paddle.enable_static() class AutoAlignTool: """ This is an automatic parallel precision alignment tool。 """ def __init__(self, program: Program, step=1, fetch_list=None): """Set some initialization information of the tool. step: Step when returning a specific variable name。 fetch_list: initialization fetch_list.When a specific step is not reached, return this. It can combine with Engine class。 example:in Engine.fit function,like this try: fetch_list = [] align_tool = AutoAlignTool(self.main_program, 0, fetch_names) level = 0 fetch_list = align_tool.get_var(level, step) outs = self._executor.run( self.main_program, fetch_list=fetch_list, use_program_cache=self._strategy.use_cache, return_numpy=self._strategy.return_numpy, ) if fetch_list != fetch_names: align_tool.save(dir_path, outs, fetch_list, self._dist_contexts["train"], self.serial) exit(0) except core.EOFException: break """ assert isinstance(program, Program) self._program = program self._blocks = program.blocks self._step = step self._fetch_list = fetch_list assert self._blocks is not None def set_step(self, step): self._step = step def get_var(self, level, step): """ level must be in [0,1,2,3,4,5]. """ if step != self._step or step == -1: return self._fetch_list if level == 0: return self.get_loss_lr_var() elif level == 1: return self.get_data_var() elif level == 2: return self.get_param_var() elif level == 3: return self.get_param_grad_var() elif level == 4: return self.get_forward_tmp_var() elif level == 5: return self.get_backward_tmp_var() else: raise ValueError() def set_program(self, program: Program): assert isinstance(program, Program) self._program = program self._blocks = program.blocks assert self._blocks is not None def get_loss_lr_var(self): """ Returns the variable name of learning rate and loss """ fetch_set = set() loss_ops = [] for block in self._blocks: for op in block.ops: if is_loss_op(op): assert ( len(op.desc.output_arg_names()) == 1 ), "loss op should only output loss var" loss_ops.append(op) for block in self._blocks: for varname in block.vars: var = block._find_var_recursive(varname) if var is None or var.type not in _valid_types: continue if "learning_rate" in var.name: fetch_set.add(var.name) for loss_op in loss_ops: fetch_set.add(loss_op.output_arg_names[0]) return list(fetch_set) def get_data_var(self): """ Returns the variable name of data. """ fetch_set = set() for block in self._blocks: for varname in block.vars: var = block._find_var_recursive(varname) if var is None or var.type not in _valid_types: continue if var.is_data: fetch_set.add(var.name) return list(fetch_set) def get_param_var(self): """ Returns the variable name of parameters. """ fetch_set = set() for block in self._blocks: for op in block.ops: if is_backward_op(op): break for varname in op.input_arg_names + op.output_arg_names: var = block._find_var_recursive(varname) if var is None or var.type not in _valid_types: continue if var.is_parameter: fetch_set.add(varname) return list(fetch_set) def get_param_grad_var(self): """ Returns the variable name of parameters' gradient. """ fetch_set = set() for block in self._blocks: for op in block.ops: if is_forward_op(op): continue for varname in op.input_arg_names + op.output_arg_names: if "@GRAD" not in varname: continue fwd_varname = varname.split("@GRAD")[0] fwd_var = block._find_var_recursive(fwd_varname) if fwd_var is None or fwd_var.type not in _valid_types: continue if fwd_var.is_parameter is False: continue var = block._find_var_recursive(varname) if var is None or var.type not in _valid_types: continue fetch_set.add(varname) return list(fetch_set) def get_forward_tmp_var(self): """ Returns the name of the temporary variable in the forward propagation """ fetch_set = set() loss_lr_list = self.get_loss_lr_var() for block in self._blocks: for op in block.ops: if is_backward_op(op): break for varname in op.input_arg_names + op.output_arg_names: if varname in loss_lr_list: continue var = block._find_var_recursive(varname) if var is None or var.type not in _valid_types: continue if var.is_data or var.is_parameter: continue fetch_set.add(varname) return list(fetch_set) def get_backward_tmp_var(self): """ Returns the name of a temporary variable in back-propagation """ fetch_set = set() loss_lr_list = self.get_loss_lr_var() forward_tmp_list = self.get_forward_tmp_var() for block in self._blocks: for op in block.ops: if is_backward_op(op): for varname in op.input_arg_names + op.output_arg_names: if ( varname in loss_lr_list or varname in forward_tmp_list ): continue if "@GRAD" in varname: fwd_varname = varname.split("@GRAD")[0] fwd_var = block._find_var_recursive(fwd_varname) if ( fwd_var is not None and fwd_var.type in _valid_types ): if fwd_var.is_parameter: continue var = block._find_var_recursive(varname) if var is None or var.type not in _valid_types: continue if var.is_data or var.is_parameter: continue fetch_set.add(varname) return list(fetch_set) def save(self, save_dir, vars, fetch_list, dist_context=None): """ save fetch variables, distributed properties of variables and program. """ if os.path.exists(save_dir) is False: os.mkdir(save_dir) if dist_context is None: dist_context = get_default_distributed_context() assert os.path.exists(save_dir) if dist.get_world_size() == 1: vars_path = os.path.join(save_dir, "vars.pkl") program_path = os.path.join(save_dir, "program.pdmodel") dist_attr_path = os.path.join(save_dir, "dist_attr.pkl") else: vars_path = os.path.join( save_dir, f"vars_rank{dist.get_rank()}.pkl" ) program_path = os.path.join( save_dir, f"program_rank{dist.get_rank()}.pdmodel" ) dist_attr_path = os.path.join( save_dir, f"dist_attr_rank{dist.get_rank()}.pkl" ) if vars is not None: vars_dict = {} assert len(fetch_list) == len(vars) for i in range(len(fetch_list)): if vars[i] is None: continue vars_dict[fetch_list[i]] = vars[i] with open(vars_path, "wb") as f: pickle.dump(vars_dict, f) dist_attr = {} for var in self._program.list_vars(): if var.name not in fetch_list: continue tensor_dist_attr = ( dist_context.get_tensor_dist_attr_for_program(var) ) if tensor_dist_attr is None: continue process_mesh = tensor_dist_attr.process_mesh dims_mapping = tensor_dist_attr.dims_mapping dist_attr[var.name] = { "process_shape": process_mesh.shape, "process_group": process_mesh.process_ids, "dims_mapping": dims_mapping, } if len(dist_attr) > 0: with open(dist_attr_path, "wb") as f: pickle.dump(dist_attr, f) if self._program is not None: with open(program_path, "wb") as f: f.write(self._program.desc.serialize_to_string()) @staticmethod def load(save_dir): assert os.path.exists(save_dir) filename_list = sorted(os.listdir(save_dir)) vars_list = [] program_list = [] dist_attr_list = [] for filename in filename_list: filepath = os.path.join(save_dir, filename) assert os.path.isfile(filepath) if "vars" in filename: assert filename.endswith("pkl") with (open(filepath, "rb")) as f: vars_list.append(pickle.load(f)) elif "program" in filename: assert filename.endswith("pdmodel") with open(filepath, "rb") as f: program_string = f.read() program_list.append(deserialize_program(program_string)) elif "dist_attr" in filename: assert filename.endswith("pkl") with (open(filepath, "rb")) as f: dist_attr_list.append(pickle.load(f)) dist_attr_map = {} for dist_attrs in dist_attr_list: for dist_attr_name in dist_attrs.keys(): if dist_attr_name not in dist_attr_map: dist_attr_map[dist_attr_name] = dist_attrs[dist_attr_name] assert len(vars_list) == len(program_list) return vars_list, program_list, dist_attr_map @staticmethod def convert_src_tensor_2_dst_tensor(vars_list, src_attr_map, dst_attr_map): """ Converter is a class object for auto parallel to convert tensors from one parallel strategy to another one. Tensors will merge and slice value with their strategy when strategies are different. But like dp to pp or dp to serial is not supported. """ assert len(vars_list) >= 1 # if dist_attr_map is None or len(dist_attr_map) == 0 or len(vars_list) == 1: if src_attr_map is None or len(src_attr_map) == 0: return vars_list[0] dst_strategys = {} src_strategys = {} tensors_dict = {} convert_tensor_dict = None for var_name in src_attr_map.keys(): assert var_name not in dst_strategys dist_vars = [] for vars in vars_list: if var_name in vars.keys(): dist_vars.append(vars[var_name]) if len(dist_vars) == 0: continue if var_name in dst_attr_map and var_name in src_attr_map: dst_strategys[var_name] = copy.deepcopy(dst_attr_map[var_name]) src_strategys[var_name] = copy.deepcopy(src_attr_map[var_name]) tensors_dict[var_name] = dist_vars if src_attr_map == dst_attr_map: return tensors_dict converter = Converter(tensors_dict, src_strategys, dst_strategys) convert_tensor_dict = converter.convert() return convert_tensor_dict @staticmethod def find_diff_vars(fixed_vars_map, query_vars_map): """ Found two variable names with different variable lists """ diff_var_name_list = set() for var_name in fixed_vars_map.keys(): if var_name in query_vars_map: fixed_vars = fixed_vars_map[var_name] query_vars = query_vars_map[var_name] if isinstance(fixed_vars, np.ndarray): fixed_vars = [fixed_vars] if isinstance(query_vars, np.ndarray): query_vars = [query_vars] length = min(len(fixed_vars), len(query_vars)) if len(fixed_vars) != len(query_vars): print() for i in range(length): if not np.allclose(fixed_vars[i], query_vars[i]): diff_var_name_list.add(var_name) return diff_var_name_list @staticmethod def diff_informations(right_dir, wrong_dir): """ Find the corresponding operator according to the variable name. """ ( right_vars_list, right_program_list, right_dist_attr_map, ) = AutoAlignTool.load(right_dir) ( wrong_vars_list, wrong_program_list, wrong_dist_attr_map, ) = AutoAlignTool.load(wrong_dir) right_tensors_dict = AutoAlignTool.convert_src_tensor_2_dst_tensor( right_vars_list, right_dist_attr_map, right_dist_attr_map ) wrong_tensors_dict = AutoAlignTool.convert_src_tensor_2_dst_tensor( wrong_vars_list, wrong_dist_attr_map, right_dist_attr_map ) diff_var_name_list = AutoAlignTool.find_diff_vars( right_tensors_dict, wrong_tensors_dict ) diff_ops_varname_dict = collections.OrderedDict() for program in wrong_program_list: for block in program.blocks: for op in block.ops: for varname in op.input_arg_names + op.output_arg_names: if varname in diff_var_name_list: if len(diff_ops_varname_dict) == 0: print( "first different op:\n", op, f"\ndifferent varname is:{varname}", ) if op not in diff_ops_varname_dict: diff_ops_varname_dict[op] = [varname] else: diff_ops_varname_dict[op].append(varname) return diff_ops_varname_dict @staticmethod def diff_informations_from_dirs(right_dirs, wrong_dirs): right_vars_list = [] right_program_list = [] right_dist_attr_map = {} for right_dir in right_dirs: ( tmp_vars_list, right_program_list, tmp_dist_attr_map, ) = AutoAlignTool.load(right_dir) if len(right_vars_list) == 0: right_vars_list = tmp_vars_list else: for i in range(len(tmp_vars_list)): vars_list = tmp_vars_list[i] for key in vars_list.keys(): if key not in right_vars_list[i].keys(): right_vars_list[i][key] = vars_list[key] for key in tmp_dist_attr_map.keys(): if key not in right_dist_attr_map: right_dist_attr_map[key] = tmp_dist_attr_map[key] wrong_vars_list = [] wrong_program_list = [] wrong_dist_attr_map = {} for wrong_dir in wrong_dirs: ( tmp_vars_list, wrong_program_list, tmp_dist_attr_map, ) = AutoAlignTool.load(wrong_dir) if len(wrong_vars_list) == 0: wrong_vars_list = tmp_vars_list else: for i in range(len(tmp_vars_list)): vars_list = tmp_vars_list[i] for key in vars_list.keys(): if key not in wrong_vars_list[i].keys(): wrong_vars_list[i][key] = vars_list[key] for key in tmp_dist_attr_map.keys(): if key not in wrong_dist_attr_map: wrong_dist_attr_map[key] = tmp_dist_attr_map[key] right_tensors_dict = AutoAlignTool.convert_src_tensor_2_dst_tensor( right_vars_list, right_dist_attr_map, right_dist_attr_map ) wrong_tensors_dict = AutoAlignTool.convert_src_tensor_2_dst_tensor( wrong_vars_list, wrong_dist_attr_map, right_dist_attr_map ) diff_var_name_list = AutoAlignTool.find_diff_vars( right_tensors_dict, wrong_tensors_dict ) diff_ops_varname_dict = collections.OrderedDict() for program in wrong_program_list: for block in program.blocks: for op in block.ops: for varname in op.input_arg_names + op.output_arg_names: if varname in diff_var_name_list: if len(diff_ops_varname_dict) == 0: print( "first different op:\n", op, f"\ndifferent varname is:{varname}", ) if op not in diff_ops_varname_dict: diff_ops_varname_dict[op] = [varname] else: diff_ops_varname_dict[op].append(varname) return diff_ops_varname_dict