diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index e9fbe15cbeb43da6d3029868c8fe5af1c4721f8d..f8c7b82053a113ad9c869a54efe955e81f8b1f66 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -465,14 +465,14 @@ class AdamOpKernel : public framework::OpKernel { if (platform::is_cpu_place(ctx.GetPlace())) { SparseAdamFunctor functor( - beta1, beta2, epsilon, beta1_pow.template data(), - beta2_pow.template data(), mom1.template data(), - mom1_out.template mutable_data(ctx.GetPlace()), - mom2.template data(), - mom2_out.template mutable_data(ctx.GetPlace()), - lr.template data(), grad_data, param.template data(), - param_out.template mutable_data(ctx.GetPlace()), rows, row_numel, - grad_merge.rows().size(), lazy_mode); + beta1, beta2, epsilon, beta1_pow.template data(), + beta2_pow.template data(), mom1.template data(), + mom1_out.template mutable_data(ctx.GetPlace()), + mom2.template data(), + mom2_out.template mutable_data(ctx.GetPlace()), + lr.template data(), grad_data, param.template data(), + param_out.template mutable_data(ctx.GetPlace()), rows, row_numel, + grad_merge.rows().size(), lazy_mode); // multi thread speedup if (FLAGS_inner_op_parallelism > 1 && FLAGS_min_param_size_to_use_multithread > 0 && @@ -491,17 +491,20 @@ class AdamOpKernel : public framework::OpKernel { row_id_to_grad_row_offset[grad_rows[i]] = i; } std::vector> fs; - int64_t line_in_each_thread = param_row_count / FLAGS_inner_op_parallelism; + int64_t line_in_each_thread = + param_row_count / FLAGS_inner_op_parallelism; for (int i = 0; i < FLAGS_inner_op_parallelism; ++i) { int64_t start = i * line_in_each_thread; int64_t end = (i + 1) * line_in_each_thread; if (end > param_row_count) { end = param_row_count; } - fs.push_back(framework::Async([&functor, &row_id_to_grad_row_offset, start, end]() { - for (int64_t i = start; i < end; ++i) { - functor.update_row(i, row_id_to_grad_row_offset[i]); - }})); + fs.push_back(framework::Async( + [&functor, &row_id_to_grad_row_offset, start, end]() { + for (int64_t i = start; i < end; ++i) { + functor.update_row(i, row_id_to_grad_row_offset[i]); + } + })); } for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); } else { @@ -511,7 +514,8 @@ class AdamOpKernel : public framework::OpKernel { for (size_t row_index = 0; row_index < row_count; ++row_index) { for (size_t offset = 0; offset < row_numel; ++offset) { size_t i = cpu_rows[row_index] * row_numel + offset; - functor.adam_update(i, grad_data[row_index * row_numel + offset]); + functor.adam_update(i, + grad_data[row_index * row_numel + offset]); } } } else {