提交 0cdb1171 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!87 Take AllToAll as a virtual operator in cost model

Merge pull request !87 from gziyan/dev_alltoall
...@@ -34,7 +34,7 @@ namespace parallel { ...@@ -34,7 +34,7 @@ namespace parallel {
#define OPERATOR_TO_OPERATOR_CONNECTOR "-" #define OPERATOR_TO_OPERATOR_CONNECTOR "-"
#define DEFAULT_DEVICE_MEMORY_CAPACITY (1024.0 * 1024.0 * 1024.0 * 16.0) #define DEFAULT_DEVICE_MEMORY_CAPACITY (1024.0 * 1024.0 * 1024.0 * 16.0)
#define DEFAULT_COST_MODEL_ALPHA 1.0 #define DEFAULT_COST_MODEL_ALPHA 1.0
#define DEFAULT_COST_MODEL_BETA 260.0 #define DEFAULT_COST_MODEL_BETA 400.0
#define DEFAULT_COST_MODEL_GAMMA 0.001 #define DEFAULT_COST_MODEL_GAMMA 0.001
#define DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION true #define DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION true
#define DEFAULT_COST_MODEL_COMMUNI_THRESHOLD 2048.0 #define DEFAULT_COST_MODEL_COMMUNI_THRESHOLD 2048.0
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
Status RedistributionOperatorInfer::Init(const TensorLayout& tensor_layout, const Map& out_tensor_map, Status RedistributionOperatorInfer::Init(const TensorLayout& tensor_layout, const Map& out_tensor_map,
RankList dev_list) { RankList dev_list, bool is_cost_model) {
in_tensor_map_ = tensor_layout.tensor_map(); in_tensor_map_ = tensor_layout.tensor_map();
dev_mat_ = tensor_layout.device_arrangement(); dev_mat_ = tensor_layout.device_arrangement();
...@@ -51,6 +51,8 @@ Status RedistributionOperatorInfer::Init(const TensorLayout& tensor_layout, cons ...@@ -51,6 +51,8 @@ Status RedistributionOperatorInfer::Init(const TensorLayout& tensor_layout, cons
for (int32_t item : map) { for (int32_t item : map) {
map_[key++] = item; map_[key++] = item;
} }
is_cost_model_ = is_cost_model;
return Status::SUCCESS; return Status::SUCCESS;
} }
...@@ -130,15 +132,26 @@ Status RedistributionOperatorInfer::InferPermuteByAxis() { ...@@ -130,15 +132,26 @@ Status RedistributionOperatorInfer::InferPermuteByAxis() {
std::any_of(map_.begin(), map_.end(), std::any_of(map_.begin(), map_.end(),
[out_dim](const RedistributionOperatorMap::value_type& a) { return a.second == out_dim; })) { [out_dim](const RedistributionOperatorMap::value_type& a) { return a.second == out_dim; })) {
int32_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim); int32_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim);
Args args_allconcat = {cat_dim, out_dim, dev_mat_.GetDimByReverseIdx(IntToUint(out_dim))}; int32_t dev_num = dev_mat_.GetDimByReverseIdx(IntToUint(out_dim));
Args args_allsplit = {dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)), UintToInt(index), out_dim}; if (is_cost_model_) {
if (InsertOperator(CONCAT_BY_AXIS, args_allconcat) == Status::FAILED) { int32_t dev_dim = in_tensor_map_.GetDimByIdx(IntToUint(cat_dim));
MS_LOG(ERROR) << "Insert ConcatByAxis Error!"; Args args_alltoall = {dev_mat_.GetDimByReverseIdx(IntToUint(dev_dim)), UintToInt(index), cat_dim, dev_dim,
return Status::FAILED; dev_num};
} if (InsertOperator(PERMUTE_BY_AXIS, args_alltoall) == Status::FAILED) {
if (InsertOperator(SPLIT_BY_AXIS, args_allsplit) == Status::FAILED) { MS_LOG(ERROR) << "Insert PermuteByAxis Error!";
MS_LOG(ERROR) << "Insert SplitByAxis Error!"; return Status::FAILED;
return Status::FAILED; }
} else {
Args args_allconcat = {cat_dim, out_dim, dev_num};
Args args_allsplit = {dev_num, UintToInt(index), out_dim};
if (InsertOperator(CONCAT_BY_AXIS, args_allconcat) == Status::FAILED) {
MS_LOG(ERROR) << "Insert ConcatByAxis Error!";
return Status::FAILED;
}
if (InsertOperator(SPLIT_BY_AXIS, args_allsplit) == Status::FAILED) {
MS_LOG(ERROR) << "Insert SplitByAxis Error!";
return Status::FAILED;
}
} }
(void)map_.erase(iter++); (void)map_.erase(iter++);
map_[IntToSize(cat_dim)] = NONE; map_[IntToSize(cat_dim)] = NONE;
......
...@@ -40,7 +40,8 @@ class RedistributionOperatorInfer { ...@@ -40,7 +40,8 @@ class RedistributionOperatorInfer {
public: public:
const int NONE = -1; const int NONE = -1;
explicit RedistributionOperatorInfer(bool construct_op_flag = true) : construct_op_flag_(construct_op_flag) {} explicit RedistributionOperatorInfer(bool construct_op_flag = true) : construct_op_flag_(construct_op_flag) {}
Status Init(const TensorLayout& tensor_layout, const Map& out_tensor_map, RankList dev_list); Status Init(const TensorLayout& tensor_layout, const Map& out_tensor_map, RankList dev_list,
bool is_cost_model = false);
~RedistributionOperatorInfer() = default; ~RedistributionOperatorInfer() = default;
OperatorList operator_list() const { return operator_list_; } OperatorList operator_list() const { return operator_list_; }
OperatorVector operator_vector() const { return operator_vector_; } OperatorVector operator_vector() const { return operator_vector_; }
...@@ -67,6 +68,7 @@ class RedistributionOperatorInfer { ...@@ -67,6 +68,7 @@ class RedistributionOperatorInfer {
ConstructOperator constructor_; ConstructOperator constructor_;
RankList dev_list_; RankList dev_list_;
bool construct_op_flag_; bool construct_op_flag_;
bool is_cost_model_;
}; };
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
......
...@@ -40,7 +40,7 @@ Status TensorRedistribution::Init(const TensorLayout& from, const TensorLayout& ...@@ -40,7 +40,7 @@ Status TensorRedistribution::Init(const TensorLayout& from, const TensorLayout&
return Status::SUCCESS; return Status::SUCCESS;
} }
RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList() { RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList(bool is_cost_model) {
// Step 1: Match device arrangement between from_ and to_ // Step 1: Match device arrangement between from_ and to_
RedistributionLayoutTransfer layout_transfer; RedistributionLayoutTransfer layout_transfer;
Status status = layout_transfer.Init(from_, to_); Status status = layout_transfer.Init(from_, to_);
...@@ -62,7 +62,7 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL ...@@ -62,7 +62,7 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL
MS_LOG(DEBUG) << "reshape to_ " << to_.ToString(); MS_LOG(DEBUG) << "reshape to_ " << to_.ToString();
// Step 2: Infer redistribution and insert operators // Step 2: Infer redistribution and insert operators
RedistributionOperatorInfer operator_infer(construct_op_flag_); RedistributionOperatorInfer operator_infer(construct_op_flag_);
if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_) == Status::FAILED) { if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_, is_cost_model) == Status::FAILED) {
MS_LOG(ERROR) << "Init operatorInfer failed!"; MS_LOG(ERROR) << "Init operatorInfer failed!";
return nullptr; return nullptr;
} }
...@@ -138,7 +138,7 @@ Status TensorRedistribution::InferReshape(const TensorLayout& from_layout, const ...@@ -138,7 +138,7 @@ Status TensorRedistribution::InferReshape(const TensorLayout& from_layout, const
} }
Status TensorRedistribution::ComputeCost() { Status TensorRedistribution::ComputeCost() {
RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList(); RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList(true);
if (redistribution_oplist_ptr == nullptr) { if (redistribution_oplist_ptr == nullptr) {
MS_LOG(ERROR) << "Failure: InferTensorRedistribution failed"; MS_LOG(ERROR) << "Failure: InferTensorRedistribution failed";
return Status::FAILED; return Status::FAILED;
...@@ -151,14 +151,22 @@ Status TensorRedistribution::ComputeCost() { ...@@ -151,14 +151,22 @@ Status TensorRedistribution::ComputeCost() {
std::accumulate(slice_shape.begin(), slice_shape.end(), static_cast<double>(1.0), std::multiplies<double>()); std::accumulate(slice_shape.begin(), slice_shape.end(), static_cast<double>(1.0), std::multiplies<double>());
std::string str = op.first; std::string str = op.first;
if (str == PERMUTE_BY_AXIS) { if (str == PERMUTE_BY_AXIS) {
// The shape does not change after PermuteByAxis operation. // Since AlltoAll is a virtual operator, the expanded operators are used here to compute cost.
// communication cost = all_to_all + all_to_all = 2 * slice_shape // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
// computation cost = slice_shape forward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR;
forward_comm_cost_ += prod; backward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR;
backward_comm_cost_ += prod; comm_cost_ += 2.0 * prod * ALLTOALL_SCALE_FACTOR;
comm_cost_ += 2.0 * prod; int32_t concat_dim = op.second[2];
computation_cost_ += prod; if (concat_dim == 0) {
memory_cost_ += prod; // memory cost = all_gather
computation_cost_ += prod;
memory_cost_ += prod;
} else {
// memory cost = all_gather + split + concat
int32_t dev_num = op.second[4];
computation_cost_ += (prod + prod * dev_num + prod * dev_num);
memory_cost_ += (prod * dev_num + prod * dev_num + prod);
}
} else if (str == CONCAT_BY_AXIS) { } else if (str == CONCAT_BY_AXIS) {
// communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
// computation cost = before_slice_shape // computation cost = before_slice_shape
...@@ -168,9 +176,9 @@ Status TensorRedistribution::ComputeCost() { ...@@ -168,9 +176,9 @@ Status TensorRedistribution::ComputeCost() {
} }
double dev_num = op.second[2]; double dev_num = op.second[2];
// here, communication cost = all_gather + reduce_scatter // here, communication cost = all_gather + reduce_scatter
forward_comm_cost_ += prod * dev_num; forward_comm_cost_ += prod * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
backward_comm_cost_ += prod; backward_comm_cost_ += prod * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
comm_cost_ += prod * (dev_num + 1.0); comm_cost_ += prod * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
int32_t concat_dim = op.second[0]; int32_t concat_dim = op.second[0];
if (concat_dim == 0) { if (concat_dim == 0) {
// computation cost = all_gather // computation cost = all_gather
......
...@@ -33,6 +33,8 @@ ...@@ -33,6 +33,8 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
constexpr double ALLTOALL_SCALE_FACTOR = 2.0;
constexpr double ALLGATHER_REDUCESCATTER_SCALE_FACTOR = 0.5;
class TensorRedistribution { class TensorRedistribution {
public: public:
explicit TensorRedistribution(bool construct_op_flag = true, bool keep_reshape = false) explicit TensorRedistribution(bool construct_op_flag = true, bool keep_reshape = false)
...@@ -46,7 +48,7 @@ class TensorRedistribution { ...@@ -46,7 +48,7 @@ class TensorRedistribution {
keep_reshape_(keep_reshape) {} keep_reshape_(keep_reshape) {}
Status Init(const TensorLayout& from, const TensorLayout& to, const RankList& dev_list); Status Init(const TensorLayout& from, const TensorLayout& to, const RankList& dev_list);
~TensorRedistribution() = default; ~TensorRedistribution() = default;
RedistributionOpListPtr InferTensorRedistributionOperatorList(); RedistributionOpListPtr InferTensorRedistributionOperatorList(bool is_cost_model = false);
OperatorList operator_list() const { return operator_list_; } OperatorList operator_list() const { return operator_list_; }
bool reshape_flag() const { return reshape_flag_; } bool reshape_flag() const { return reshape_flag_; }
Status ComputeCost(); Status ComputeCost();
......
...@@ -304,7 +304,7 @@ def train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768): ...@@ -304,7 +304,7 @@ def train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768):
def test_train_32k_8p_fusion1(epoch_size=3, batch_size=32, num_classes=32768): #1048576 #131072 #32768 #8192 def test_train_32k_8p_fusion1(epoch_size=3, batch_size=32, num_classes=32768): #1048576 #131072 #32768 #8192
cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=260.0) cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2)
cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5)
...@@ -651,7 +651,7 @@ def test_train_32k_8p_fusion2(epoch_size=3, batch_size=32, num_classes=32768): # ...@@ -651,7 +651,7 @@ def test_train_32k_8p_fusion2(epoch_size=3, batch_size=32, num_classes=32768): #
def test_train_64k_8p(epoch_size=3, batch_size=32, num_classes=65536): #1048576 #131072 #32768 #8192 def test_train_64k_8p(epoch_size=3, batch_size=32, num_classes=65536): #1048576 #131072 #32768 #8192
dev_num = 8 dev_num = 8
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num)
cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=260.0) cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0)
set_algo_parameters(elementwise_op_strategy_follow=True) set_algo_parameters(elementwise_op_strategy_follow=True)
resset_op_id() resset_op_id()
np.random.seed(6) np.random.seed(6)
......
...@@ -86,7 +86,7 @@ def test_two_matmul(): ...@@ -86,7 +86,7 @@ def test_two_matmul():
costmodel_alpha = cost_model_context.get_cost_model_context("costmodel_alpha") costmodel_alpha = cost_model_context.get_cost_model_context("costmodel_alpha")
assert costmodel_alpha == 1.0 assert costmodel_alpha == 1.0
costmodel_beta = cost_model_context.get_cost_model_context("costmodel_beta") costmodel_beta = cost_model_context.get_cost_model_context("costmodel_beta")
assert costmodel_beta == 260.0 assert costmodel_beta == 400.0
costmodel_gamma = cost_model_context.get_cost_model_context("costmodel_gamma") costmodel_gamma = cost_model_context.get_cost_model_context("costmodel_gamma")
assert costmodel_gamma == 0.001 assert costmodel_gamma == 0.001
costmodel_communi_threshold = cost_model_context.get_cost_model_context("costmodel_communi_threshold") costmodel_communi_threshold = cost_model_context.get_cost_model_context("costmodel_communi_threshold")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册