diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index e8b977e2d962224dd8b753ec70b7d345847bfa42..01d3d600540607df6182288a3613109a5d76db2a 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -361,9 +361,10 @@ class AdamOpKernel : public framework::OpKernel { if (lazy_mode) { std::vector id_vector; size_t row_count = grad_merge.rows().size(); + std::vector cpu_rows(grad_merge.rows()); for (size_t row_index = 0; row_index < row_count; ++row_index) { for (size_t offset = 0; offset < row_numel; ++offset) { - size_t i = rows[row_index] * row_numel + offset; + size_t i = cpu_rows[row_index] * row_numel + offset; id_vector.push_back(i); } } diff --git a/paddle/fluid/platform/for_range.h b/paddle/fluid/platform/for_range.h index a767bf92993b901015d3dc447746777d7d52d70d..ab00d8b8f578ca9e00b23de5228e00ce204b67c7 100644 --- a/paddle/fluid/platform/for_range.h +++ b/paddle/fluid/platform/for_range.h @@ -128,7 +128,7 @@ struct ForRangeIn { int grid_size = (range_.size() + num_threads - 1) / num_threads; ForRangeInElemwiseOp<<>>( - func, range_.data(), range_size); + func, range_.CUDAData(dev_ctx_.GetPlace()), range_size); } const CUDADeviceContext& dev_ctx_;