diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index 1138bb7400e0e7a00983e7bfaad2b2d9704b77ab..de18edcd4468b42b4f7e301c457e68ffb4e9c02c 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -423,6 +423,7 @@ class AdamOpKernel : public framework::OpKernel { } } + framework::SelectedRows cpu_grad_merge; const framework::SelectedRows* grad_merge_ptr; if (is_strict_sorted) { grad_merge_ptr = &grad; @@ -430,12 +431,16 @@ class AdamOpKernel : public framework::OpKernel { // merge duplicated rows if any. // The rows of grad_merge have been sorted inside MergeAdd functor scatter::MergeAdd merge_func; - auto* grad_merge_var = const_cast(ctx.scope()) - .Var() - ->GetMutable(); + if (platform::is_cpu_place(ctx.GetPlace())) { + grad_merge_ptr = &cpu_grad_merge; + } else { + // FIXME(qiao): GPU also need to fix this + auto* grad_merge_var = const_cast(ctx.scope()) + .Var() + ->GetMutable(); + } merge_func(ctx.template device_context(), grad, - grad_merge_var, true); - grad_merge_ptr = grad_merge_var; + grad_merge_ptr, true); } auto& grad_merge = *grad_merge_ptr;