diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index 5870557bb7b265672aae514eab07e55fd5428075..e8b977e2d962224dd8b753ec70b7d345847bfa42 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -359,14 +359,17 @@ class AdamOpKernel : public framework::OpKernel { param_out.template mutable_data(ctx.GetPlace()), rows, row_numel, grad_merge.rows().size(), lazy_mode); if (lazy_mode) { + std::vector id_vector; size_t row_count = grad_merge.rows().size(); for (size_t row_index = 0; row_index < row_count; ++row_index) { for (size_t offset = 0; offset < row_numel; ++offset) { size_t i = rows[row_index] * row_numel + offset; - T g = grad_data[row_index * row_numel + offset]; - functor.adam_update(i, g); + id_vector.push_back(i); } } + platform::ForRangeIn for_range_in( + static_cast(ctx.device_context()), id_vector); + for_range_in(functor); } else { platform::ForRange for_range( static_cast(ctx.device_context()), diff --git a/paddle/fluid/platform/for_range.h b/paddle/fluid/platform/for_range.h index c153e80fe42aecb33d3aa97874d2881bce9029be..9fbaa36723bcfa6a4273a24ba550c225a4daf582 100644 --- a/paddle/fluid/platform/for_range.h +++ b/paddle/fluid/platform/for_range.h @@ -13,11 +13,38 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + +#include + +#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/platform/device_context.h" namespace paddle { namespace platform { +template +struct ForRangeIn { + ForRangeIn(const DeviceContext& dev_ctx, std::vector range); + + template + void operator()(Function func) const; +}; + +template <> +struct ForRangeIn { + ForRangeIn(const CPUDeviceContext& dev_ctx, std::vector range) + : range_(range) {} + + template + void operator()(Function func) const { + for (auto i : range_) { + func(i); + } + } + + std::vector range_; +}; + template struct ForRange { ForRange(const DeviceContext& dev_ctx, size_t limit); @@ -79,6 +106,34 @@ struct ForRange { int limit_; }; +template +__global__ static void ForRangeInElemwiseOp(Function func, T* vector, + int vector_size) { + size_t idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx < vector_size) { + func(vector[idx]); + } +} + +template <> +struct ForRangeIn { + ForRange(const CUDADeviceContext& dev_ctx, std::vector range) + : dev_ctx_(dev_ctx), range_(range) {} + + template + inline void operator()(Function func) const { + constexpr int num_threads = 1024; + int block_size = range_.size() <= num_threads ? limit_ : num_threads; + int grid_size = (range_.size() + num_threads - 1) / num_threads; + + ForRangeInElemwiseOp<<>>( + func, range_.data(), range_.size()); + } + + const CUDADeviceContext& dev_ctx_; + framework::Vector range_; +}; + #endif } // namespace platform