提交 be63f8b8 编写于 作者: C cristoval 提交者: ZPaC

bugfix for server shard range computation

上级 7371cedd
......@@ -149,10 +149,12 @@ void SparseOptimInfo::ComputeMean(const std::shared_ptr<std::vector<std::shared_
size_t original_row_count = input_shapes->front();
if (original_row_count > 0) {
size_t offset = 0;
if ((original_row_count % server_num) == 0) {
offset = original_row_count / server_num * rank_id;
} else {
offset = std::round((static_cast<float>(original_row_count)) / server_num) * rank_id;
std::map<int, int> rank_dims = Util::AllRankLocalShard(original_row_count, rank_id, server_num);
for (size_t i = 0; i < rank_id; i++) {
if (rank_dims.count(i) == 0) {
MS_LOG(EXCEPTION) << "No local shard number for rank " << i;
}
offset += rank_dims[i];
}
for (size_t i = 0; i < indices_size; i++) {
indices_data[i] -= offset;
......
......@@ -134,13 +134,33 @@ std::string Util::optimizer_node_name(int id) {
bool Util::is_optimizer(std::string name) { return optimizer_to_ids.count(name) > 0; }
int Util::LocalShard(int first_dim, int rank_id, int server_num) {
int shard_size = std::round((static_cast<float>(first_dim)) / server_num);
int remain_size = first_dim % server_num;
if (remain_size == 0 || rank_id < server_num - 1) {
return shard_size;
} else {
return first_dim - (shard_size * (server_num - 1));
std::map<int, int> shard_dims = AllRankLocalShard(first_dim, rank_id, server_num);
if (shard_dims.count(rank_id) == 0) {
MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id;
}
return shard_dims[rank_id];
}
std::map<int, int> Util::AllRankLocalShard(int first_dim, int rank_id, int server_num) {
if (rank_id >= server_num) {
MS_LOG(EXCEPTION) << "The rank ID " << rank_id << " should be less than the number of servers " << server_num;
}
std::map<int, int> shard_dims;
for (int i = 0; i < server_num; i++) {
shard_dims[i] = 0;
}
if (server_num != static_cast<int>(shard_dims.size())) {
MS_LOG(EXCEPTION) << "Inconsistent server num " << server_num << " shard dims counter size " << shard_dims.size();
}
int server_index = -1;
for (int i = 0; i < first_dim; i++) {
server_index = (server_index + 1) % server_num;
shard_dims[server_index] = shard_dims[server_index] + 1;
}
if (shard_dims.count(rank_id) == 0) {
MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id << ", total server num " << server_num;
}
return shard_dims;
}
void Util::SetRankId(int rank_id) { rank_id_ = rank_id; }
......
......@@ -39,6 +39,7 @@ class Util {
static std::string optimizer_node_name(int id);
static bool is_optimizer(std::string name);
static int LocalShard(int first_dim, int rank_id, int server_num);
static std::map<int, int> AllRankLocalShard(int first_dim, int rank_id, int server_num);
static void SetRankId(int rank_id);
static int GetRankId();
static void ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册