未验证 提交 00bbb8c5 编写于 作者: F furnace 提交者: GitHub

[Phi] move gaussian_random (#39932)

[Phi] move gaussian_random kernel
上级 815f7a67
......@@ -26,27 +26,6 @@ namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class CPUGaussianRandomKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
float mean = context.Attr<float>("mean");
float std = context.Attr<float>("std");
auto* tensor = context.Output<framework::Tensor>("Out");
std::normal_distribution<T> dist(mean, std);
auto shape = GetShape(context);
tensor->Resize(shape);
int64_t size = tensor->numel();
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
auto engine = framework::GetCPURandomEngine(seed);
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(*engine);
}
}
}; // namespace operators
template <typename T>
class CPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
......@@ -194,8 +173,6 @@ Used to initialize tensors with gaussian random generator.
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp,
ops::GaussianRandomOpMaker);
REGISTER_OP_CPU_KERNEL(gaussian_random, ops::CPUGaussianRandomKernel<float>,
ops::CPUGaussianRandomKernel<double>);
REGISTER_OP_CPU_KERNEL(gaussian_random_batch_size_like,
ops::CPUGaussianRandomBatchSizeLikeKernel<float>,
ops::CPUGaussianRandomBatchSizeLikeKernel<double>);
......
......@@ -52,53 +52,6 @@ struct GaussianGenerator {
}
};
template <typename T>
class GPUGaussianRandomKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out");
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
bool seed_flag = false;
if (seed == 0) {
std::random_device rd;
seed = rd();
seed_flag = true;
}
T mean = static_cast<T>(context.Attr<float>("mean"));
T std = static_cast<T>(context.Attr<float>("std"));
auto shape = GetShape(context);
tensor->Resize(shape);
auto& dev_cxt =
context.template device_context<platform::CUDADeviceContext>();
T* data = tensor->mutable_data<T>(dev_cxt.GetPlace());
int64_t size = tensor->numel();
int device_id = context.GetPlace().GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed_flag) {
if (FLAGS_use_curand) {
using MT = typename details::MPTypeTrait<T>::Type;
distribution::normal_distribution<MT> dist;
distribution::normal_transform<MT> trans(mean, std);
distribution::distribution_and_transform<T>(dev_cxt, tensor, dist,
trans);
} else {
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = size * seed_offset.second;
auto func =
GaussianGenerator<T>(mean, std, seed_offset.first, gen_offset);
IndexKernel<T, GaussianGenerator<T>>(dev_cxt, tensor, func);
}
} else {
auto func = GaussianGenerator<T>(mean, std, seed);
IndexKernel<T, GaussianGenerator<T>>(dev_cxt, tensor, func);
}
}
};
template <typename T>
class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
public:
......@@ -136,11 +89,6 @@ class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(
gaussian_random,
paddle::operators::GPUGaussianRandomKernel<paddle::platform::float16>,
paddle::operators::GPUGaussianRandomKernel<float>,
paddle::operators::GPUGaussianRandomKernel<double>);
REGISTER_OP_CUDA_KERNEL(
gaussian_random_batch_size_like,
paddle::operators::GPUGaussianRandomBatchSizeLikeKernel<
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gaussian_random_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/framework/generator.h"
namespace phi {
template <typename T, typename Context>
void GaussianRandomKernel(const Context& dev_ctx,
const ScalarArray& shape,
float mean,
float std,
int seed,
DataType dtype,
DenseTensor* out) {
auto tensor = out;
std::normal_distribution<T> dist(mean, std);
tensor->Resize(phi::make_ddim(shape.GetData()));
int64_t size = tensor->numel();
T* data = dev_ctx.template Alloc<T>(tensor);
auto engine = paddle::framework::GetCPURandomEngine(seed);
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(*engine);
}
}
} // namespace phi
PD_REGISTER_KERNEL(gaussian_random,
CPU,
ALL_LAYOUT,
phi::GaussianRandomKernel,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
template <typename T, typename Context>
void GaussianRandomKernel(const Context& ctx,
const ScalarArray& shape,
float mean,
float std,
int seed,
DataType dtype,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gaussian_random_kernel.h"
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
#include "paddle/fluid/framework/generator.h"
DECLARE_bool(use_curand);
namespace phi {
template <typename T>
struct GaussianGenerator {
T mean_, std_;
unsigned int seed_;
unsigned int offset_ = 0;
__host__ __device__ GaussianGenerator(T mean, T std, int seed)
: mean_(mean), std_(std), seed_(seed) {}
__host__ __device__ GaussianGenerator(T mean, T std, int seed, int offset)
: mean_(mean), std_(std), seed_(seed), offset_(offset) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
thrust::normal_distribution<MT> dist(mean_, std_);
unsigned int new_n = n + offset_;
rng.discard(new_n);
MT out = dist(rng);
return static_cast<T>(out);
}
};
template <typename T, typename Context>
void GaussianRandomKernel(const Context& dev_ctx,
const ScalarArray& shape,
float mean,
float std,
int seed,
DataType dtype,
DenseTensor* out) {
auto tensor = out;
bool seed_flag = false;
if (seed == 0) {
std::random_device rd;
seed = rd();
seed_flag = true;
}
tensor->Resize(phi::make_ddim(shape.GetData()));
T* data = dev_ctx.template Alloc<T>(tensor);
int64_t size = tensor->numel();
int device_id = dev_ctx.GetPlace().GetDeviceId();
auto gen_cuda = paddle::framework::GetDefaultCUDAGenerator(device_id);
using MT = typename phi::kps::details::MPTypeTrait<T>::Type;
if (gen_cuda->GetIsInitPy() && seed_flag) {
if (FLAGS_use_curand) {
funcs::normal_distribution<MT> dist;
funcs::normal_transform<MT> trans(mean, std);
funcs::distribution_and_transform<T>(dev_ctx, tensor, dist, trans);
} else {
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = size * seed_offset.second;
auto func =
GaussianGenerator<MT>(mean, std, seed_offset.first, gen_offset);
IndexKernel<T, GaussianGenerator<MT>>(dev_ctx, tensor, func);
}
} else {
auto func = GaussianGenerator<MT>(mean, std, seed);
IndexKernel<T, GaussianGenerator<MT>>(dev_ctx, tensor, func);
}
}
} // namespace phi
PD_REGISTER_KERNEL(gaussian_random,
GPU,
ALL_LAYOUT,
phi::GaussianRandomKernel,
phi::dtype::float16,
float,
double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature GaussianRandomOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.InputSize("ShapeTensorList") > 0) {
return KernelSignature("gaussian_random",
{},
{"ShapeTensorList", "mean", "std", "seed", "dtype"},
{"Out"});
}
const auto& shape = paddle::any_cast<std::vector<int64_t>>(ctx.Attr("shape"));
if (ctx.HasInput("ShapeTensor") && shape.empty()) {
return KernelSignature("gaussian_random",
{},
{"ShapeTensor", "mean", "std", "seed", "dtype"},
{"Out"});
}
return KernelSignature("gaussian_random",
{},
{"shape", "mean", "std", "seed", "dtype"},
{"Out"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(gaussian_random,
phi::GaussianRandomOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册