You need to sign in or sign up before continuing.
未验证 提交 20b9be65 编写于 作者: J JZ-LIANG 提交者: GitHub

[Tensor Parallelism] split fix bug (#33015)

上级 a2a45d8d
...@@ -977,6 +977,11 @@ def _parallel_linear(x, ...@@ -977,6 +977,11 @@ def _parallel_linear(x,
group=None): group=None):
""" """
Parallel Linear Parallel Linear
axis the dimension of the parameter of linear layer.
axis = 0: the row dimension
axid = 1: the col dimension
""" """
if group is not None and not group.is_member(): if group is not None and not group.is_member():
return return
...@@ -1008,6 +1013,12 @@ def _parallel_linear(x, ...@@ -1008,6 +1013,12 @@ def _parallel_linear(x,
main_block = paddle.static.default_main_program().global_block() main_block = paddle.static.default_main_program().global_block()
startup_block.vars[linear.weight.name].is_distributed = True startup_block.vars[linear.weight.name].is_distributed = True
main_block.vars[linear.weight.name].is_distributed = True main_block.vars[linear.weight.name].is_distributed = True
# set is_distributed for splited bias
# if a linear layer is splited by row, each rank would hold a complete bias and they should be the same in each rank.
# if a linear layer is splited by col, the bias would also be split into each rank as its weight
if axis == 1 and linear._bias_attr != False:
startup_block.vars[linear.bias.name].is_distributed = True
main_block.vars[linear.bias.name].is_distributed = True
if not gather_out: return linear_out if not gather_out: return linear_out
......
...@@ -814,7 +814,7 @@ class DistributedStrategy(object): ...@@ -814,7 +814,7 @@ class DistributedStrategy(object):
"sharding_segment_strategy": "segment_broadcast_MB", "sharding_segment_strategy": "segment_broadcast_MB",
"segment_broadcast_MB": 32, "segment_broadcast_MB": 32,
"sharding_degree": 8, "sharding_degree": 8,
"sharding_degree": 2, "dp_degree": 2,
"gradient_merge_acc_step": 4, "gradient_merge_acc_step": 4,
} }
""" """
......
...@@ -145,6 +145,7 @@ gray_list = { ...@@ -145,6 +145,7 @@ gray_list = {
'sign', 'sign',
'cast', 'cast',
'fused_bn_add_activation', 'fused_bn_add_activation',
'c_identity',
} }
# The set of ops that don't support fp16 calculation # The set of ops that don't support fp16 calculation
......
...@@ -69,7 +69,7 @@ class TestColumnParallelLinearAPI(TestCollectiveAPIRunnerBase): ...@@ -69,7 +69,7 @@ class TestColumnParallelLinearAPI(TestCollectiveAPIRunnerBase):
axis=1, axis=1,
num_partitions=2, num_partitions=2,
weight_attr=param_attr, weight_attr=param_attr,
bias_attr=False, ) bias_attr=True, )
return [linear_out] return [linear_out]
......
...@@ -65,12 +65,12 @@ class TestRowParallelLinearAPI(TestCollectiveAPIRunnerBase): ...@@ -65,12 +65,12 @@ class TestRowParallelLinearAPI(TestCollectiveAPIRunnerBase):
linear_out = paddle.distributed.split( linear_out = paddle.distributed.split(
data, data,
size=(1000, 8), size=(1000, 16),
operation='linear', operation='linear',
axis=0, axis=0,
num_partitions=2, num_partitions=2,
weight_attr=param_attr, weight_attr=param_attr,
bias_attr=False, ) bias_attr=True, )
return [linear_out] return [linear_out]
......
...@@ -154,7 +154,10 @@ class TestDistBase(unittest.TestCase): ...@@ -154,7 +154,10 @@ class TestDistBase(unittest.TestCase):
#update environment #update environment
env0.update(envs) env0.update(envs)
env1.update(envs) env1.update(envs)
tr_cmd = "%s %s" if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
tr_cmd = "%s -m coverage run --branch -p %s"
else:
tr_cmd = "%s %s"
tr0_cmd = tr_cmd % (self._python_interp, model_file) tr0_cmd = tr_cmd % (self._python_interp, model_file)
tr1_cmd = tr_cmd % (self._python_interp, model_file) tr1_cmd = tr_cmd % (self._python_interp, model_file)
tr0_pipe = open("/tmp/tr0_err_%d.log" % os.getpid(), "w") tr0_pipe = open("/tmp/tr0_err_%d.log" % os.getpid(), "w")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册