未验证 提交 b1e82031 编写于 作者: Z zhaoyingli 提交者: GitHub

fix strategy (#46256)

上级 ffd35908
...@@ -45,7 +45,7 @@ set_field_default_config(BASE, "gradient_scale", True) ...@@ -45,7 +45,7 @@ set_field_default_config(BASE, "gradient_scale", True)
set_field_default_config(BASE, "use_cache", True) set_field_default_config(BASE, "use_cache", True)
set_field_default_config(BASE, "return_numpy", True) set_field_default_config(BASE, "return_numpy", True)
set_field_default_config(BASE, "all_ranks", False) set_field_default_config(BASE, "all_ranks", False)
set_field_default_config(BASE, "split_data", False) set_field_default_config(BASE, "split_data", True)
set_field_default_config(BASE, "seed", None) set_field_default_config(BASE, "seed", None)
set_field_default_config(BASE, "reinit", False) # Only for debug set_field_default_config(BASE, "reinit", False) # Only for debug
......
...@@ -978,9 +978,10 @@ class Engine: ...@@ -978,9 +978,10 @@ class Engine:
# extract ckpts by specific model # extract ckpts by specific model
if isinstance(self._model, paddle.nn.Layer): if isinstance(self._model, paddle.nn.Layer):
if hasattr( if hasattr(self._model,
self._model, "gpt" "gpt") and self._model.__class__.__name__ in [
) and self._model.__class__.__name__ == 'GPTForPretraining': 'GPTForPretraining', 'GPTForPretrainingAuto'
]:
exact_ckpts = self._model.gpt.checkpoints exact_ckpts = self._model.gpt.checkpoints
else: else:
exact_ckpts = recompute.checkpoints exact_ckpts = recompute.checkpoints
......
...@@ -59,10 +59,11 @@ class BaseConfig(object): ...@@ -59,10 +59,11 @@ class BaseConfig(object):
return result_dict return result_dict
def __repr__(self): def __repr__(self):
return yaml.dump(self.to_dict(), result_dict = self.to_dict()
default_flow_style=False, string = "{"
sort_keys=True, for k, v in result_dict.items():
indent=4) string += "\"%s\":\"%s\"," % (k, v)
return string + "}"
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
cls = self.__class__ cls = self.__class__
......
...@@ -81,6 +81,8 @@ def convert_to_dims_mapping(shard_spec, process_mesh): ...@@ -81,6 +81,8 @@ def convert_to_dims_mapping(shard_spec, process_mesh):
for shard in shard_spec: for shard in shard_spec:
if shard is None: if shard is None:
dims_mapping.append(-1) dims_mapping.append(-1)
elif process_mesh.topology[process_mesh.dim_names.index(shard)] == 1:
dims_mapping.append(-1)
else: else:
dims_mapping.append(process_mesh.dim_names.index(shard)) dims_mapping.append(process_mesh.dim_names.index(shard))
return dims_mapping return dims_mapping
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册