From f1c973b0141b4396596bccaace1848ddec6faa24 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 28 Dec 2018 18:33:02 +0800 Subject: [PATCH] adam op should not create tmp var in compute --- paddle/fluid/operators/optimizers/adam_op.h | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index 1138bb740..de18edcd4 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -423,6 +423,7 @@ class AdamOpKernel : public framework::OpKernel { } } + framework::SelectedRows cpu_grad_merge; const framework::SelectedRows* grad_merge_ptr; if (is_strict_sorted) { grad_merge_ptr = &grad; @@ -430,12 +431,16 @@ class AdamOpKernel : public framework::OpKernel { // merge duplicated rows if any. // The rows of grad_merge have been sorted inside MergeAdd functor scatter::MergeAdd merge_func; - auto* grad_merge_var = const_cast(ctx.scope()) - .Var() - ->GetMutable(); + if (platform::is_cpu_place(ctx.GetPlace())) { + grad_merge_ptr = &cpu_grad_merge; + } else { + // FIXME(qiao): GPU also need to fix this + auto* grad_merge_var = const_cast(ctx.scope()) + .Var() + ->GetMutable(); + } merge_func(ctx.template device_context(), grad, - grad_merge_var, true); - grad_merge_ptr = grad_merge_var; + grad_merge_ptr, true); } auto& grad_merge = *grad_merge_ptr; -- GitLab