未验证 提交 e0e0c0fa 编写于 作者: Y Yuang Liu 提交者: GitHub

add sync calc stream and add ut for fuse on gpu (#33580)

上级 2d7ef7ad
......@@ -177,6 +177,7 @@ message DistributedStrategy {
optional bool tensor_parallel = 29 [ default = false ];
optional bool without_graph_optimization = 30 [ default = false ];
optional int32 fuse_grad_size_in_num = 31 [ default = 1 ];
optional bool calc_comm_same_stream = 32 [ default = false ];
optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102;
......
......@@ -853,6 +853,30 @@ class DistributedStrategy(object):
"WARNING: without_graph_optimization should have value of bool type"
)
@property
def _calc_comm_same_stream(self):
"""
This based on raw_program_optimizer program
Set whether use same stream for calc and comm when fuse allreduce
The default value for the calc_comm_same_stream is False
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.calc_comm_same_stream = True
"""
return self.strategy.calc_comm_same_stream
@_calc_comm_same_stream.setter
@is_strict_auto
def _calc_comm_same_stream(self, same):
if isinstance(same, bool):
self.strategy.calc_comm_same_stream = same
else:
print(
"WARNING: calc_comm_same_stream should have value of boolean type"
)
@property
def fuse_grad_size_in_num(self):
"""
......
......@@ -44,6 +44,7 @@ class RawProgramOptimizer(MetaOptimizerBase):
self.fuse_all_reduce_ops = user_defined_strategy.fuse_all_reduce_ops
if self.fuse_all_reduce_ops:
self.fuse_grad_size_in_num = user_defined_strategy.fuse_grad_size_in_num
self.calc_comm_same_stream = user_defined_strategy._calc_comm_same_stream
def _can_apply(self):
if not self.role_maker._is_collective:
......@@ -130,8 +131,7 @@ class RawProgramOptimizer(MetaOptimizerBase):
def _transpile_main_program(self, loss):
self._insert_loss_grad_ops(loss)
if self.fuse_all_reduce_ops and core.is_compiled_with_npu():
self._calc_stream = True
if self.fuse_all_reduce_ops:
self._allreduce_fusion_program()
else:
self._insert_allreduce_ops()
......@@ -206,22 +206,30 @@ class RawProgramOptimizer(MetaOptimizerBase):
OP_ROLE_KEY: OpRole.Backward})
break
# TODO(Liu yuang): ADD CUDA allreduce_fusion fuction.
# This function helps reduce the input of allreduce by integrating can save communication time.
# This function helps reduce the number of allreduce by integrating op, which can save communication time.
# to use allreduce fuse, follow these codes:
# strategy = paddle.distributed.fleet.DistributedStrategy()
# strategy.without_graph_optimization = True
# strategy.fuse_all_reduce_ops = True
# strategy.calc_comm_same_stream = False
# strategy.fuse_grad_size_in_num = 8
def _allreduce_fusion_program(self):
block = self.main_program.global_block()
ring_id = self.global_ring_id
record_idx, allreduce_input_vars, allreduce_output_vars = [], [], []
block_ops = len(list(enumerate(block.ops)))
ops = list(enumerate(block.ops))
for idx, op in reversed(list(enumerate(block.ops))):
for idx, op in reversed(ops):
# we travers the ops reversely
if is_backward_op(op) and \
OP_ROLE_VAR_KEY in op.attr_names:
op_role_var = op.attr(OP_ROLE_VAR_KEY)
if len(op_role_var) == 0:
continue
assert len(op_role_var) % 2 == 0
assert len(op_role_var) % 2 == 0, "vars need to be one param var followed by one grad var, " \
"but got odd number of vars"
for i in range(0, len(op_role_var), 2):
# handle vars in each op, each time handle a param and a grad
param_name = op_role_var[i]
param = block.var(param_name)
grad_name = op_role_var[i + 1]
......@@ -229,6 +237,7 @@ class RawProgramOptimizer(MetaOptimizerBase):
if param.is_distributed:
continue
if ".cast_fp16@GRAD" in grad_name:
# when amp=True get the fp16 param
param_name = param_name + ".cast_fp16"
if not block.has_var(param_name):
raise ValueError("op cast name error {}".format(
......@@ -236,154 +245,102 @@ class RawProgramOptimizer(MetaOptimizerBase):
else:
param = block.var(param_name)
if len(allreduce_output_vars) == 0:
allreduce_output_vars.append([grad])
allreduce_input_vars.append([param])
if self.fuse_grad_size_in_num == 1:
record_idx.append([idx, idx])
continue
record_idx.append([-2, idx])
elif len(allreduce_output_vars[
-1]) == self.fuse_grad_size_in_num:
if len(allreduce_output_vars) == 0 or \
len(allreduce_output_vars[-1]) == \
self.fuse_grad_size_in_num:
# start of the fusion or last group meets the config size
allreduce_output_vars.append([grad])
allreduce_input_vars.append([param])
if self.fuse_grad_size_in_num == 1:
record_idx.append([idx, idx])
continue
if idx != block_ops - 1:
record_idx.append([-2, idx])
# add the start and end idx to the record idx
record_idx.append([idx, idx])
else:
# Current group's size is below the config size
# append grad and param to the last group (current group)
# update the start idx to current op's idx
# Since we travers the ops reversely, the idx is descending
# we update the first entry of each entry for record_idx
allreduce_output_vars[-1].append(grad)
allreduce_input_vars[-1].append(param)
record_idx[-1][0] = idx
if record_idx[-1][0] == -2:
record_idx[-1][0] = record_idx[-1][1]
assert len(allreduce_output_vars) == len(
record_idx
), "It has different lens between the allreduce_output_vars and record_idx."
if not allreduce_output_vars or not allreduce_input_vars:
# nothing needs to be allreduced
return
self.vars = collections.OrderedDict()
index, offset_pos, pos, offset = 0, 0, 0, 0
index, pos, offset = 0, 0, 0
start, end = record_idx[index]
men_list = [end, start]
# Here we need to explain the flag. When integrating OP, we will encounter different groups of the same Op.
# Because we insert coalesce tensor in reverse ops,
# we need to use flag to record whether the current OP has been inserted into coalesce tensor。
# For example:
# [(3, 2), (2, 2), (1, 0)], (3, 2), (2, 2) using same op, but in different groups.
for idx, op in reversed(list(enumerate(block.ops))):
for idx, op in reversed(ops):
if idx == start:
pos = 0
flag = True if end == men_list[-1] else False
offset = offset_pos if flag else 0
done_output_vars, done_input_vars = self._split_fuction(
allreduce_output_vars[index], allreduce_input_vars[index])
allreduce_output_vars[index], # grad
allreduce_input_vars[index] # param
)
for id_, done_output_var in enumerate(done_output_vars):
if flag:
tmp_var = block.create_var(
name=unique_name.generate(
'FusedOutput_{}_{}'.format(start, id_ +
offset)),
dtype=done_output_var[0].dtype,
persistable=False,
stop_gradient=True)
self.vars['FusedOutput_{}_{}'.format(start, id_ +
offset)] = tmp_var
tmp_var = block.create_var(
name=unique_name.generate('FusedOutput_{}'.format(
done_output_var[0].name)),
dtype=done_output_var[0].dtype,
persistable=False,
stop_gradient=True)
self.vars['FusedOutput_{}'.format(done_output_var[0]
.name)] = tmp_var
block._insert_op(
idx + id_ + offset,
type="coalesce_tensor",
inputs={"Input": done_input_vars[id_]},
outputs={
"Output": done_output_var,
"FusedOutput": tmp_var
},
attrs={
"copy_data": False,
"use_align": True,
"dtype": done_output_var[0].dtype
})
pos += 1
else:
tmp_var = block.create_var(
name=unique_name.generate(
'FusedOutput_{}_{}'.format(start, id_)),
dtype=done_output_var[0].dtype,
persistable=False,
stop_gradient=True)
self.vars['FusedOutput_{}_{}'.format(start,
id_)] = tmp_var
block._insert_op(
idx + id_,
type="coalesce_tensor",
inputs={"Input": done_input_vars[id_]},
outputs={
"Output": done_output_var,
"FusedOutput": tmp_var
},
attrs={
"copy_data": False,
"use_align": True,
"dtype": done_output_var[0].dtype,
OP_ROLE_KEY: OpRole.Backward
})
pos += 1
block._insert_op(
idx + id_,
type="coalesce_tensor",
inputs={"Input": done_input_vars[id_]},
outputs={
"Output": done_output_var,
"FusedOutput": tmp_var
},
attrs={
"copy_data": False,
"use_align": True,
"dtype": done_output_var[0].dtype
})
pos += 1
offset_pos = pos
# TODO(Liu yuang): ADD CUDA and NPU's EVENT and c_allreduce_sum.
for id_ in range(len(done_output_vars)):
if flag:
block._insert_op(
end + id_ + pos + 1,
type='c_allreduce_sum',
inputs={
'X': self.vars['FusedOutput_{}_{}'.format(
start, id_ + offset)]
},
outputs={
'Out': self.vars['FusedOutput_{}_{}'.format(
start, id_ + offset)]
},
attrs={
'ring_id': ring_id,
'use_calc_stream': True
if self._calc_stream else False,
OP_ROLE_KEY: OpRole.Backward
})
else:
x = self.vars['FusedOutput_{}'.format(done_output_vars[id_][
0].name)]
out = x
# NOTE: there still some optimize space if use EVENT instead of sync
if not self.calc_comm_same_stream:
# need sync if the calc and comm stream are not the same
block._insert_op(
end + id_ + pos + 1,
type='c_allreduce_sum',
inputs={
'X': self.vars['FusedOutput_{}_{}'.format(start,
id_)]
},
outputs={
'Out': self.vars['FusedOutput_{}_{}'.format(
start, id_)]
},
attrs={
'ring_id': ring_id,
'use_calc_stream': True
if self._calc_stream else False,
OP_ROLE_KEY: OpRole.Backward
})
type='c_sync_calc_stream',
inputs={'X': x},
outputs={'Out': out},
attrs={OP_ROLE_KEY: OpRole.Backward})
block._insert_op(
end + id_ + pos + 1
if self.calc_comm_same_stream else end + id_ + pos + 2,
type='c_allreduce_sum',
inputs={'X': x},
outputs={'Out': out},
attrs={
'ring_id': ring_id,
'use_calc_stream': self.calc_comm_same_stream,
OP_ROLE_KEY: OpRole.Backward
})
index += 1
men_list.append(end)
men_list.append(start)
if len(record_idx) == index:
start = end = -1
continue
break
start, end = record_idx[index]
if not self._calc_stream:
if not self.calc_comm_same_stream:
# need sync if the calc and comm stream are not the same
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
block._insert_op(
......@@ -397,34 +354,50 @@ class RawProgramOptimizer(MetaOptimizerBase):
})
break
# Integrate grads of the same type to form a combination. If skip_comb is selected, will return grads of the same group.
# Integrate grads of the same type to form a combination.
# If combination is selected, will return grads of the same type in a groups.
# For example:[(fp16, fp16), (fp32), (fp16)] -> [(fp16, fp16, fp16), (fp32)]
def _split_fuction(self,
allreduce_output_vars,
allreduce_input_vars,
skip_comb=True):
combination=True):
input_vars, final_input_vars, output_vars, final_output_vars = [], [], [], []
if len(allreduce_output_vars) - 1 == 0:
if len(allreduce_output_vars) == 1:
# only have one var to handle
final_output_vars.append(allreduce_output_vars)
final_input_vars.append(allreduce_input_vars)
return final_output_vars, final_input_vars
for idx in range(len(allreduce_input_vars) - 1):
# the last var needs to be handled differently
if allreduce_input_vars[idx].dtype == allreduce_input_vars[idx +
1].dtype:
# if current var and next var are in same type
# append current var to input_vars
input_vars.append(allreduce_input_vars[idx])
if idx == len(allreduce_input_vars) - 2:
# if current var is the second last var
# append the last var to input_vars
# and update the final_input_vars
input_vars.append(allreduce_input_vars[idx + 1])
final_input_vars.append(input_vars)
else:
# the current var and next var are in different types
# append current var to input_vars
# update the final_input_vars
# reset input_vars to receive a new type
input_vars.append(allreduce_input_vars[idx])
final_input_vars.append(input_vars)
input_vars = []
if idx == len(allreduce_input_vars) - 2:
# if current var is the second last var
# append the last var to a reset input_vars since they are in different types
# and update the final_input_vars
input_vars.append(allreduce_input_vars[idx + 1])
final_input_vars.append(input_vars)
for idx in range(len(allreduce_output_vars) - 1):
# the procedure for the output vars is the same with that for the input vars
if allreduce_output_vars[idx].dtype == allreduce_output_vars[
idx + 1].dtype:
output_vars.append(allreduce_output_vars[idx])
......@@ -438,10 +411,14 @@ class RawProgramOptimizer(MetaOptimizerBase):
if idx == len(allreduce_output_vars) - 2:
output_vars.append(allreduce_output_vars[idx + 1])
final_output_vars.append(output_vars)
if skip_comb:
# at this time, all vars in each group in final_input_vars and final_output_vars are in the same type
if combination:
input_fp16_vars, input_fp32_vars, output_fp16_vars, output_fp32_vars = [], [], [], []
for final_input_var in final_input_vars:
if final_input_var[0].dtype == core.VarDesc.VarType.FP16:
# extend the group
input_fp16_vars.extend(final_input_var)
else:
input_fp32_vars.extend(final_input_var)
......@@ -451,6 +428,7 @@ class RawProgramOptimizer(MetaOptimizerBase):
output_fp16_vars.extend(final_output_var)
else:
output_fp32_vars.extend(final_output_var)
final_output_vars, final_input_vars = [], []
if output_fp16_vars:
final_output_vars.append(output_fp16_vars)
......
......@@ -718,6 +718,7 @@ if (WITH_DISTRIBUTE)
set_tests_properties(test_dist_fleet_sparse_embedding_ctr PROPERTIES TIMEOUT 200)
set_tests_properties(test_dist_fleet_infer PROPERTIES TIMEOUT 200)
set_tests_properties(test_dist_fleet_raw_program_optimizer PROPERTIES TIMEOUT 120)
set_tests_properties(test_dist_fleet_raw_program_optimizer_fuse_allreduce PROPERTIES TIMEOUT 60)
endif()
if (WITH_DISTRIBUTE AND NOT APPLE)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from test_dist_base import TestDistRunnerBase, runtime_main
import unittest
import paddle
import os
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
import numpy as np
from functools import reduce
import paddle.fluid as fluid
paddle.enable_static()
DTYPE = "float32"
paddle.dataset.mnist.fetch()
# Fix seed for test
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
def cnn_model(data):
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=data,
filter_size=5,
num_filters=20,
pool_size=2,
pool_stride=2,
act="relu",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.01)))
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
filter_size=5,
num_filters=50,
pool_size=2,
pool_stride=2,
act="relu",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.01)))
SIZE = 10
input_shape = conv_pool_2.shape
param_shape = [reduce(lambda a, b: a * b, input_shape[1:], 1)] + [SIZE]
scale = (2.0 / (param_shape[0]**2 * SIZE))**0.5
predict = fluid.layers.fc(
input=conv_pool_2,
size=SIZE,
act="softmax",
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01)))
return predict
class TestFleetMetaOptimizerFuseAllReducePrecision(TestDistRunnerBase):
def get_model(self, batch_size=2, single_device=False):
# Input data
images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE)
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# Train program
predict = cnn_model(images)
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
# Evaluator
batch_size_tensor = fluid.layers.create_tensor(dtype='int64')
batch_acc = fluid.layers.accuracy(
input=predict, label=label, total=batch_size_tensor)
test_program = fluid.default_main_program().clone(for_test=True)
# Reader
train_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size)
optimizer = paddle.fluid.optimizer.Adam(0.01)
if single_device:
optimizer.minimize(avg_cost)
else:
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.without_graph_optimization = True
strategy.fuse_all_reduce_ops = True
strategy._calc_comm_same_stream = False
strategy.fuse_grad_size_in_num = 8
optimizer = fleet.distributed_optimizer(
optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
return test_program, avg_cost, train_reader, test_reader, batch_acc, predict
if __name__ == "__main__":
runtime_main(TestFleetMetaOptimizerFuseAllReducePrecision)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from test_dist_base import TestDistBase
import paddle
import os
paddle.enable_static()
flag_name = os.path.splitext(__file__)[0]
class TestFleetMetaOptimizerAllReduceFusePrecision(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._use_reduce = False
self._use_reader_alloc = False
self._nccl2_mode = True
self._nccl2_reduce_layer = True
self._use_fleet_api = True
self._use_fleet_api_20 = True
def test_dist_train(self):
import paddle.fluid as fluid
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"dist_fleet_raw_program_optimizer_fuse_allreduce.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册