未验证 提交 b1e82031 编写于 作者: Z zhaoyingli 提交者: GitHub

fix strategy (#46256)

上级 ffd35908
......@@ -45,7 +45,7 @@ set_field_default_config(BASE, "gradient_scale", True)
set_field_default_config(BASE, "use_cache", True)
set_field_default_config(BASE, "return_numpy", True)
set_field_default_config(BASE, "all_ranks", False)
set_field_default_config(BASE, "split_data", False)
set_field_default_config(BASE, "split_data", True)
set_field_default_config(BASE, "seed", None)
set_field_default_config(BASE, "reinit", False) # Only for debug
......
......@@ -978,9 +978,10 @@ class Engine:
# extract ckpts by specific model
if isinstance(self._model, paddle.nn.Layer):
if hasattr(
self._model, "gpt"
) and self._model.__class__.__name__ == 'GPTForPretraining':
if hasattr(self._model,
"gpt") and self._model.__class__.__name__ in [
'GPTForPretraining', 'GPTForPretrainingAuto'
]:
exact_ckpts = self._model.gpt.checkpoints
else:
exact_ckpts = recompute.checkpoints
......
......@@ -59,10 +59,11 @@ class BaseConfig(object):
return result_dict
def __repr__(self):
return yaml.dump(self.to_dict(),
default_flow_style=False,
sort_keys=True,
indent=4)
result_dict = self.to_dict()
string = "{"
for k, v in result_dict.items():
string += "\"%s\":\"%s\"," % (k, v)
return string + "}"
def __deepcopy__(self, memo):
cls = self.__class__
......
......@@ -81,6 +81,8 @@ def convert_to_dims_mapping(shard_spec, process_mesh):
for shard in shard_spec:
if shard is None:
dims_mapping.append(-1)
elif process_mesh.topology[process_mesh.dim_names.index(shard)] == 1:
dims_mapping.append(-1)
else:
dims_mapping.append(process_mesh.dim_names.index(shard))
return dims_mapping
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册