未验证 提交 20b9be65 编写于 作者: J JZ-LIANG 提交者: GitHub

[Tensor Parallelism] split fix bug (#33015)

上级 a2a45d8d
......@@ -977,6 +977,11 @@ def _parallel_linear(x,
group=None):
"""
Parallel Linear
axis the dimension of the parameter of linear layer.
axis = 0: the row dimension
axid = 1: the col dimension
"""
if group is not None and not group.is_member():
return
......@@ -1008,6 +1013,12 @@ def _parallel_linear(x,
main_block = paddle.static.default_main_program().global_block()
startup_block.vars[linear.weight.name].is_distributed = True
main_block.vars[linear.weight.name].is_distributed = True
# set is_distributed for splited bias
# if a linear layer is splited by row, each rank would hold a complete bias and they should be the same in each rank.
# if a linear layer is splited by col, the bias would also be split into each rank as its weight
if axis == 1 and linear._bias_attr != False:
startup_block.vars[linear.bias.name].is_distributed = True
main_block.vars[linear.bias.name].is_distributed = True
if not gather_out: return linear_out
......
......@@ -814,7 +814,7 @@ class DistributedStrategy(object):
"sharding_segment_strategy": "segment_broadcast_MB",
"segment_broadcast_MB": 32,
"sharding_degree": 8,
"sharding_degree": 2,
"dp_degree": 2,
"gradient_merge_acc_step": 4,
}
"""
......
......@@ -145,6 +145,7 @@ gray_list = {
'sign',
'cast',
'fused_bn_add_activation',
'c_identity',
}
# The set of ops that don't support fp16 calculation
......
......@@ -69,7 +69,7 @@ class TestColumnParallelLinearAPI(TestCollectiveAPIRunnerBase):
axis=1,
num_partitions=2,
weight_attr=param_attr,
bias_attr=False, )
bias_attr=True, )
return [linear_out]
......
......@@ -65,12 +65,12 @@ class TestRowParallelLinearAPI(TestCollectiveAPIRunnerBase):
linear_out = paddle.distributed.split(
data,
size=(1000, 8),
size=(1000, 16),
operation='linear',
axis=0,
num_partitions=2,
weight_attr=param_attr,
bias_attr=False, )
bias_attr=True, )
return [linear_out]
......
......@@ -154,7 +154,10 @@ class TestDistBase(unittest.TestCase):
#update environment
env0.update(envs)
env1.update(envs)
tr_cmd = "%s %s"
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
tr_cmd = "%s -m coverage run --branch -p %s"
else:
tr_cmd = "%s %s"
tr0_cmd = tr_cmd % (self._python_interp, model_file)
tr1_cmd = tr_cmd % (self._python_interp, model_file)
tr0_pipe = open("/tmp/tr0_err_%d.log" % os.getpid(), "w")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册