diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 4f3a6f4768933d90782445edbc74f4f446a15a9b..5775a734c870690613d474f279ba2e20b720b759 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -1009,16 +1009,18 @@ def _parallel_linear(x, name=name) linear_out = linear(x) - startup_block = paddle.static.default_startup_program().global_block() - main_block = paddle.static.default_main_program().global_block() - startup_block.vars[linear.weight.name].is_distributed = True - main_block.vars[linear.weight.name].is_distributed = True + startup_block = paddle.static.default_startup_program().current_block() + main_block = paddle.static.default_main_program().current_block() + startup_block._find_var_recursive(linear.weight.name).is_distributed = True + main_block._find_var_recursive(linear.weight.name).is_distributed = True + # set is_distributed for splited bias # if a linear layer is splited by row, each rank would hold a complete bias and they should be the same in each rank. # if a linear layer is splited by col, the bias would also be split into each rank as its weight if axis == 1 and linear._bias_attr != False: - startup_block.vars[linear.bias.name].is_distributed = True - main_block.vars[linear.bias.name].is_distributed = True + startup_block._find_var_recursive( + linear.bias.name).is_distributed = True + main_block._find_var_recursive(linear.bias.name).is_distributed = True if not gather_out: return linear_out