提交 1177b0bc 编写于 作者: Q Qiao Longfei

update multi thread adam

上级 3b294e2e
...@@ -465,14 +465,14 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -465,14 +465,14 @@ class AdamOpKernel : public framework::OpKernel<T> {
if (platform::is_cpu_place(ctx.GetPlace())) { if (platform::is_cpu_place(ctx.GetPlace())) {
SparseAdamFunctor<T, CPUAdam> functor( SparseAdamFunctor<T, CPUAdam> functor(
beta1, beta2, epsilon, beta1_pow.template data<T>(), beta1, beta2, epsilon, beta1_pow.template data<T>(),
beta2_pow.template data<T>(), mom1.template data<T>(), beta2_pow.template data<T>(), mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()), mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(), mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()), mom2_out.template mutable_data<T>(ctx.GetPlace()),
lr.template data<T>(), grad_data, param.template data<T>(), lr.template data<T>(), grad_data, param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel, param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size(), lazy_mode); grad_merge.rows().size(), lazy_mode);
// multi thread speedup // multi thread speedup
if (FLAGS_inner_op_parallelism > 1 && if (FLAGS_inner_op_parallelism > 1 &&
FLAGS_min_param_size_to_use_multithread > 0 && FLAGS_min_param_size_to_use_multithread > 0 &&
...@@ -491,17 +491,20 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -491,17 +491,20 @@ class AdamOpKernel : public framework::OpKernel<T> {
row_id_to_grad_row_offset[grad_rows[i]] = i; row_id_to_grad_row_offset[grad_rows[i]] = i;
} }
std::vector<std::future<void>> fs; std::vector<std::future<void>> fs;
int64_t line_in_each_thread = param_row_count / FLAGS_inner_op_parallelism; int64_t line_in_each_thread =
param_row_count / FLAGS_inner_op_parallelism;
for (int i = 0; i < FLAGS_inner_op_parallelism; ++i) { for (int i = 0; i < FLAGS_inner_op_parallelism; ++i) {
int64_t start = i * line_in_each_thread; int64_t start = i * line_in_each_thread;
int64_t end = (i + 1) * line_in_each_thread; int64_t end = (i + 1) * line_in_each_thread;
if (end > param_row_count) { if (end > param_row_count) {
end = param_row_count; end = param_row_count;
} }
fs.push_back(framework::Async([&functor, &row_id_to_grad_row_offset, start, end]() { fs.push_back(framework::Async(
for (int64_t i = start; i < end; ++i) { [&functor, &row_id_to_grad_row_offset, start, end]() {
functor.update_row(i, row_id_to_grad_row_offset[i]); for (int64_t i = start; i < end; ++i) {
}})); functor.update_row(i, row_id_to_grad_row_offset[i]);
}
}));
} }
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
} else { } else {
...@@ -511,7 +514,8 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -511,7 +514,8 @@ class AdamOpKernel : public framework::OpKernel<T> {
for (size_t row_index = 0; row_index < row_count; ++row_index) { for (size_t row_index = 0; row_index < row_count; ++row_index) {
for (size_t offset = 0; offset < row_numel; ++offset) { for (size_t offset = 0; offset < row_numel; ++offset) {
size_t i = cpu_rows[row_index] * row_numel + offset; size_t i = cpu_rows[row_index] * row_numel + offset;
functor.adam_update(i, grad_data[row_index * row_numel + offset]); functor.adam_update(i,
grad_data[row_index * row_numel + offset]);
} }
} }
} else { } else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册