提交 fcde2b27 编写于 作者: Q Qiao Longfei

add ForRangeIn

上级 cf526462
...@@ -359,14 +359,17 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -359,14 +359,17 @@ class AdamOpKernel : public framework::OpKernel<T> {
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel, param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size(), lazy_mode); grad_merge.rows().size(), lazy_mode);
if (lazy_mode) { if (lazy_mode) {
std::vector<int64_t> id_vector;
size_t row_count = grad_merge.rows().size(); size_t row_count = grad_merge.rows().size();
for (size_t row_index = 0; row_index < row_count; ++row_index) { for (size_t row_index = 0; row_index < row_count; ++row_index) {
for (size_t offset = 0; offset < row_numel; ++offset) { for (size_t offset = 0; offset < row_numel; ++offset) {
size_t i = rows[row_index] * row_numel + offset; size_t i = rows[row_index] * row_numel + offset;
T g = grad_data[row_index * row_numel + offset]; id_vector.push_back(i);
functor.adam_update(i, g);
} }
} }
platform::ForRangeIn<DeviceContext> for_range_in(
static_cast<const DeviceContext&>(ctx.device_context()), id_vector);
for_range_in(functor);
} else { } else {
platform::ForRange<DeviceContext> for_range( platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()), static_cast<const DeviceContext&>(ctx.device_context()),
......
...@@ -13,11 +13,38 @@ See the License for the specific language governing permissions and ...@@ -13,11 +13,38 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <vector>
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
template <typename DeviceContext>
struct ForRangeIn {
ForRangeIn(const DeviceContext& dev_ctx, std::vector<int64_t> range);
template <typename Function>
void operator()(Function func) const;
};
template <>
struct ForRangeIn<CPUDeviceContext> {
ForRangeIn(const CPUDeviceContext& dev_ctx, std::vector<int64_t> range)
: range_(range) {}
template <typename Function>
void operator()(Function func) const {
for (auto i : range_) {
func(i);
}
}
std::vector<int64_t> range_;
};
template <typename DeviceContext> template <typename DeviceContext>
struct ForRange { struct ForRange {
ForRange(const DeviceContext& dev_ctx, size_t limit); ForRange(const DeviceContext& dev_ctx, size_t limit);
...@@ -79,6 +106,34 @@ struct ForRange<CUDADeviceContext> { ...@@ -79,6 +106,34 @@ struct ForRange<CUDADeviceContext> {
int limit_; int limit_;
}; };
template <typename T, typename Function>
__global__ static void ForRangeInElemwiseOp(Function func, T* vector,
int vector_size) {
size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);
if (idx < vector_size) {
func(vector[idx]);
}
}
template <>
struct ForRangeIn<CUDADeviceContext> {
ForRange(const CUDADeviceContext& dev_ctx, std::vector<int64_t> range)
: dev_ctx_(dev_ctx), range_(range) {}
template <typename Function>
inline void operator()(Function func) const {
constexpr int num_threads = 1024;
int block_size = range_.size() <= num_threads ? limit_ : num_threads;
int grid_size = (range_.size() + num_threads - 1) / num_threads;
ForRangeInElemwiseOp<<<grid_size, block_size, 0, dev_ctx_.stream()>>>(
func, range_.data(), range_.size());
}
const CUDADeviceContext& dev_ctx_;
framework::Vector<int64_t> range_;
};
#endif #endif
} // namespace platform } // namespace platform
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册