未验证 提交 bd2d4fd0 编写于 作者: B Baibaifan 提交者: GitHub

fix_import_distribute_bugs (#40396)

上级 135cf713
...@@ -25,10 +25,9 @@ from collections import OrderedDict ...@@ -25,10 +25,9 @@ from collections import OrderedDict
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import core from paddle.fluid import core
import paddle.distributed as dist
from paddle.optimizer import Optimizer from paddle.optimizer import Optimizer
from paddle.fluid.clip import ClipGradByGlobalNorm from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.distributed.collective import _get_global_group from paddle.distributed.collective import _get_global_group, new_group, broadcast, wait
from ...utils.internal_storage import ParamStorage, GradStorage from ...utils.internal_storage import ParamStorage, GradStorage
from ...meta_parallel.sharding.sharding_utils import Type, device_guard, ShardingClipGrad from ...meta_parallel.sharding.sharding_utils import Type, device_guard, ShardingClipGrad
...@@ -91,8 +90,8 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -91,8 +90,8 @@ class ShardingOptimizerStage2(Optimizer):
filter(lambda x: x.trainable and x.dtype == Type.fp16.value, filter(lambda x: x.trainable and x.dtype == Type.fp16.value,
self._local_params))) > 0 self._local_params))) > 0
self.group = dist.new_group(_get_global_group() self.group = new_group(_get_global_group()
.ranks) if group is None else group .ranks) if group is None else group
self.world_size = self.group.nranks self.world_size = self.group.nranks
self.rank = self.group.rank self.rank = self.group.rank
...@@ -141,14 +140,14 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -141,14 +140,14 @@ class ShardingOptimizerStage2(Optimizer):
""" """
for p in self._local_params: for p in self._local_params:
dist.broadcast( broadcast(
p, p,
src=self._global_root_rank, src=self._global_root_rank,
group=self.group, group=self.group,
use_calc_stream=True) use_calc_stream=True)
# Multi stream operation will be supported later # Multi stream operation will be supported later
dist.wait(tensor=p, group=self.group, use_calc_stream=True) wait(tensor=p, group=self.group, use_calc_stream=True)
def _generate_master_params(self, trainable_params): def _generate_master_params(self, trainable_params):
if self.offload: if self.offload:
...@@ -385,6 +384,12 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -385,6 +384,12 @@ class ShardingOptimizerStage2(Optimizer):
raise RuntimeError( raise RuntimeError(
"optimizer.minimize() not support now, please use optimizer.step()") "optimizer.minimize() not support now, please use optimizer.step()")
def set_state_dict(self, state_dict):
self._optim.set_state_dict(state_dict)
def state_dict(self):
return self._optim.state_dict()
def _clear_cache(self): def _clear_cache(self):
self.__segment_params.clear() self.__segment_params.clear()
self._dtype_rank_params.clear() self._dtype_rank_params.clear()
...@@ -399,14 +404,14 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -399,14 +404,14 @@ class ShardingOptimizerStage2(Optimizer):
# Exchange all the shards with the other ranks # Exchange all the shards with the other ranks
for dtype_per_rank in self.param_storages.values(): for dtype_per_rank in self.param_storages.values():
for dst_rank, internal_storage in dtype_per_rank.items(): for dst_rank, internal_storage in dtype_per_rank.items():
dist.broadcast( broadcast(
tensor=internal_storage.buffer, tensor=internal_storage.buffer,
src=self.group.ranks[dst_rank], src=self.group.ranks[dst_rank],
group=self.group, group=self.group,
use_calc_stream=True) use_calc_stream=True)
# Multi stream operation will be supported later # Multi stream operation will be supported later
dist.wait( wait(
tensor=internal_storage.buffer, tensor=internal_storage.buffer,
group=self.group, group=self.group,
use_calc_stream=True) use_calc_stream=True)
...@@ -28,7 +28,7 @@ from types import MethodType ...@@ -28,7 +28,7 @@ from types import MethodType
import paddle import paddle
from paddle import nn from paddle import nn
import paddle.distributed as dist from paddle.distributed import collective as dist
from paddle.distributed.collective import _get_global_group from paddle.distributed.collective import _get_global_group
from ...utils.internal_storage import GradStorage from ...utils.internal_storage import GradStorage
...@@ -158,6 +158,17 @@ class ShardingStage2(nn.Layer): ...@@ -158,6 +158,17 @@ class ShardingStage2(nn.Layer):
return fw return fw
def set_state_dict(self, state_dict, use_structured_name=True):
self._layer.set_state_dict(
state_dict, use_structured_name=use_structured_name)
def state_dict(self,
destination=None,
include_sublayers=True,
structured_name_prefix=""):
return self._layer.state_dict(
destination=None, include_sublayers=True, structured_name_prefix="")
def _clear_gradients(self): def _clear_gradients(self):
""" """
Set zero to the gradient of the optimizer's current rank trainable parameters. Set zero to the gradient of the optimizer's current rank trainable parameters.
......
...@@ -20,7 +20,6 @@ import logging ...@@ -20,7 +20,6 @@ import logging
import functools import functools
import numpy as np import numpy as np
from itertools import chain from itertools import chain
from functools import reduce
from types import MethodType from types import MethodType
from collections import deque, OrderedDict from collections import deque, OrderedDict
...@@ -28,9 +27,9 @@ import paddle ...@@ -28,9 +27,9 @@ import paddle
from paddle import nn from paddle import nn
from paddle.autograd import PyLayer from paddle.autograd import PyLayer
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.distributed as dist
from paddle.fluid.framework import ParamBase from paddle.fluid.framework import ParamBase
from paddle.fluid.clip import ClipGradByGlobalNorm from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.distributed import collective as dist
from paddle.distributed.collective import _get_global_group from paddle.distributed.collective import _get_global_group
from .sharding_utils import Type, ShardingClipGrad, device_guard from .sharding_utils import Type, ShardingClipGrad, device_guard
...@@ -249,6 +248,17 @@ class ShardingStage3(nn.Layer): ...@@ -249,6 +248,17 @@ class ShardingStage3(nn.Layer):
return fw return fw
def set_state_dict(self, state_dict, use_structured_name=True):
self._layer.set_state_dict(
state_dict, use_structured_name=use_structured_name)
def state_dict(self,
destination=None,
include_sublayers=True,
structured_name_prefix=""):
return self._layer.state_dict(
destination=None, include_sublayers=True, structured_name_prefix="")
def _handle_unslice_params(self): def _handle_unslice_params(self):
buffer_size = dict() buffer_size = dict()
buffer_size[Type.fp32.value] = 0 buffer_size[Type.fp32.value] = 0
...@@ -523,7 +533,7 @@ class ShardingStage3(nn.Layer): ...@@ -523,7 +533,7 @@ class ShardingStage3(nn.Layer):
def _get_allreduce_fn(self, param): def _get_allreduce_fn(self, param):
@paddle.autograd.no_grad() @paddle.autograd.no_grad()
def reduce(*_): def allreduce_(*_):
if param.name in self._task_flow.full_grad.keys(): if param.name in self._task_flow.full_grad.keys():
full_grad = self._task_flow.full_grad[param.name] full_grad = self._task_flow.full_grad[param.name]
# Only support sync allreduce current rank's layer now # Only support sync allreduce current rank's layer now
...@@ -573,7 +583,7 @@ class ShardingStage3(nn.Layer): ...@@ -573,7 +583,7 @@ class ShardingStage3(nn.Layer):
if self._offload: if self._offload:
param.fw_storage = _device2cpu(param.fw_storage, True) param.fw_storage = _device2cpu(param.fw_storage, True)
return reduce return allreduce_
def _param2align(self, param): def _param2align(self, param):
# CUDA alignment 256 bytes # CUDA alignment 256 bytes
......
...@@ -21,7 +21,6 @@ import numpy as np ...@@ -21,7 +21,6 @@ import numpy as np
from types import MethodType from types import MethodType
import paddle import paddle
import paddle.distributed as dist
from paddle import _C_ops from paddle import _C_ops
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import layers from paddle.fluid import layers
......
...@@ -14,8 +14,11 @@ ...@@ -14,8 +14,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import shutil
import numpy as np import numpy as np
import argparse import argparse
import tempfile
import ast import ast
import time import time
import paddle import paddle
...@@ -88,7 +91,8 @@ def train_mlp(model, ...@@ -88,7 +91,8 @@ def train_mlp(model,
batch_size=100, batch_size=100,
use_pure_fp16=False, use_pure_fp16=False,
accumulate_grad=False, accumulate_grad=False,
opt_group=False): opt_group=False,
save_model=False):
if sharding_stage == "dp": if sharding_stage == "dp":
hcg = fleet.get_hybrid_communicate_group() hcg = fleet.get_hybrid_communicate_group()
group = hcg.get_check_parallel_group() group = hcg.get_check_parallel_group()
...@@ -147,6 +151,9 @@ def train_mlp(model, ...@@ -147,6 +151,9 @@ def train_mlp(model,
if accumulate_grad: if accumulate_grad:
optimizer.step() optimizer.step()
optimizer.clear_grad() optimizer.clear_grad()
if save_model:
return model, optimizer
return model.parameters() return model.parameters()
...@@ -158,11 +165,13 @@ def test_dp_stage2(): ...@@ -158,11 +165,13 @@ def test_dp_stage2():
mlp3 = MLP() mlp3 = MLP()
mlp4 = MLP() mlp4 = MLP()
mlp5 = MLP() mlp5 = MLP()
mlp6 = MLP()
mlp1.set_state_dict(state_dict) mlp1.set_state_dict(state_dict)
mlp2.set_state_dict(state_dict) mlp2.set_state_dict(state_dict)
mlp3.set_state_dict(state_dict) mlp3.set_state_dict(state_dict)
mlp4.set_state_dict(state_dict) mlp4.set_state_dict(state_dict)
mlp5.set_state_dict(state_dict) mlp5.set_state_dict(state_dict)
mlp6.set_state_dict(state_dict)
# DP VS stage2 # DP VS stage2
dp_params = train_mlp( dp_params = train_mlp(
...@@ -186,10 +195,29 @@ def test_dp_stage2(): ...@@ -186,10 +195,29 @@ def test_dp_stage2():
# stage2 param list VS param group # stage2 param list VS param group
stage2_params = train_mlp( stage2_params = train_mlp(
mlp2, sharding_stage=2, use_pure_fp16=False, opt_group=True) mlp5, sharding_stage=2, use_pure_fp16=False, opt_group=True)
for i in range(len(dp_params)): for i in range(len(dp_params)):
np.testing.assert_allclose( np.testing.assert_allclose(
dp_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6) dp_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6)
# save/load model
output_dir = tempfile.mkdtemp()
model_file = os.path.join(output_dir, "model.pdmodel")
optimizer_file = os.path.join(output_dir, "model.pdopt")
model_stage2, optimizer_stage2 = train_mlp(
mlp6,
sharding_stage=2,
use_pure_fp16=False,
opt_group=False,
save_model=True)
paddle.save(model_stage2.state_dict(), model_file)
paddle.save(optimizer_stage2.state_dict(), optimizer_file)
m_state_dict = paddle.load(model_file)
opt_state_dict = paddle.load(optimizer_file)
model_stage2.set_state_dict(m_state_dict)
optimizer_stage2.set_state_dict(opt_state_dict)
shutil.rmtree(output_dir)
return return
......
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import shutil
import tempfile
import numpy as np import numpy as np
import argparse import argparse
import ast import ast
...@@ -84,7 +87,8 @@ def train_mlp(model, ...@@ -84,7 +87,8 @@ def train_mlp(model,
batch_size=100, batch_size=100,
opt_group=False, opt_group=False,
sync_comm=False, sync_comm=False,
test_minimize=False): test_minimize=False,
save_model=False):
group = paddle.distributed.new_group([0, 1]) group = paddle.distributed.new_group([0, 1])
if opt_group: if opt_group:
optimizer = optimizer_setting( optimizer = optimizer_setting(
...@@ -162,12 +166,15 @@ def train_mlp(model, ...@@ -162,12 +166,15 @@ def train_mlp(model,
optimizer.clear_grad() optimizer.clear_grad()
if sharding_stage == 3: if sharding_stage == 3:
model.get_all_parameters() model.get_all_parameters()
if save_model:
return model, optimizer
return model.parameters() return model.parameters()
def test_stage2_stage3(): def test_stage2_stage3():
mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8, mlp9 = MLP(), MLP( mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8, mlp9, mlp10 = MLP(
), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP() ), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP()
state_dict = mlp.state_dict() state_dict = mlp.state_dict()
mlp1.set_state_dict(state_dict) mlp1.set_state_dict(state_dict)
mlp2.set_state_dict(state_dict) mlp2.set_state_dict(state_dict)
...@@ -178,6 +185,7 @@ def test_stage2_stage3(): ...@@ -178,6 +185,7 @@ def test_stage2_stage3():
mlp7.set_state_dict(state_dict) mlp7.set_state_dict(state_dict)
mlp8.set_state_dict(state_dict) mlp8.set_state_dict(state_dict)
mlp9.set_state_dict(state_dict) mlp9.set_state_dict(state_dict)
mlp10.set_state_dict(state_dict)
# fp32 # fp32
stage2_params = train_mlp( stage2_params = train_mlp(
...@@ -238,9 +246,27 @@ def test_stage2_stage3(): ...@@ -238,9 +246,27 @@ def test_stage2_stage3():
np.testing.assert_allclose( np.testing.assert_allclose(
stage3_params[i].numpy(), stage3_params_re[i].numpy(), rtol=1e-6) stage3_params[i].numpy(), stage3_params_re[i].numpy(), rtol=1e-6)
# save/load model
output_dir = tempfile.mkdtemp()
model_file = os.path.join(output_dir, "model.pdmodel")
optimizer_file = os.path.join(output_dir, "model.pdopt")
model_stage3, optimizer_stage3 = train_mlp(
mlp9,
sharding_stage=3,
use_pure_fp16=False,
opt_group=False,
save_model=True)
paddle.save(model_stage3.state_dict(), model_file)
paddle.save(optimizer_stage3.state_dict(), optimizer_file)
m_state_dict = paddle.load(model_file)
opt_state_dict = paddle.load(optimizer_file)
model_stage3.set_state_dict(m_state_dict)
optimizer_stage3.set_state_dict(opt_state_dict)
shutil.rmtree(output_dir)
# check optimizer.minimize() error # check optimizer.minimize() error
train_mlp( train_mlp(
mlp9, mlp10,
sharding_stage=3, sharding_stage=3,
use_pure_fp16=False, use_pure_fp16=False,
opt_group=False, opt_group=False,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册