未验证 提交 9f94821b 编写于 作者: N niuliling123 提交者: GitHub

Modified RandomKernel with Kernel Primitive API (#39666)

* Modified RandomKernel with Kernel Primitive API

* update pten.h to phi.h

* update

* update fullKernel
上级 f4e74887
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/distribution_helper.h"
#include "paddle/fluid/operators/fill_constant_op.h"
#include "paddle/fluid/operators/index_impl.cu.h"
DECLARE_bool(use_curand);
......@@ -65,7 +66,6 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
}
T mean = static_cast<T>(context.Attr<float>("mean"));
T std = static_cast<T>(context.Attr<float>("std"));
thrust::counting_iterator<int64_t> index_sequence_begin(0);
auto shape = GetShape(context);
tensor->Resize(shape);
......@@ -88,15 +88,13 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
} else {
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = size * seed_offset.second;
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed_offset.first, gen_offset));
auto func =
GaussianGenerator<T>(mean, std, seed_offset.first, gen_offset);
IndexKernel<T, GaussianGenerator<T>>(dev_cxt, tensor, func);
}
} else {
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed));
auto func = GaussianGenerator<T>(mean, std, seed);
IndexKernel<T, GaussianGenerator<T>>(dev_cxt, tensor, func);
}
}
};
......@@ -116,23 +114,22 @@ class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
}
T mean = static_cast<T>(context.Attr<float>("mean"));
T std = static_cast<T>(context.Attr<float>("std"));
thrust::counting_iterator<int64_t> index_sequence_begin(0);
int64_t size = tensor->numel();
int device_id = context.GetPlace().GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
auto& dev_cxt =
context.template device_context<platform::CUDADeviceContext>();
if (gen_cuda->GetIsInitPy() && seed_flag) {
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = size * seed_offset.second;
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed_offset.first,
seed_offset.second));
auto func = GaussianGenerator<T>(mean, std, seed_offset.first,
seed_offset.second);
IndexKernel<T, GaussianGenerator<T>>(dev_cxt, tensor, func);
} else {
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed));
auto func = GaussianGenerator<T>(mean, std, seed);
IndexKernel<T, GaussianGenerator<T>>(dev_cxt, tensor, func);
}
}
};
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/distribution_helper.h"
#include "paddle/fluid/operators/fill_constant_op.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
namespace paddle {
namespace operators {
namespace kps = phi::kps;
template <typename T, typename Functor, int VecSize>
__global__ void VectorizedIndexKernel(T *out, int numel, int main_offset,
Functor func) {
int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
int args[VecSize];
T result[VecSize];
for (; data_offset < main_offset; data_offset += stride) {
kps::InitWithDataIndex<int, VecSize, 1, 1>(&args[0], data_offset);
kps::ElementwiseUnary<int, T, VecSize, 1, 1, Functor>(&result[0], &args[0],
func);
kps::WriteData<T, VecSize, 1, 1, false>(out + data_offset, &result[0],
BLOCK_NUM_X * VecSize);
}
int num = numel - data_offset;
if (numel > 0) {
kps::InitWithDataIndex<int, VecSize, 1, 1>(&args[0], data_offset);
kps::ElementwiseUnary<int, T, VecSize, 1, 1, Functor>(&result[0], &args[0],
func);
kps::WriteData<T, VecSize, 1, 1, true>(out + data_offset, &result[0], num);
}
}
template <typename T, typename Functor>
void IndexKernel(const KPDevice &dev_ctx, Tensor *out, Functor func) {
int numel = out->numel();
T *out_data = out->mutable_data<T>(dev_ctx.GetPlace());
if (numel <= 0) return;
int vec_size = paddle::platform::GetVectorizedSize((out->data<T>()));
#ifdef PADDLE_WITH_XPU_KP
int block = 64;
int grid = 8;
auto stream = dev_ctx.x_context()->xpu_stream;
#else
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, vec_size);
int grid = config.block_per_grid.x;
int block = config.thread_per_block.x;
auto stream = dev_ctx.stream();
#endif
int main_offset = (numel / (vec_size * block)) * vec_size * block;
switch (vec_size) {
case 4:
VectorizedIndexKernel<T, Functor, 4><<<grid, block, 0, stream>>>(
out_data, numel, main_offset, func);
break;
case 2:
VectorizedIndexKernel<T, Functor, 2><<<grid, block, 0, stream>>>(
out_data, numel, main_offset, func);
break;
case 1:
VectorizedIndexKernel<T, Functor, 1><<<grid, block, 0, stream>>>(
out_data, numel, main_offset, func);
break;
default: {
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"Unsupported vectorized size: %d !", vec_size));
break;
}
}
}
} // namespace operators
} // namespace paddle
......@@ -12,130 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/uniform_random_op.h"
#include "paddle/phi/kernels/full_kernel.h"
namespace paddle {
namespace operators {
template <typename T>
struct UniformGenerator {
T min_, max_;
unsigned int seed_;
T diag_val_;
unsigned int diag_num_;
unsigned int diag_step_;
__host__ __device__ UniformGenerator(T min, T max, int seed, int diag_num,
int diag_step, T diag_val)
: min_(min),
max_(max),
seed_(seed),
diag_num_(diag_num),
diag_step_(diag_step),
diag_val_(diag_val) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(min_, max_);
rng.discard(n);
T out = dist(rng);
unsigned int remainder = n % (diag_step_ + 1);
if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) {
out = diag_val_;
}
return out;
}
};
template <typename T>
struct UniformGeneratorOffset {
T min_, max_;
unsigned int seed_;
T diag_val_;
unsigned int diag_num_;
unsigned int diag_step_;
int offset_;
__host__ __device__ UniformGeneratorOffset(T min, T max, int seed,
int diag_num, int diag_step,
T diag_val, int offset)
: min_(min),
max_(max),
seed_(seed),
diag_num_(diag_num),
diag_step_(diag_step),
diag_val_(diag_val),
offset_(offset) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(min_, max_);
rng.discard(n + offset_);
T out = dist(rng);
unsigned int remainder = n % (diag_step_ + 1);
if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) {
out = diag_val_;
}
return out;
}
};
template <typename T>
__global__ void fill_value(int64_t size, T* data, float value) {
for (int idx = threadIdx.x; idx < size; idx += blockDim.x) {
data[idx] = static_cast<T>(value);
}
}
// It seems that Eigen::Tensor::random in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random as uniform_random_op.cu.
template <typename T>
class GPUUniformRandomInplaceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto out_var = ctx.OutputVar("Out");
auto* tensor = out_var->GetMutable<framework::LoDTensor>();
T* data = tensor->mutable_data<T>(ctx.GetPlace());
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
bool seed_flag = false;
if (seed == 0) {
std::random_device rd;
seed = rd();
seed_flag = true;
}
T min = static_cast<T>(ctx.Attr<float>("min"));
T max = static_cast<T>(ctx.Attr<float>("max"));
unsigned int diag_num =
static_cast<unsigned int>(ctx.Attr<int>("diag_num"));
unsigned int diag_step =
static_cast<unsigned int>(ctx.Attr<int>("diag_step"));
T diag_val = static_cast<T>(ctx.Attr<float>("diag_val"));
thrust::counting_iterator<int64_t> index_sequence_begin(0);
int64_t size = tensor->numel();
int device_id = ctx.GetPlace().GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed_flag) {
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = size * seed_offset.second;
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
UniformGeneratorOffset<T>(min, max, seed_offset.first, diag_num,
diag_step, diag_val, gen_offset));
} else {
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
UniformGenerator<T>(min, max, seed, diag_num, diag_step, diag_val));
}
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out");
UniformRandom<T>(context, tensor);
}
};
......@@ -143,17 +30,15 @@ template <typename T>
class GPUUniformRandomInplaceGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#ifdef __HIPCC__
const int64_t kMaxBlockDim = 256;
#else
const int64_t kMaxBlockDim = 512;
#endif
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* data = dx->mutable_data<T>(ctx.GetPlace());
auto size = dx->numel();
int64_t kBlockDim = std::min(size, kMaxBlockDim);
fill_value<T><<<1, kBlockDim, 0>>>(size, data, static_cast<float>(0));
auto dims = vectorize(dx->dims());
const auto& dev_cxt =
ctx.template device_context<platform::CUDADeviceContext>();
float value = static_cast<float>(0.0f);
phi::FullKernel<T>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
paddle::platform::CUDADeviceContext>::TYPE&>(dev_cxt),
dims, value, phi::DataType::UNDEFINED, dx);
}
};
......
......@@ -11,88 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/distribution_helper.h"
#include "paddle/fluid/operators/uniform_random_op.h"
DECLARE_bool(use_curand);
namespace paddle {
namespace operators {
template <typename T>
struct UniformGenerator {
T min_, max_;
unsigned int seed_;
T diag_val_;
unsigned int diag_num_;
unsigned int diag_step_;
__host__ __device__ UniformGenerator(T min, T max, int seed, int diag_num,
int diag_step, T diag_val)
: min_(min),
max_(max),
seed_(seed),
diag_num_(diag_num),
diag_step_(diag_step),
diag_val_(diag_val) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(min_, max_);
rng.discard(n);
T out = dist(rng);
unsigned int remainder = n % (diag_step_ + 1);
if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) {
out = diag_val_;
}
return out;
}
};
template <typename T>
struct UniformGeneratorOffset {
T min_, max_;
unsigned int seed_;
T diag_val_;
unsigned int diag_num_;
unsigned int diag_step_;
int offset_;
__host__ __device__ UniformGeneratorOffset(T min, T max, int seed,
int diag_num, int diag_step,
T diag_val, int offset)
: min_(min),
max_(max),
seed_(seed),
diag_num_(diag_num),
diag_step_(diag_step),
diag_val_(diag_val),
offset_(offset) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(min_, max_);
rng.discard(n + offset_);
T out = dist(rng);
unsigned int remainder = n % (diag_step_ + 1);
if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) {
out = diag_val_;
}
return out;
}
};
// It seems that Eigen::Tensor::random in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template <typename T>
class GPUUniformRandomKernel : public framework::OpKernel<T> {
public:
......@@ -128,50 +51,7 @@ class GPUUniformRandomKernel : public framework::OpKernel<T> {
"unsupport type: %s.",
framework::ToTypeName(out_var->Type())));
}
auto& dev_cxt =
context.template device_context<platform::CUDADeviceContext>();
T* data = tensor->mutable_data<T>(dev_cxt.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
bool seed_flag = false;
if (seed == 0) {
std::random_device rd;
seed = rd();
seed_flag = true;
}
T min = static_cast<T>(context.Attr<float>("min"));
T max = static_cast<T>(context.Attr<float>("max"));
unsigned int diag_num =
static_cast<unsigned int>(context.Attr<int>("diag_num"));
unsigned int diag_step =
static_cast<unsigned int>(context.Attr<int>("diag_step"));
T diag_val = static_cast<T>(context.Attr<float>("diag_val"));
thrust::counting_iterator<int64_t> index_sequence_begin(0);
int64_t size = tensor->numel();
int device_id = context.GetPlace().GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed_flag) {
if (FLAGS_use_curand) {
using MT = typename details::MPTypeTrait<T>::Type;
distribution::uniform_distribution<MT> dist;
distribution::uniform_transform<MT> trans(min, max);
distribution::distribution_and_transform<T>(dev_cxt, tensor, dist,
trans);
} else {
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = size * seed_offset.second;
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
UniformGeneratorOffset<T>(min, max, seed_offset.first, diag_num,
diag_step, diag_val, gen_offset));
}
} else {
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
UniformGenerator<T>(min, max, seed, diag_num, diag_step, diag_val));
}
UniformRandom<T>(context, tensor);
}
};
......
......@@ -18,6 +18,16 @@
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#if defined(__NVCC__) || defined(__HIPCC__)
DECLARE_bool(use_curand);
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/index_impl.cu.h"
#include "paddle/phi/kernels/full_kernel.h"
#endif
namespace paddle {
namespace operators {
......@@ -102,5 +112,117 @@ inline std::vector<int64_t> GetNewDataFromShapeTensorList(
return vec_new_shape;
}
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T>
struct UniformGenerator {
T min_, max_;
unsigned int seed_;
T diag_val_;
unsigned int diag_num_;
unsigned int diag_step_;
__host__ __device__ UniformGenerator(T min, T max, int seed, int diag_num,
int diag_step, T diag_val)
: min_(min),
max_(max),
seed_(seed),
diag_num_(diag_num),
diag_step_(diag_step),
diag_val_(diag_val) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(min_, max_);
rng.discard(n);
T out = dist(rng);
unsigned int remainder = n % (diag_step_ + 1);
if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) {
out = diag_val_;
}
return out;
}
};
template <typename T>
struct UniformGeneratorOffset {
T min_, max_;
unsigned int seed_;
T diag_val_;
unsigned int diag_num_;
unsigned int diag_step_;
int offset_;
__host__ __device__ UniformGeneratorOffset(T min, T max, int seed,
int diag_num, int diag_step,
T diag_val, int offset)
: min_(min),
max_(max),
seed_(seed),
diag_num_(diag_num),
diag_step_(diag_step),
diag_val_(diag_val),
offset_(offset) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(min_, max_);
rng.discard(n + offset_);
T out = dist(rng);
unsigned int remainder = n % (diag_step_ + 1);
if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) {
out = diag_val_;
}
return out;
}
};
template <typename T>
void UniformRandom(const framework::ExecutionContext& context,
framework::Tensor* tensor) {
int64_t size = tensor->numel();
auto& dev_cxt =
context.template device_context<platform::CUDADeviceContext>();
T* data = tensor->mutable_data<T>(dev_cxt.GetPlace());
if (size <= 0) return;
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
bool seed_flag = false;
if (seed == 0) {
std::random_device rd;
seed = rd();
seed_flag = true;
}
T min = static_cast<T>(context.Attr<float>("min"));
T max = static_cast<T>(context.Attr<float>("max"));
unsigned int diag_num =
static_cast<unsigned int>(context.Attr<int>("diag_num"));
unsigned int diag_step =
static_cast<unsigned int>(context.Attr<int>("diag_step"));
T diag_val = static_cast<T>(context.Attr<float>("diag_val"));
int device_id = context.GetPlace().GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed_flag) {
if (FLAGS_use_curand) {
using MT = typename details::MPTypeTrait<T>::Type;
distribution::uniform_distribution<MT> dist;
distribution::uniform_transform<MT> trans(min, max);
distribution::distribution_and_transform<T>(dev_cxt, tensor, dist, trans);
} else {
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = size * seed_offset.second;
auto func =
UniformGeneratorOffset<T>(min, max, seed_offset.first, diag_num,
diag_step, diag_val, gen_offset);
IndexKernel<T, UniformGeneratorOffset<T>>(dev_cxt, tensor, func);
}
} else {
auto func =
UniformGenerator<T>(min, max, seed, diag_num, diag_step, diag_val);
IndexKernel<T, UniformGenerator<T>>(dev_cxt, tensor, func);
}
}
#endif
} // namespace operators
} // namespace paddle
......@@ -714,5 +714,14 @@ __device__ __forceinline__ void ReadDataBc(
}
}
template <typename T, int NX, int NY, int BlockSize>
__device__ __forceinline__ void InitWithDataIndex(T* dst, int block_offset) {
int thread_offset = block_offset + threadIdx.x * NX;
#pragma unroll
for (int nx = 0; nx < NX; ++nx) {
dst[nx] = static_cast<T>(thread_offset + nx);
}
}
} // namespace kps
} // namespace phi
......@@ -21,7 +21,7 @@ import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from paddle.fluid.executor import Executor
from op_test import OpTest
from paddle.fluid.tests.unittests.op_test import OpTest
import paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册