未验证 提交 d5cc7bff 编写于 作者: 李季 提交者: GitHub

update mp (#33194)

* update mp
上级 fdbdef0e
...@@ -1009,16 +1009,18 @@ def _parallel_linear(x, ...@@ -1009,16 +1009,18 @@ def _parallel_linear(x,
name=name) name=name)
linear_out = linear(x) linear_out = linear(x)
startup_block = paddle.static.default_startup_program().global_block() startup_block = paddle.static.default_startup_program().current_block()
main_block = paddle.static.default_main_program().global_block() main_block = paddle.static.default_main_program().current_block()
startup_block.vars[linear.weight.name].is_distributed = True startup_block._find_var_recursive(linear.weight.name).is_distributed = True
main_block.vars[linear.weight.name].is_distributed = True main_block._find_var_recursive(linear.weight.name).is_distributed = True
# set is_distributed for splited bias # set is_distributed for splited bias
# if a linear layer is splited by row, each rank would hold a complete bias and they should be the same in each rank. # if a linear layer is splited by row, each rank would hold a complete bias and they should be the same in each rank.
# if a linear layer is splited by col, the bias would also be split into each rank as its weight # if a linear layer is splited by col, the bias would also be split into each rank as its weight
if axis == 1 and linear._bias_attr != False: if axis == 1 and linear._bias_attr != False:
startup_block.vars[linear.bias.name].is_distributed = True startup_block._find_var_recursive(
main_block.vars[linear.bias.name].is_distributed = True linear.bias.name).is_distributed = True
main_block._find_var_recursive(linear.bias.name).is_distributed = True
if not gather_out: return linear_out if not gather_out: return linear_out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册