diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index 5c559484ec95e794ebbbe0e713cb9e26b5c01b98..61b9384f8422cb531a94096875434ffe36ecdbce 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -424,16 +424,23 @@ class AdamOpKernel : public framework::OpKernel { } } + framework::SelectedRows cpu_grad_merge; const framework::SelectedRows* grad_merge_ptr; if (is_strict_sorted) { grad_merge_ptr = &grad; } else { // merge duplicated rows if any. // The rows of grad_merge have been sorted inside MergeAdd functor + framework::SelectedRows* grad_merge_var; scatter::MergeAdd merge_func; - auto* grad_merge_var = const_cast(ctx.scope()) - .Var() - ->GetMutable(); + if (platform::is_cpu_place(ctx.GetPlace())) { + grad_merge_var = &cpu_grad_merge; + } else { + // FIXME(qiao): GPU also need to fix this + grad_merge_var = const_cast(ctx.scope()) + .Var() + ->GetMutable(); + } merge_func(ctx.template device_context(), grad, grad_merge_var, true); grad_merge_ptr = grad_merge_var;