提交 fd152289 编写于 作者: Q Qiao Longfei

clean for range in test=develop

上级 1141db81
...@@ -227,9 +227,11 @@ struct SparseAdamFunctor { ...@@ -227,9 +227,11 @@ struct SparseAdamFunctor {
inline HOSTDEVICE void operator()(size_t i) const { inline HOSTDEVICE void operator()(size_t i) const {
auto row_idx = auto row_idx =
math::BinarySearch<int64_t>(rows_, row_count_, i / row_numel_); math::BinarySearch<int64_t>(rows_, row_count_, i / row_numel_);
if (!(lazy_mode_ && row_idx < 0)) {
T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0; T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0;
adam_update(i, g); adam_update(i, g);
} }
}
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -359,19 +361,15 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -359,19 +361,15 @@ class AdamOpKernel : public framework::OpKernel<T> {
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel, param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size(), lazy_mode); grad_merge.rows().size(), lazy_mode);
VLOG(3) << "lazy_mode :" << lazy_mode; VLOG(3) << "lazy_mode :" << lazy_mode;
if (lazy_mode) { if (lazy_mode && platform::is_cpu_place(ctx.GetPlace())) {
std::vector<int64_t> id_vector;
size_t row_count = grad_merge.rows().size(); size_t row_count = grad_merge.rows().size();
std::vector<int64_t> cpu_rows(grad_merge.rows()); std::vector<int64_t> cpu_rows(grad_merge.rows());
for (size_t row_index = 0; row_index < row_count; ++row_index) { for (size_t row_index = 0; row_index < row_count; ++row_index) {
for (size_t offset = 0; offset < row_numel; ++offset) { for (size_t offset = 0; offset < row_numel; ++offset) {
size_t i = cpu_rows[row_index] * row_numel + offset; size_t i = cpu_rows[row_index] * row_numel + offset;
id_vector.push_back(i); functor.adam_update(i, grad_data[row_index * row_numel + offset]);
} }
} }
platform::ForRangeIn<DeviceContext> for_range_in(
static_cast<const DeviceContext&>(ctx.device_context()), id_vector);
for_range_in(functor);
} else { } else {
platform::ForRange<DeviceContext> for_range( platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()), static_cast<const DeviceContext&>(ctx.device_context()),
......
...@@ -22,29 +22,6 @@ limitations under the License. */ ...@@ -22,29 +22,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
template <typename DeviceContext>
struct ForRangeIn {
ForRangeIn(const DeviceContext& dev_ctx, std::vector<int64_t> range);
template <typename Function>
void operator()(Function func) const;
};
template <>
struct ForRangeIn<CPUDeviceContext> {
ForRangeIn(const CPUDeviceContext& dev_ctx, std::vector<int64_t> range)
: range_(range) {}
template <typename Function>
void operator()(Function func) const {
for (auto i : range_) {
func(i);
}
}
std::vector<int64_t> range_;
};
template <typename DeviceContext> template <typename DeviceContext>
struct ForRange { struct ForRange {
ForRange(const DeviceContext& dev_ctx, size_t limit); ForRange(const DeviceContext& dev_ctx, size_t limit);
...@@ -106,35 +83,6 @@ struct ForRange<CUDADeviceContext> { ...@@ -106,35 +83,6 @@ struct ForRange<CUDADeviceContext> {
int limit_; int limit_;
}; };
template <typename T, typename Function>
__global__ static void ForRangeInElemwiseOp(Function func, T* vector,
int vector_size) {
size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);
if (idx < vector_size) {
func(vector[idx]);
}
}
template <>
struct ForRangeIn<CUDADeviceContext> {
ForRangeIn(const CUDADeviceContext& dev_ctx, std::vector<int64_t> range)
: dev_ctx_(dev_ctx), range_(range) {}
template <typename Function>
inline void operator()(Function func) const {
constexpr int num_threads = 1024;
int range_size = range_.size();
int block_size = range_size <= num_threads ? range_size : num_threads;
int grid_size = (range_.size() + num_threads - 1) / num_threads;
ForRangeInElemwiseOp<<<grid_size, block_size, 0, dev_ctx_.stream()>>>(
func, range_.CUDAData(dev_ctx_.GetPlace()), range_size);
}
const CUDADeviceContext& dev_ctx_;
framework::Vector<int64_t> range_;
};
#endif #endif
} // namespace platform } // namespace platform
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册