未验证 提交 23aa7a36 编写于 作者: F furnace 提交者: GitHub

[Phi] move truncated_gaussian_random kernel (#39971)

* [Phi] move truncated_gaussian_random, copy kernels

* [Phi] move truncated_gaussian_random, kernel register

* [Phi] move truncated_gaussian_random, delete useless codes
上级 35471b1f
...@@ -23,28 +23,6 @@ limitations under the License. */ ...@@ -23,28 +23,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T>
class CPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
float mean = context.Attr<float>("mean");
float std = context.Attr<float>("std");
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
std::uniform_real_distribution<T> dist(std::numeric_limits<float>::min(),
1.0);
TruncatedNormal<T> truncated_normal(mean, std);
int64_t size = tensor->numel();
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
auto engine = framework::GetCPURandomEngine(seed);
for (int64_t i = 0; i < size; ++i) {
data[i] = truncated_normal(dist(*engine));
}
}
};
class TruncatedGaussianRandomOp : public framework::OperatorWithKernel { class TruncatedGaussianRandomOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -124,5 +102,3 @@ namespace ops = paddle::operators; ...@@ -124,5 +102,3 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(truncated_gaussian_random, REGISTER_OP_WITHOUT_GRADIENT(truncated_gaussian_random,
ops::TruncatedGaussianRandomOp, ops::TruncatedGaussianRandomOp,
ops::TruncatedGaussianRandomOpMaker); ops::TruncatedGaussianRandomOpMaker);
REGISTER_OP_CPU_KERNEL(truncated_gaussian_random,
ops::CPUTruncatedGaussianRandomKernel<float>);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/truncated_gaussian_random_kernel.h"
#include <limits>
#include <random>
#include <vector>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/framework/generator.h"
namespace phi {
template <typename T, typename Context>
void TruncatedGaussianRandomKernel(const Context& dev_ctx,
const ScalarArray& shape,
float mean,
float std,
int seed,
DataType dtype,
DenseTensor* out) {
auto tensor = out;
T* data = dev_ctx.template Alloc<T>(tensor);
std::uniform_real_distribution<T> dist(std::numeric_limits<float>::min(),
1.0);
TruncatedNormal<T> truncated_normal(mean, std);
int64_t size = tensor->numel();
auto engine = paddle::framework::GetCPURandomEngine(seed);
for (int64_t i = 0; i < size; ++i) {
data[i] = truncated_normal(dist(*engine));
}
}
} // namespace phi
PD_REGISTER_KERNEL(truncated_gaussian_random,
CPU,
ALL_LAYOUT,
phi::TruncatedGaussianRandomKernel,
float) {}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
You may obtain a copy of the License at // You may obtain a copy of the License at
//
http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
//
Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
limitations under the License. */ // limitations under the License.
#include "paddle/phi/kernels/truncated_gaussian_random_kernel.h"
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/host_vector.h> #include <thrust/host_vector.h>
#include <thrust/random.h> #include <thrust/random.h>
#include <thrust/transform.h> #include <thrust/transform.h>
#include <limits> #include <limits>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h" // #include "paddle/phi/core/generator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/truncated_gaussian_random_op.h"
namespace paddle { namespace phi {
namespace operators {
template <typename T> template <typename T>
struct GPUTruncatedNormal { struct GPUTruncatedNormal {
...@@ -59,8 +63,8 @@ struct TruncatedNormalOffset { ...@@ -59,8 +63,8 @@ struct TruncatedNormalOffset {
T numeric_min; T numeric_min;
int offset_; int offset_;
__host__ __device__ TruncatedNormalOffset(T mean, T std, T numeric_min, __host__ __device__
int seed, int offset) TruncatedNormalOffset(T mean, T std, T numeric_min, int seed, int offset)
: mean(mean), : mean(mean),
std(std), std(std),
seed(seed), seed(seed),
...@@ -81,48 +85,55 @@ struct TruncatedNormalOffset { ...@@ -81,48 +85,55 @@ struct TruncatedNormalOffset {
} }
}; };
template <typename T> template <typename T, typename Context>
class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> { void TruncatedGaussianRandomKernel(const Context& dev_ctx,
public: const ScalarArray& shape,
void Compute(const framework::ExecutionContext& context) const override { float mean,
auto* tensor = context.Output<framework::Tensor>("Out"); float std,
T* data = tensor->mutable_data<T>(context.GetPlace()); int seed,
DataType dtype,
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed")); DenseTensor* out) {
bool seed_flag = false; auto tensor = out;
if (seed == 0) {
std::random_device rd; T* data = dev_ctx.template Alloc<T>(tensor);
seed = rd();
seed_flag = true; bool seed_flag = false;
} if (seed == 0) {
T mean = static_cast<T>(context.Attr<float>("mean")); std::random_device rd;
T std = static_cast<T>(context.Attr<float>("std")); seed = rd();
thrust::counting_iterator<int64_t> index_sequence_begin(0); seed_flag = true;
int64_t size = tensor->numel();
int device_id = context.GetPlace().GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed_flag) {
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = size * seed_offset.second;
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
TruncatedNormalOffset<T>(mean, std, std::numeric_limits<T>::min(),
seed_offset.first, gen_offset));
} else {
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GPUTruncatedNormal<T>(
mean, std, std::numeric_limits<T>::min(), seed));
}
} }
};
} // namespace operators thrust::counting_iterator<int64_t> index_sequence_begin(0);
} // namespace paddle int64_t size = tensor->numel();
int device_id = dev_ctx.GetPlace().GetDeviceId();
auto gen_cuda = paddle::framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed_flag) {
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = size * seed_offset.second;
thrust::transform(index_sequence_begin,
index_sequence_begin + size,
thrust::device_ptr<T>(data),
TruncatedNormalOffset<T>(mean,
std,
std::numeric_limits<T>::min(),
seed_offset.first,
gen_offset));
} else {
thrust::transform(
index_sequence_begin,
index_sequence_begin + size,
thrust::device_ptr<T>(data),
GPUTruncatedNormal<T>(mean, std, std::numeric_limits<T>::min(), seed));
}
}
} // namespace phi
REGISTER_OP_CUDA_KERNEL( PD_REGISTER_KERNEL(truncated_gaussian_random,
truncated_gaussian_random, GPU,
paddle::operators::GPUTruncatedGaussianRandomKernel<float>); ALL_LAYOUT,
phi::TruncatedGaussianRandomKernel,
float) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <limits>
#include <random>
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
// reference: https://gist.github.com/lakshayg/d80172fe5ae3c5d2c2aedb53c250320e
template <typename T>
T Erfinv(T x) {
if (x < -1 || x > 1) {
return std::numeric_limits<T>::quiet_NaN();
} else if (x == 1.0) {
return std::numeric_limits<T>::infinity();
} else if (x == -1.0) {
return -std::numeric_limits<T>::infinity();
}
const T LN2 = 6.931471805599453094172321214581e-1;
const T A0 = 1.1975323115670912564578e0;
const T A1 = 4.7072688112383978012285e1;
const T A2 = 6.9706266534389598238465e2;
const T A3 = 4.8548868893843886794648e3;
const T A4 = 1.6235862515167575384252e4;
const T A5 = 2.3782041382114385731252e4;
const T A6 = 1.1819493347062294404278e4;
const T A7 = 8.8709406962545514830200e2;
const T B0 = 1.0000000000000000000e0;
const T B1 = 4.2313330701600911252e1;
const T B2 = 6.8718700749205790830e2;
const T B3 = 5.3941960214247511077e3;
const T B4 = 2.1213794301586595867e4;
const T B5 = 3.9307895800092710610e4;
const T B6 = 2.8729085735721942674e4;
const T B7 = 5.2264952788528545610e3;
const T C0 = 1.42343711074968357734e0;
const T C1 = 4.63033784615654529590e0;
const T C2 = 5.76949722146069140550e0;
const T C3 = 3.64784832476320460504e0;
const T C4 = 1.27045825245236838258e0;
const T C5 = 2.41780725177450611770e-1;
const T C6 = 2.27238449892691845833e-2;
const T C7 = 7.74545014278341407640e-4;
const T D0 = 1.4142135623730950488016887e0;
const T D1 = 2.9036514445419946173133295e0;
const T D2 = 2.3707661626024532365971225e0;
const T D3 = 9.7547832001787427186894837e-1;
const T D4 = 2.0945065210512749128288442e-1;
const T D5 = 2.1494160384252876777097297e-2;
const T D6 = 7.7441459065157709165577218e-4;
const T D7 = 1.4859850019840355905497876e-9;
const T E0 = 6.65790464350110377720e0;
const T E1 = 5.46378491116411436990e0;
const T E2 = 1.78482653991729133580e0;
const T E3 = 2.96560571828504891230e-1;
const T E4 = 2.65321895265761230930e-2;
const T E5 = 1.24266094738807843860e-3;
const T E6 = 2.71155556874348757815e-5;
const T E7 = 2.01033439929228813265e-7;
const T F0 = 1.414213562373095048801689e0;
const T F1 = 8.482908416595164588112026e-1;
const T F2 = 1.936480946950659106176712e-1;
const T F3 = 2.103693768272068968719679e-2;
const T F4 = 1.112800997078859844711555e-3;
const T F5 = 2.611088405080593625138020e-5;
const T F6 = 2.010321207683943062279931e-7;
const T F7 = 2.891024605872965461538222e-15;
T abs_x = abs(x);
if (abs_x <= 0.85) {
T r = 0.180625 - 0.25 * x * x;
T num =
(((((((A7 * r + A6) * r + A5) * r + A4) * r + A3) * r + A2) * r + A1) *
r +
A0);
T den =
(((((((B7 * r + B6) * r + B5) * r + B4) * r + B3) * r + B2) * r + B1) *
r +
B0);
return x * num / den;
}
T r = sqrt(LN2 - log(1.0 - abs_x));
T num, den;
if (r <= 5.0) {
r = r - 1.6;
num =
(((((((C7 * r + C6) * r + C5) * r + C4) * r + C3) * r + C2) * r + C1) *
r +
C0);
den =
(((((((D7 * r + D6) * r + D5) * r + D4) * r + D3) * r + D2) * r + D1) *
r +
D0);
} else {
r = r - 5.0;
num =
(((((((E7 * r + E6) * r + E5) * r + E4) * r + E3) * r + E2) * r + E1) *
r +
E0);
den =
(((((((F7 * r + F6) * r + F5) * r + F4) * r + F3) * r + F2) * r + F1) *
r +
F0);
}
if (x < 0) {
return -num / den;
} else {
return num / den;
}
}
template <typename T>
struct TruncatedNormal {
T mean, std;
T a_normal_cdf;
T b_normal_cdf;
TruncatedNormal(T mean, T std) : mean(mean), std(std) {
auto normal_cdf = [](T x) {
return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
};
a_normal_cdf = normal_cdf(-2.0);
b_normal_cdf = normal_cdf(2.0);
}
T operator()(T value) const {
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
return std::sqrt(2.0) * Erfinv(2 * p - 1) * std + mean;
}
};
template <typename T, typename Context>
void TruncatedGaussianRandomKernel(const Context& ctx,
const ScalarArray& shape,
float mean,
float std,
int seed,
DataType dtype,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature TruncatedGaussianRandomOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("truncated_gaussian_random",
{},
{"shape", "mean", "std", "seed", "dtype"},
{"Out"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(truncated_gaussian_random,
phi::TruncatedGaussianRandomOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册