提交 0d208e00 编写于 作者: Z Ziyan 提交者: Xiaoda Zhang

Model ALLTOALL as a single operator in cost model; scale the ALLTOALL,

ALLGATHER, and REDUCESCATTER with different factors; change the BETA and
GAMMA value in cost model.
上级 ae7556ff
......@@ -34,7 +34,7 @@ namespace parallel {
#define OPERATOR_TO_OPERATOR_CONNECTOR "-"
#define DEFAULT_DEVICE_MEMORY_CAPACITY (1024.0 * 1024.0 * 1024.0 * 16.0)
#define DEFAULT_COST_MODEL_ALPHA 1.0
#define DEFAULT_COST_MODEL_BETA 260.0
#define DEFAULT_COST_MODEL_BETA 400.0
#define DEFAULT_COST_MODEL_GAMMA 0.001
#define DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION true
#define DEFAULT_COST_MODEL_COMMUNI_THRESHOLD 2048.0
......
......@@ -23,7 +23,7 @@
namespace mindspore {
namespace parallel {
Status RedistributionOperatorInfer::Init(const TensorLayout& tensor_layout, const Map& out_tensor_map,
RankList dev_list) {
RankList dev_list, bool is_cost_model) {
in_tensor_map_ = tensor_layout.tensor_map();
dev_mat_ = tensor_layout.device_arrangement();
......@@ -51,6 +51,8 @@ Status RedistributionOperatorInfer::Init(const TensorLayout& tensor_layout, cons
for (int32_t item : map) {
map_[key++] = item;
}
is_cost_model_ = is_cost_model;
return Status::SUCCESS;
}
......@@ -130,8 +132,18 @@ Status RedistributionOperatorInfer::InferPermuteByAxis() {
std::any_of(map_.begin(), map_.end(),
[out_dim](const RedistributionOperatorMap::value_type& a) { return a.second == out_dim; })) {
int32_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim);
Args args_allconcat = {cat_dim, out_dim, dev_mat_.GetDimByReverseIdx(IntToUint(out_dim))};
Args args_allsplit = {dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)), UintToInt(index), out_dim};
int32_t dev_num = dev_mat_.GetDimByReverseIdx(IntToUint(out_dim));
if (is_cost_model_) {
int32_t dev_dim = in_tensor_map_.GetDimByIdx(IntToUint(cat_dim));
Args args_alltoall = {dev_mat_.GetDimByReverseIdx(IntToUint(dev_dim)), UintToInt(index), cat_dim, dev_dim,
dev_num};
if (InsertOperator(PERMUTE_BY_AXIS, args_alltoall) == Status::FAILED) {
MS_LOG(ERROR) << "Insert PermuteByAxis Error!";
return Status::FAILED;
}
} else {
Args args_allconcat = {cat_dim, out_dim, dev_num};
Args args_allsplit = {dev_num, UintToInt(index), out_dim};
if (InsertOperator(CONCAT_BY_AXIS, args_allconcat) == Status::FAILED) {
MS_LOG(ERROR) << "Insert ConcatByAxis Error!";
return Status::FAILED;
......@@ -140,6 +152,7 @@ Status RedistributionOperatorInfer::InferPermuteByAxis() {
MS_LOG(ERROR) << "Insert SplitByAxis Error!";
return Status::FAILED;
}
}
(void)map_.erase(iter++);
map_[IntToSize(cat_dim)] = NONE;
} else {
......
......@@ -40,7 +40,8 @@ class RedistributionOperatorInfer {
public:
const int NONE = -1;
explicit RedistributionOperatorInfer(bool construct_op_flag = true) : construct_op_flag_(construct_op_flag) {}
Status Init(const TensorLayout& tensor_layout, const Map& out_tensor_map, RankList dev_list);
Status Init(const TensorLayout& tensor_layout, const Map& out_tensor_map, RankList dev_list,
bool is_cost_model = false);
~RedistributionOperatorInfer() = default;
OperatorList operator_list() const { return operator_list_; }
OperatorVector operator_vector() const { return operator_vector_; }
......@@ -67,6 +68,7 @@ class RedistributionOperatorInfer {
ConstructOperator constructor_;
RankList dev_list_;
bool construct_op_flag_;
bool is_cost_model_;
};
} // namespace parallel
} // namespace mindspore
......
......@@ -40,7 +40,7 @@ Status TensorRedistribution::Init(const TensorLayout& from, const TensorLayout&
return Status::SUCCESS;
}
RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList() {
RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList(bool is_cost_model) {
// Step 1: Match device arrangement between from_ and to_
RedistributionLayoutTransfer layout_transfer;
Status status = layout_transfer.Init(from_, to_);
......@@ -62,7 +62,7 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL
MS_LOG(DEBUG) << "reshape to_ " << to_.ToString();
// Step 2: Infer redistribution and insert operators
RedistributionOperatorInfer operator_infer(construct_op_flag_);
if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_) == Status::FAILED) {
if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_, is_cost_model) == Status::FAILED) {
MS_LOG(ERROR) << "Init operatorInfer failed!";
return nullptr;
}
......@@ -138,7 +138,7 @@ Status TensorRedistribution::InferReshape(const TensorLayout& from_layout, const
}
Status TensorRedistribution::ComputeCost() {
RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList();
RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList(true);
if (redistribution_oplist_ptr == nullptr) {
MS_LOG(ERROR) << "Failure: InferTensorRedistribution failed";
return Status::FAILED;
......@@ -151,14 +151,22 @@ Status TensorRedistribution::ComputeCost() {
std::accumulate(slice_shape.begin(), slice_shape.end(), static_cast<double>(1.0), std::multiplies<double>());
std::string str = op.first;
if (str == PERMUTE_BY_AXIS) {
// The shape does not change after PermuteByAxis operation.
// communication cost = all_to_all + all_to_all = 2 * slice_shape
// computation cost = slice_shape
forward_comm_cost_ += prod;
backward_comm_cost_ += prod;
comm_cost_ += 2.0 * prod;
// Since AlltoAll is a virtual operator, the expanded operators are used here to compute cost.
// communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
forward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR;
backward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR;
comm_cost_ += 2.0 * prod * ALLTOALL_SCALE_FACTOR;
int32_t concat_dim = op.second[2];
if (concat_dim == 0) {
// memory cost = all_gather
computation_cost_ += prod;
memory_cost_ += prod;
} else {
// memory cost = all_gather + split + concat
int32_t dev_num = op.second[4];
computation_cost_ += (prod + prod * dev_num + prod * dev_num);
memory_cost_ += (prod * dev_num + prod * dev_num + prod);
}
} else if (str == CONCAT_BY_AXIS) {
// communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
// computation cost = before_slice_shape
......@@ -168,9 +176,9 @@ Status TensorRedistribution::ComputeCost() {
}
double dev_num = op.second[2];
// here, communication cost = all_gather + reduce_scatter
forward_comm_cost_ += prod * dev_num;
backward_comm_cost_ += prod;
comm_cost_ += prod * (dev_num + 1.0);
forward_comm_cost_ += prod * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
backward_comm_cost_ += prod * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
comm_cost_ += prod * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
int32_t concat_dim = op.second[0];
if (concat_dim == 0) {
// computation cost = all_gather
......
......@@ -33,6 +33,8 @@
namespace mindspore {
namespace parallel {
constexpr double ALLTOALL_SCALE_FACTOR = 2.0;
constexpr double ALLGATHER_REDUCESCATTER_SCALE_FACTOR = 0.5;
class TensorRedistribution {
public:
explicit TensorRedistribution(bool construct_op_flag = true, bool keep_reshape = false)
......@@ -46,7 +48,7 @@ class TensorRedistribution {
keep_reshape_(keep_reshape) {}
Status Init(const TensorLayout& from, const TensorLayout& to, const RankList& dev_list);
~TensorRedistribution() = default;
RedistributionOpListPtr InferTensorRedistributionOperatorList();
RedistributionOpListPtr InferTensorRedistributionOperatorList(bool is_cost_model = false);
OperatorList operator_list() const { return operator_list_; }
bool reshape_flag() const { return reshape_flag_; }
Status ComputeCost();
......
......@@ -304,7 +304,7 @@ def train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768):
def test_train_32k_8p_fusion1(epoch_size=3, batch_size=32, num_classes=32768): #1048576 #131072 #32768 #8192
cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=260.0)
cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5)
......@@ -651,7 +651,7 @@ def test_train_32k_8p_fusion2(epoch_size=3, batch_size=32, num_classes=32768): #
def test_train_64k_8p(epoch_size=3, batch_size=32, num_classes=65536): #1048576 #131072 #32768 #8192
dev_num = 8
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num)
cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=260.0)
cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0)
set_algo_parameters(elementwise_op_strategy_follow=True)
resset_op_id()
np.random.seed(6)
......
......@@ -86,7 +86,7 @@ def test_two_matmul():
costmodel_alpha = cost_model_context.get_cost_model_context("costmodel_alpha")
assert costmodel_alpha == 1.0
costmodel_beta = cost_model_context.get_cost_model_context("costmodel_beta")
assert costmodel_beta == 260.0
assert costmodel_beta == 400.0
costmodel_gamma = cost_model_context.get_cost_model_context("costmodel_gamma")
assert costmodel_gamma == 0.001
costmodel_communi_threshold = cost_model_context.get_cost_model_context("costmodel_communi_threshold")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册