diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index 6b794e0d3e0d0d66132883e954057076e978d600..6ff2a2bb6fcf76e214e232c1fd56535aa2ba2bef 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -473,10 +473,19 @@ class AdamOpKernel : public framework::OpKernel { lr.template data(), grad_data, param.template data(), param_out.template mutable_data(ctx.GetPlace()), rows, row_numel, grad_merge.rows().size(), lazy_mode); - // multi thread speedup - if (FLAGS_inner_op_parallelism > 1 && - FLAGS_min_param_size_to_use_multithread > 0 && - param.numel() > FLAGS_min_param_size_to_use_multithread) { + if (lazy_mode) { + VLOG(3) << "run cpu lazy mode"; + size_t row_count = grad_merge.rows().size(); + std::vector cpu_rows(grad_merge.rows()); + for (size_t row_index = 0; row_index < row_count; ++row_index) { + for (size_t offset = 0; offset < row_numel; ++offset) { + size_t i = cpu_rows[row_index] * row_numel + offset; + functor.adam_update(i, grad_data[row_index * row_numel + offset]); + } + } + } else if (FLAGS_inner_op_parallelism > 1 && + FLAGS_min_param_size_to_use_multithread > 0 && + param.numel() > FLAGS_min_param_size_to_use_multithread) { VLOG(3) << "use multi thread, inner_op_parallelism=" << FLAGS_inner_op_parallelism << " min_param_size_to_use_multithread=" @@ -508,20 +517,7 @@ class AdamOpKernel : public framework::OpKernel { } for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); } else { - if (lazy_mode) { - VLOG(3) << "run cpu lazy mode"; - size_t row_count = grad_merge.rows().size(); - std::vector cpu_rows(grad_merge.rows()); - for (size_t row_index = 0; row_index < row_count; ++row_index) { - for (size_t offset = 0; offset < row_numel; ++offset) { - size_t i = cpu_rows[row_index] * row_numel + offset; - functor.adam_update(i, - grad_data[row_index * row_numel + offset]); - } - } - } else { - functor(param.numel()); - } + functor(param.numel()); } } else if (platform::is_gpu_place(ctx.GetPlace())) { SparseAdamFunctor functor(