提交 4094c468 编写于 作者: K kswang

use multi thread for reduce sparse gradient

上级 faa1084b
...@@ -579,8 +579,40 @@ void WorkerForReduceSparseGradient(WorkerParamsForReduceSparseGradient param) { ...@@ -579,8 +579,40 @@ void WorkerForReduceSparseGradient(WorkerParamsForReduceSparseGradient param) {
} }
} }
void RunMultiThreadReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad,
size_t outer_dim, std::vector<std::pair<int, size_t>> *sorted_indices,
std::vector<size_t> *slice_positions) {
MS_LOG(DEBUG) << "Start";
size_t thread_num = 24;
if (slice_positions->size() < thread_num) {
thread_num = slice_positions->size();
}
size_t stride = (slice_positions->size() + thread_num - 1) / thread_num;
thread_num = (slice_positions->size() + stride - 1) / stride;
std::vector<std::thread> threads;
size_t max_length = sorted_indices->size() * outer_dim;
for (size_t i = 0; i < thread_num; ++i) {
size_t slice_start = i * stride;
size_t slice_end = 0;
if (i == thread_num - 1) {
slice_end = slice_positions->size();
} else {
slice_end = slice_start + stride;
}
WorkerParamsForReduceSparseGradient params{
slice_start, slice_end, max_length, outer_dim, sorted_indices, slice_positions, origin_sparse_grad.value_,
unique_grad};
threads.emplace_back(std::thread(WorkerForReduceSparseGradient, params));
}
for (size_t i = 0; i < thread_num; ++i) {
threads[i].join();
}
MS_LOG(DEBUG) << "End";
}
void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim,
size_t outer_dim) { size_t outer_dim, bool use_multi_threads) {
MS_LOG(DEBUG) << "Start";
MS_EXCEPTION_IF_NULL(origin_sparse_grad.value_); MS_EXCEPTION_IF_NULL(origin_sparse_grad.value_);
MS_EXCEPTION_IF_NULL(origin_sparse_grad.indices_); MS_EXCEPTION_IF_NULL(origin_sparse_grad.indices_);
MS_EXCEPTION_IF_NULL(unique_grad); MS_EXCEPTION_IF_NULL(unique_grad);
...@@ -599,42 +631,35 @@ void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradie ...@@ -599,42 +631,35 @@ void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradie
[](const std::pair<int, size_t> &left, const std::pair<int, size_t> &right) { return left.first < right.first; }); [](const std::pair<int, size_t> &left, const std::pair<int, size_t> &right) { return left.first < right.first; });
int last_index = 0; int last_index = 0;
std::vector<size_t> slice_positions; std::vector<size_t> slice_positions;
slice_positions.reserve(sorted_indices.size());
for (size_t i = 0; i < sorted_indices.size(); ++i) { for (size_t i = 0; i < sorted_indices.size(); ++i) {
if (i == 0 || last_index != sorted_indices[i].first) { if (i == 0 || last_index != sorted_indices[i].first) {
slice_positions.emplace_back(i); slice_positions.emplace_back(i);
} }
last_index = sorted_indices[i].first; last_index = sorted_indices[i].first;
} }
size_t thread_num = 8; if (use_multi_threads) {
if (slice_positions.size() < thread_num) { RunMultiThreadReduceSparseGradient(origin_sparse_grad, unique_grad, outer_dim, &sorted_indices, &slice_positions);
thread_num = slice_positions.size(); } else {
} size_t max_length = sorted_indices.size() * outer_dim;
size_t stride = (slice_positions.size() + thread_num - 1) / thread_num; WorkerParamsForReduceSparseGradient params{0,
thread_num = (slice_positions.size() + stride - 1) / stride; slice_positions.size(),
std::vector<std::thread> threads; max_length,
size_t max_length = sorted_indices.size() * outer_dim; outer_dim,
for (size_t i = 0; i < thread_num; ++i) { &sorted_indices,
size_t slice_start = i * stride; &slice_positions,
size_t slice_end = 0; origin_sparse_grad.value_,
if (i == thread_num - 1) { unique_grad};
slice_end = slice_positions.size(); WorkerForReduceSparseGradient(params);
} else {
slice_end = slice_start + stride;
}
WorkerParamsForReduceSparseGradient params{
slice_start, slice_end, max_length, outer_dim, &sorted_indices, &slice_positions, origin_sparse_grad.value_,
unique_grad};
threads.emplace_back(std::thread(WorkerForReduceSparseGradient, params));
}
for (size_t i = 0; i < thread_num; ++i) {
threads[i].join();
} }
unique_grad->indices_size_ = slice_positions.size(); unique_grad->indices_size_ = slice_positions.size();
MS_LOG(DEBUG) << "End";
} }
void ReduceMultiSparseGradient(const std::vector<std::shared_ptr<SparseGradient>> &unique_slice_grads, void ReduceMultiSparseGradient(const std::vector<std::shared_ptr<SparseGradient>> &unique_slice_grads,
SparseGradient *tmp_grad, SparseGradient *unique_grad, size_t first_dim, SparseGradient *tmp_grad, SparseGradient *unique_grad, size_t first_dim,
size_t outer_dim) { size_t outer_dim) {
MS_LOG(DEBUG) << "Start";
if (unique_slice_grads.empty()) { if (unique_slice_grads.empty()) {
return; return;
} }
...@@ -658,10 +683,12 @@ void ReduceMultiSparseGradient(const std::vector<std::shared_ptr<SparseGradient> ...@@ -658,10 +683,12 @@ void ReduceMultiSparseGradient(const std::vector<std::shared_ptr<SparseGradient>
} }
tmp_grad->indices_size_ = unique_indices_size; tmp_grad->indices_size_ = unique_indices_size;
ReduceSparseGradient(*tmp_grad, unique_grad, first_dim, outer_dim); ReduceSparseGradient(*tmp_grad, unique_grad, first_dim, outer_dim);
MS_LOG(DEBUG) << "End";
} }
void TwoLevelReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *tmp_grad, void TwoLevelReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *tmp_grad,
SparseGradient *unique_grad, size_t first_dim, size_t outer_dim) { SparseGradient *unique_grad, size_t first_dim, size_t outer_dim) {
MS_LOG(DEBUG) << "Start";
MS_EXCEPTION_IF_NULL(origin_sparse_grad.value_); MS_EXCEPTION_IF_NULL(origin_sparse_grad.value_);
MS_EXCEPTION_IF_NULL(origin_sparse_grad.indices_); MS_EXCEPTION_IF_NULL(origin_sparse_grad.indices_);
MS_EXCEPTION_IF_NULL(unique_grad); MS_EXCEPTION_IF_NULL(unique_grad);
...@@ -693,12 +720,13 @@ void TwoLevelReduceSparseGradient(const SparseGradient &origin_sparse_grad, Spar ...@@ -693,12 +720,13 @@ void TwoLevelReduceSparseGradient(const SparseGradient &origin_sparse_grad, Spar
unique_slice_grads[i]->indices_ = unique_grad->indices_ + indices_offset; unique_slice_grads[i]->indices_ = unique_grad->indices_ + indices_offset;
unique_slice_grads[i]->indices_size_ = indices_size; unique_slice_grads[i]->indices_size_ = indices_size;
threads.emplace_back( threads.emplace_back(
std::thread(ReduceSparseGradient, slice_grad, unique_slice_grads[i].get(), first_dim, outer_dim)); std::thread(ReduceSparseGradient, slice_grad, unique_slice_grads[i].get(), first_dim, outer_dim, false));
} }
for (size_t i = 0; i < thread_num; ++i) { for (size_t i = 0; i < thread_num; ++i) {
threads[i].join(); threads[i].join();
} }
ReduceMultiSparseGradient(unique_slice_grads, tmp_grad, unique_grad, first_dim, outer_dim); ReduceMultiSparseGradient(unique_slice_grads, tmp_grad, unique_grad, first_dim, outer_dim);
MS_LOG(DEBUG) << "End";
} }
std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index) { std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index) {
......
...@@ -115,7 +115,7 @@ int Sign(float x); ...@@ -115,7 +115,7 @@ int Sign(float x);
void DeduplicateIndexedSlices(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, void DeduplicateIndexedSlices(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim,
size_t outer_dim); size_t outer_dim);
void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim, void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim,
size_t outer_dim); size_t outer_dim, bool use_multi_threads = true);
std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index); std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index);
std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list, std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list,
const std::vector<AnfNodePtr> &input_list); const std::vector<AnfNodePtr> &input_list);
...@@ -130,6 +130,9 @@ void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<An ...@@ -130,6 +130,9 @@ void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<An
bool IsWeightBoundary(const AnfNodePtr &node); bool IsWeightBoundary(const AnfNodePtr &node);
void MultiThreadCompute(const MultiThreadComputeFunc &func, MultiThreadComputeParams *params, void MultiThreadCompute(const MultiThreadComputeFunc &func, MultiThreadComputeParams *params,
size_t total_compute_size); size_t total_compute_size);
void RunMultiThreadReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad,
size_t outer_dim, std::vector<std::pair<int, size_t>> *sorted_indices,
std::vector<size_t> *slice_positions);
void ReduceMultiSparseGradient(const std::vector<std::shared_ptr<SparseGradient>> &unique_slice_grads, void ReduceMultiSparseGradient(const std::vector<std::shared_ptr<SparseGradient>> &unique_slice_grads,
SparseGradient *tmp_grad, SparseGradient *unique_grad, size_t first_dim, SparseGradient *tmp_grad, SparseGradient *unique_grad, size_t first_dim,
size_t outer_dim); size_t outer_dim);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册