提交 1177b0bc 编写于 作者: Q Qiao Longfei

update multi thread adam

上级 3b294e2e
......@@ -491,17 +491,20 @@ class AdamOpKernel : public framework::OpKernel<T> {
row_id_to_grad_row_offset[grad_rows[i]] = i;
}
std::vector<std::future<void>> fs;
int64_t line_in_each_thread = param_row_count / FLAGS_inner_op_parallelism;
int64_t line_in_each_thread =
param_row_count / FLAGS_inner_op_parallelism;
for (int i = 0; i < FLAGS_inner_op_parallelism; ++i) {
int64_t start = i * line_in_each_thread;
int64_t end = (i + 1) * line_in_each_thread;
if (end > param_row_count) {
end = param_row_count;
}
fs.push_back(framework::Async([&functor, &row_id_to_grad_row_offset, start, end]() {
fs.push_back(framework::Async(
[&functor, &row_id_to_grad_row_offset, start, end]() {
for (int64_t i = start; i < end; ++i) {
functor.update_row(i, row_id_to_grad_row_offset[i]);
}}));
}
}));
}
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
} else {
......@@ -511,7 +514,8 @@ class AdamOpKernel : public framework::OpKernel<T> {
for (size_t row_index = 0; row_index < row_count; ++row_index) {
for (size_t offset = 0; offset < row_numel; ++offset) {
size_t i = cpu_rows[row_index] * row_numel + offset;
functor.adam_update(i, grad_data[row_index * row_numel + offset]);
functor.adam_update(i,
grad_data[row_index * row_numel + offset]);
}
}
} else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册