From fd152289fa694b99704e4821a71b0c1f160896aa Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Mon, 17 Dec 2018 22:14:11 +0800 Subject: [PATCH] clean for range in test=develop --- paddle/fluid/operators/optimizers/adam_op.h | 14 +++--- paddle/fluid/platform/for_range.h | 52 --------------------- 2 files changed, 6 insertions(+), 60 deletions(-) diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index 8fc6689ff..4f212bb69 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -227,8 +227,10 @@ struct SparseAdamFunctor { inline HOSTDEVICE void operator()(size_t i) const { auto row_idx = math::BinarySearch(rows_, row_count_, i / row_numel_); - T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0; - adam_update(i, g); + if (!(lazy_mode_ && row_idx < 0)) { + T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0; + adam_update(i, g); + } } }; @@ -359,19 +361,15 @@ class AdamOpKernel : public framework::OpKernel { param_out.template mutable_data(ctx.GetPlace()), rows, row_numel, grad_merge.rows().size(), lazy_mode); VLOG(3) << "lazy_mode :" << lazy_mode; - if (lazy_mode) { - std::vector id_vector; + if (lazy_mode && platform::is_cpu_place(ctx.GetPlace())) { size_t row_count = grad_merge.rows().size(); std::vector cpu_rows(grad_merge.rows()); for (size_t row_index = 0; row_index < row_count; ++row_index) { for (size_t offset = 0; offset < row_numel; ++offset) { size_t i = cpu_rows[row_index] * row_numel + offset; - id_vector.push_back(i); + functor.adam_update(i, grad_data[row_index * row_numel + offset]); } } - platform::ForRangeIn for_range_in( - static_cast(ctx.device_context()), id_vector); - for_range_in(functor); } else { platform::ForRange for_range( static_cast(ctx.device_context()), diff --git a/paddle/fluid/platform/for_range.h b/paddle/fluid/platform/for_range.h index ab00d8b8f..910d1669f 100644 --- a/paddle/fluid/platform/for_range.h +++ b/paddle/fluid/platform/for_range.h @@ -22,29 +22,6 @@ limitations under the License. */ namespace paddle { namespace platform { -template -struct ForRangeIn { - ForRangeIn(const DeviceContext& dev_ctx, std::vector range); - - template - void operator()(Function func) const; -}; - -template <> -struct ForRangeIn { - ForRangeIn(const CPUDeviceContext& dev_ctx, std::vector range) - : range_(range) {} - - template - void operator()(Function func) const { - for (auto i : range_) { - func(i); - } - } - - std::vector range_; -}; - template struct ForRange { ForRange(const DeviceContext& dev_ctx, size_t limit); @@ -106,35 +83,6 @@ struct ForRange { int limit_; }; -template -__global__ static void ForRangeInElemwiseOp(Function func, T* vector, - int vector_size) { - size_t idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (idx < vector_size) { - func(vector[idx]); - } -} - -template <> -struct ForRangeIn { - ForRangeIn(const CUDADeviceContext& dev_ctx, std::vector range) - : dev_ctx_(dev_ctx), range_(range) {} - - template - inline void operator()(Function func) const { - constexpr int num_threads = 1024; - int range_size = range_.size(); - int block_size = range_size <= num_threads ? range_size : num_threads; - int grid_size = (range_.size() + num_threads - 1) / num_threads; - - ForRangeInElemwiseOp<<>>( - func, range_.CUDAData(dev_ctx_.GetPlace()), range_size); - } - - const CUDADeviceContext& dev_ctx_; - framework::Vector range_; -}; - #endif } // namespace platform -- GitLab