提交 1177b0bc 编写于 作者: Q Qiao Longfei

update multi thread adam

上级 3b294e2e
...@@ -491,17 +491,20 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -491,17 +491,20 @@ class AdamOpKernel : public framework::OpKernel<T> {
row_id_to_grad_row_offset[grad_rows[i]] = i; row_id_to_grad_row_offset[grad_rows[i]] = i;
} }
std::vector<std::future<void>> fs; std::vector<std::future<void>> fs;
int64_t line_in_each_thread = param_row_count / FLAGS_inner_op_parallelism; int64_t line_in_each_thread =
param_row_count / FLAGS_inner_op_parallelism;
for (int i = 0; i < FLAGS_inner_op_parallelism; ++i) { for (int i = 0; i < FLAGS_inner_op_parallelism; ++i) {
int64_t start = i * line_in_each_thread; int64_t start = i * line_in_each_thread;
int64_t end = (i + 1) * line_in_each_thread; int64_t end = (i + 1) * line_in_each_thread;
if (end > param_row_count) { if (end > param_row_count) {
end = param_row_count; end = param_row_count;
} }
fs.push_back(framework::Async([&functor, &row_id_to_grad_row_offset, start, end]() { fs.push_back(framework::Async(
[&functor, &row_id_to_grad_row_offset, start, end]() {
for (int64_t i = start; i < end; ++i) { for (int64_t i = start; i < end; ++i) {
functor.update_row(i, row_id_to_grad_row_offset[i]); functor.update_row(i, row_id_to_grad_row_offset[i]);
}})); }
}));
} }
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
} else { } else {
...@@ -511,7 +514,8 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -511,7 +514,8 @@ class AdamOpKernel : public framework::OpKernel<T> {
for (size_t row_index = 0; row_index < row_count; ++row_index) { for (size_t row_index = 0; row_index < row_count; ++row_index) {
for (size_t offset = 0; offset < row_numel; ++offset) { for (size_t offset = 0; offset < row_numel; ++offset) {
size_t i = cpu_rows[row_index] * row_numel + offset; size_t i = cpu_rows[row_index] * row_numel + offset;
functor.adam_update(i, grad_data[row_index * row_numel + offset]); functor.adam_update(i,
grad_data[row_index * row_numel + offset]);
} }
} }
} else { } else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册