提交 7a58ad5c 编写于 作者: Q Qiao Longfei

lazy mode have higher priority then multithread

test=develop
上级 d0572bf0
...@@ -473,10 +473,19 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -473,10 +473,19 @@ class AdamOpKernel : public framework::OpKernel<T> {
lr.template data<T>(), grad_data, param.template data<T>(), lr.template data<T>(), grad_data, param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel, param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size(), lazy_mode); grad_merge.rows().size(), lazy_mode);
// multi thread speedup if (lazy_mode) {
if (FLAGS_inner_op_parallelism > 1 && VLOG(3) << "run cpu lazy mode";
FLAGS_min_param_size_to_use_multithread > 0 && size_t row_count = grad_merge.rows().size();
param.numel() > FLAGS_min_param_size_to_use_multithread) { std::vector<int64_t> cpu_rows(grad_merge.rows());
for (size_t row_index = 0; row_index < row_count; ++row_index) {
for (size_t offset = 0; offset < row_numel; ++offset) {
size_t i = cpu_rows[row_index] * row_numel + offset;
functor.adam_update(i, grad_data[row_index * row_numel + offset]);
}
}
} else if (FLAGS_inner_op_parallelism > 1 &&
FLAGS_min_param_size_to_use_multithread > 0 &&
param.numel() > FLAGS_min_param_size_to_use_multithread) {
VLOG(3) << "use multi thread, inner_op_parallelism=" VLOG(3) << "use multi thread, inner_op_parallelism="
<< FLAGS_inner_op_parallelism << FLAGS_inner_op_parallelism
<< " min_param_size_to_use_multithread=" << " min_param_size_to_use_multithread="
...@@ -508,20 +517,7 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -508,20 +517,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
} }
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
} else { } else {
if (lazy_mode) { functor(param.numel());
VLOG(3) << "run cpu lazy mode";
size_t row_count = grad_merge.rows().size();
std::vector<int64_t> cpu_rows(grad_merge.rows());
for (size_t row_index = 0; row_index < row_count; ++row_index) {
for (size_t offset = 0; offset < row_numel; ++offset) {
size_t i = cpu_rows[row_index] * row_numel + offset;
functor.adam_update(i,
grad_data[row_index * row_numel + offset]);
}
}
} else {
functor(param.numel());
}
} }
} else if (platform::is_gpu_place(ctx.GetPlace())) { } else if (platform::is_gpu_place(ctx.GetPlace())) {
SparseAdamFunctor<T, GPUAdam> functor( SparseAdamFunctor<T, GPUAdam> functor(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册