未验证 提交 89ce6db8 编写于 作者: X xiayanming 提交者: GitHub

[Auto Parallel] Add general gradient merge pass to support auto parallel (#38259)

* [Auto Parallel] add gradient merge pass

* fix ci issue

* fix ci issue

* fix ci issue

* fix ci issue

* fix ci issue

* fix ci issue

* fix ci issue

* fix ci issue

* fix ci issue

* fix pr review

* fix pr review

* fix pr review

* fix pr review

* fix pr review

* fix pr review
上级 8898dce1
...@@ -131,6 +131,11 @@ class DistributedContext: ...@@ -131,6 +131,11 @@ class DistributedContext:
else: else:
return None return None
def del_dist_op_for_program(self, serial_tensor):
serial_tensor_id = serial_tensor.desc.id()
if self._dist_ops_for_program.get(serial_tensor_id, None):
del self._dist_ops_for_program[serial_tensor_id]
def get_dist_op_for_graph(self, serial_op_node): def get_dist_op_for_graph(self, serial_op_node):
serial_op_node_id = serial_op_node.id() serial_op_node_id = serial_op_node.id()
return self._dist_ops_for_graph.get(serial_op_node_id, None) return self._dist_ops_for_graph.get(serial_op_node_id, None)
......
...@@ -47,6 +47,7 @@ from .mapper import mapping ...@@ -47,6 +47,7 @@ from .mapper import mapping
from .dist_op import DistributedOperator from .dist_op import DistributedOperator
from .dist_tensor import DistributedTensor from .dist_tensor import DistributedTensor
from .planner import Planner from .planner import Planner
from paddle.distributed.passes import new_pass, PassContext
_logger = get_logger(logging.INFO) _logger = get_logger(logging.INFO)
...@@ -78,6 +79,8 @@ class AutoParallelizer: ...@@ -78,6 +79,8 @@ class AutoParallelizer:
self._enable_auto_mapping = False self._enable_auto_mapping = False
else: else:
self._enable_auto_mapping = True self._enable_auto_mapping = True
self._pass_context = PassContext()
self._need_rank_mapping = os.getenv("PADDLE_NEED_RANK_MAPPING") self._need_rank_mapping = os.getenv("PADDLE_NEED_RANK_MAPPING")
self._need_rank_mapping = True if self._need_rank_mapping and \ self._need_rank_mapping = True if self._need_rank_mapping and \
self._need_rank_mapping.lower() == 'true' else False self._need_rank_mapping.lower() == 'true' else False
...@@ -164,6 +167,15 @@ class AutoParallelizer: ...@@ -164,6 +167,15 @@ class AutoParallelizer:
auto_parallel_sharding_pass.apply( auto_parallel_sharding_pass.apply(
[main_program], [startup_program], self._pass_context) [main_program], [startup_program], self._pass_context)
if self._dist_strategy.gradient_merge:
config = copy.deepcopy(self._dist_strategy.gradient_merge_configs)
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
auto_parallel_gradient_merge_pass = new_pass(
"auto_parallel_gradient_merge_pass", config)
auto_parallel_gradient_merge_pass.apply(
[main_program], [startup_program], self._pass_context)
def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False): def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False):
completed_main_program = None completed_main_program = None
serial_main_program = self._main_program.clone() serial_main_program = self._main_program.clone()
...@@ -204,6 +216,7 @@ class AutoParallelizer: ...@@ -204,6 +216,7 @@ class AutoParallelizer:
make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context) make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context)
reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context) reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context)
self._apply_post_optimization_passed(dist_main_prog, dist_startup_prog, self._apply_post_optimization_passed(dist_main_prog, dist_startup_prog,
rank, dist_params_grads) rank, dist_params_grads)
g_process_group_map = None g_process_group_map = None
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from .pass_base import new_pass, PassManager, PassContext from .pass_base import new_pass, PassManager, PassContext
from .fuse_all_reduce import * from .fuse_all_reduce import *
from .auto_parallel_gradient_merge import *
from .auto_parallel_sharding import * from .auto_parallel_sharding import *
from .cpp_pass import * from .cpp_pass import *
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from collections import OrderedDict
from typing import List, Tuple, Dict, Any
import paddle
from paddle.framework import core
from paddle.fluid.framework import program_guard, device_guard
from paddle.fluid import unique_name, layers
from paddle.fluid.clip import append_gradient_clip_ops
from .pass_base import PassBase, PassType, register_pass
def _is_the_backward_op(op):
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
OpRole = core.op_proto_and_checker_maker.OpRole
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward)
def _is_the_optimizer_op(op):
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
OpRole = core.op_proto_and_checker_maker.OpRole
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Optimize)
def _remove_and_get_optimizer_op(main_program, dist_context):
# 1 create tmp block
# 2 mv optimizer op from global program to tmp block
# 3 del the op from dist_context
from paddle.distributed.fleet.meta_optimizers.common import OpRole
main_block = main_program.global_block()
temp_block = main_program._create_block()
removed_op_idx = []
optimize_ops_desc = []
for idx, op in enumerate(main_block.ops):
if _is_the_optimizer_op(op):
# append optimizer op to tmp block
new_op_desc = temp_block.desc.append_op()
new_op_desc.copy_from(op.desc)
optimize_ops_desc.append(new_op_desc)
removed_op_idx.append(idx)
# del op from dist_context
if dist_context:
dist_context.del_dist_op_for_program(op)
for idx in removed_op_idx[::-1]:
main_block._remove_op(idx)
return optimize_ops_desc
def _remove_op_role_var(param, grad):
op_maker = core.op_proto_and_checker_maker
op = grad.op
assert _is_the_backward_op(op), \
'grad.op={} is not the backward op which produces the grad={}' \
.format(op, grad.name)
if op.has_attr(op_maker.kOpRoleVarAttrName()):
op._remove_attr(op_maker.kOpRoleVarAttrName())
def _get_gm_cond_var(main_program, k_steps):
main_block = main_program.global_block()
# Add const var
k_step_var = layers.create_global_var(
name="gradient_merge_k",
shape=[1],
value=int(k_steps),
dtype='int32',
persistable=True,
force_cpu=True)
zero_var = layers.create_global_var(
name="gradient_merge_zero",
shape=[1],
value=int(0),
dtype='int32',
persistable=True,
force_cpu=True)
# Add step var & cond var
step_var = layers.create_global_var(
name="gradient_merge_step",
shape=[1],
value=int(0),
dtype='int32',
persistable=True,
force_cpu=True)
cond_var = layers.create_global_var(
name="gradient_merge_cond",
shape=[1],
value=bool(0),
dtype='bool',
persistable=False,
force_cpu=True)
with device_guard("cpu"):
# step_var = (step_var + 1) % k_step
layers.increment(x=step_var, value=1.0, in_place=True)
main_block.append_op(
type='elementwise_mod',
inputs={'X': step_var,
'Y': k_step_var},
outputs={'Out': step_var},
attrs={'axis': -1,
'use_mkldnn': False})
# cond_var = (step_var == 0)
main_block.append_op(
type='equal',
inputs={'X': step_var,
'Y': zero_var},
outputs={'Out': cond_var})
return cond_var
def _append_gradient_merge_backward_op(
main_program,
startup_program,
params_grads: List[Tuple[Any, Any]],
cond_var_name: str) -> Tuple[List[Tuple[Any, Any]], Dict[str, Any]]:
main_block = main_program.global_block()
startup_block = startup_program.global_block()
# step1: remove grad.op's op_role_var
for param, grad in params_grads:
assert (
param.type != core.VarDesc.VarType.SELECTED_ROWS
), "SELECTED_ROWS is not supported in GradientMergeOptimizer for now"
_remove_op_role_var(param, grad)
param_to_gradient_merge = {}
new_params_to_grads = []
# step2: create gradient_merge var and init with 0
for param, grad in params_grads:
param_name = param.name
param_var = main_block.var(param_name)
assert (param_var is not None)
gradient_merge_var = main_block.create_var(
name=param_name + "@GRAD@GradientMerge",
shape=param_var.shape,
dtype=param_var.dtype,
persistable=True)
param_to_gradient_merge[param_name] = gradient_merge_var
startup_gradient_merge_var = startup_block.create_var(
name=param_name + "@GRAD@GradientMerge",
shape=param_var.shape,
dtype=param_var.dtype,
persistable=True)
startup_block.append_op(
type="fill_constant",
outputs={"Out": startup_gradient_merge_var},
attrs={
"shape": param_var.shape,
"dtype": param_var.dtype,
"value": float(0),
})
# grad_merge += grad
new_grad_op = main_block.append_op(
type="elementwise_add",
inputs={'X': grad,
'Y': gradient_merge_var},
outputs={'Out': gradient_merge_var},
attrs={'axis': -1,
'use_mkldnn': False})
new_params_to_grads.append([param, gradient_merge_var])
return new_params_to_grads, param_to_gradient_merge
def _create_cond_block_and_update_optimizer(
main_program,
cond_var,
new_params_to_grads: List[Tuple[Any, Any]],
param_to_gradient_merge: Dict[str, Any],
optimize_ops_desc: List[Any],
k_steps,
avg):
def true_apply_gradient():
cur_block_idx = main_program.current_block_idx
cur_block = main_program.current_block()
# cur_block's forward_block & backward_block is itself
cur_block._set_forward_block_idx(cur_block_idx)
op_maker = core.op_proto_and_checker_maker
if avg:
for param, new_grad in new_params_to_grads:
# grad /= k_steps
cur_block.append_op(
type='scale',
inputs={'X': new_grad},
outputs={'Out': new_grad},
attrs={
'scale': 1.0 / k_steps,
'bias': 0.0,
'bias_after_scale': False
})
new_grad.op._set_attr(op_maker.kOpRoleAttrName(),
op_maker.OpRole.Optimize)
# append optimizer ops
for op_desc in optimize_ops_desc:
new_op_desc = cur_block.desc.append_op()
new_op_desc.copy_from(op_desc)
#update input/output
for input_name in new_op_desc.input_arg_names():
if input_name in new_params_to_grads:
new_op_desc._rename_input(input_name,
new_params_to_grads[input_name])
for output_name in new_op_desc.output_arg_names():
if output_name in new_params_to_grads:
new_op_desc._rename_output(output_name,
new_params_to_grads[output_name])
# remove op_role_var
if new_op_desc.has_attr(op_maker.kOpRoleVarAttrName()):
new_op_desc.remove_attr(op_maker.kOpRoleVarAttrName())
# op's update Grad
if new_op_desc.input("Grad"):
grad_value = new_op_desc.input("Grad")[0]
# TODO FIXME(xym) support fp16
grad_merge_value = grad_value + '@GradientMerge'
new_op_desc.set_input("Grad", [grad_merge_value])
main_program.global_block()._sync_with_cpp()
cur_block._sync_with_cpp()
# clear gradient_merge_vars
for param, new_grad in new_params_to_grads:
layers.fill_constant(
shape=new_grad.shape,
dtype=new_grad.dtype,
value=0.0,
out=new_grad)
new_grad.op._set_attr(op_maker.kOpRoleAttrName(),
op_maker.OpRole.Optimize)
layers.cond(cond_var, true_fn=true_apply_gradient, false_fn=None)
def parse_program(main_program, startup_program, params_grads, k_steps, avg,
dist_context):
# 1 create gradient_merge_cond
cond_var = _get_gm_cond_var(main_program, k_steps)
# 2 remove optimizer_op from main_program
optimize_ops_desc = _remove_and_get_optimizer_op(main_program, dist_context)
# back to block 0
main_program._rollback()
# 3 append gradient merge backward op to main_program
new_params_to_grads, param_to_gradient_merge = _append_gradient_merge_backward_op(
main_program, startup_program, params_grads, cond_var.name)
# 4 create ConditionalBlock and append gradient merge optimizer ops
_create_cond_block_and_update_optimizer(
main_program, cond_var, new_params_to_grads, param_to_gradient_merge,
optimize_ops_desc, k_steps, avg)
@register_pass("auto_parallel_gradient_merge_pass")
class GradientMergePass(PassBase):
def __init__(self):
super(GradientMergePass, self).__init__()
self.set_attr("k_steps", -1)
self.set_attr("avg", True)
self.set_attr("inner_optimizer", None)
def _check_self(self):
if self.get_attr("k_steps") < 1:
return False
return True
def _check_conflict(self, other_pass):
return True
def _type(self):
return PassType.COMM_OPT
def _apply_single_impl(self, main_program, startup_program, context):
k_steps = self.get_attr("k_steps", -1)
avg = self.get_attr("avg", False)
dist_context = self.get_attr("dist_context")
params_grads = self.get_attr("params_grads")
with paddle.static.program_guard(main_program, startup_program):
parse_program(main_program, startup_program, params_grads, k_steps,
avg, dist_context)
main_program._sync_with_cpp()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import unittest
import random
import numpy as np
import os
import shutil
import logging
import paddle
import paddle.nn as nn
import paddle.utils as utils
import paddle.static as static
import paddle.nn.functional as F
import paddle.distributed.auto_parallel as auto
from paddle.fluid.initializer import NumpyArrayInitializer
from paddle.distributed.passes import new_pass, PassManager, PassContext
import paddle.distributed.fleet as fleet
from dist_pass_test_base import DistPassTestBase
logging.getLogger().setLevel(logging.INFO)
paddle.enable_static()
_global_parallel_strategy = None
_global_process_mesh = None
#np.set_printoptions(suppress=True)
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=128,
intermediate_size=4 * 128,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
np.random.seed(2021)
arr0 = np.random.normal(0, 0.02, size=(d_model, dim_feedforward))
arr1 = np.random.normal(0, 0.02, size=(dim_feedforward, d_model))
weight_attr0 = paddle.ParamAttr(initializer=NumpyArrayInitializer(arr0))
weight_attr1 = paddle.ParamAttr(initializer=NumpyArrayInitializer(arr1))
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr0, bias_attr=bias_attr)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr1, bias_attr=bias_attr)
self.linear2 = nn.Linear(
d_model, dim_feedforward, weight_attr0, bias_attr=bias_attr)
self.linear3 = nn.Linear(
dim_feedforward, d_model, weight_attr1, bias_attr=bias_attr)
self.linear4 = nn.Linear(
d_model, dim_feedforward, weight_attr0, bias_attr=bias_attr)
self.linear5 = nn.Linear(
dim_feedforward, d_model, weight_attr1, bias_attr=bias_attr)
self.norm0 = nn.LayerNorm(d_model, epsilon=1e-5)
self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5)
def forward(self, input):
out = self.norm0(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
out = self.norm1(out)
out = self.linear2(out)
out = F.gelu(out, approximate=True)
out = self.linear3(out)
out = self.norm2(out)
out = self.linear4(out)
out = F.gelu(out, approximate=True)
out = self.linear5(out)
return out
def mlp_forward(input, label, hidden_size):
if _global_parallel_strategy == "dp":
auto.shard_tensor(
input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
initializer_range=0.02)
predict = mlp(input)
error_cost = paddle.nn.functional.square_error_cost(predict, label)
loss = paddle.mean(error_cost)
return loss
class TestGradientMergePass(DistPassTestBase):
def init(self):
self._params_grads = None
self._config = {"k_steps": 4, "avg": True}
def apply_passes(self, main_prog, startup_prog):
self._config["params_grads"] = self._params_grads
pass_context = PassContext()
auto_parallel_gradient_merge_pass = new_pass(
"auto_parallel_gradient_merge_pass", self._config)
auto_parallel_gradient_merge_pass.apply([main_prog], [startup_prog],
pass_context)
def test_result(self):
no_pass_rets = self._distributed_launch(
model=None,
apply_pass=False,
gpus=[0],
gradient_merge=False,
batch_size=32,
max_step=2)
pass_rets = self._distributed_launch(
model=None,
apply_pass=True,
gpus=[0],
gradient_merge=True,
batch_size=8,
max_step=8)
# avg loss for gradient_merge pass
avg_loss = 0
pass_avg_ret_list = []
for i, pass_ret in enumerate(pass_rets[0]):
if (i + 1) % 4 == 0:
avg_loss += pass_ret[0]
pass_avg_ret_list.append([avg_loss / 4])
avg_loss = 0
else:
avg_loss += pass_ret[0]
for no_pass_ret, pass_ret in zip(no_pass_rets[0], pass_avg_ret_list):
print(f"no_pass_ret={no_pass_ret}, pass_ret={pass_ret}")
self.assertTrue(
np.isclose(
no_pass_ret,
pass_ret,
rtol=self.rtol,
atol=self.atol,
equal_nan=self.equal_nan))
def get_model(self, place, gradient_merge, batch_size, max_step):
paddle.seed(2021)
random.seed(2021)
np.random.seed(2021)
hidden_size = 128
global _global_parallel_strategy
global _global_process_mesh
world_size = paddle.distributed.get_world_size()
if world_size == 1:
_global_parallel_strategy = "dp"
_global_process_mesh = auto.ProcessMesh([0])
elif world_size == 2:
_global_parallel_strategy = "dp"
_global_process_mesh = auto.ProcessMesh([0, 1])
train_program = static.Program()
startup_program = static.Program()
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
#if gradient_merge:
# dist_strategy.gradient_merge = True
# dist_strategy.gradient_merge_configs = {"k_steps": 4, "avg": True}
fleet.init(is_collective=True, strategy=dist_strategy)
with static.program_guard(train_program, startup_program), \
utils.unique_name.guard():
input = static.data(
name="input", shape=[batch_size, hidden_size], dtype='float32')
label = static.data(
name="label", shape=[batch_size, 1], dtype='float32')
input.stop_gradient = False
loss = mlp_forward(input, label, hidden_size)
optimizer = paddle.fluid.optimizer.SGDOptimizer(learning_rate=0.01)
#optimizer = paddle.fluid.optimizer.Adam(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer)
_, self._params_grads, dist_startup_prog, dist_main_prog = optimizer.minimize(
loss, startup_program)
input_data = np.random.random(size=(128, hidden_size)).astype('float32')
label_data = np.random.random(size=(128, 1)).astype('float32')
def reader():
for i in range(max_step):
x_data = input_data[i * batch_size:(i + 1) * batch_size, :]
y_data = label_data[i * batch_size:(i + 1) * batch_size, :]
yield x_data, y_data
return dist_main_prog, dist_startup_prog, [input, label], [loss], reader
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册