未验证 提交 13e2e10c 编写于 作者: Z zhangyuqin1998 提交者: GitHub

move UniformRawKernel to legacy (#53158)

* move UniformRawKernel to legacy

* Update uniform_kernel.cc

* Update uniform_kernel.cu

* Update uniform_kernel.cc

* Update uniform_kernel.cu

* Update uniform_kernel.h

* Update uniform_kernel.cc

* Empty Commit to setup deployments
上级 d463f8ee
......@@ -112,6 +112,7 @@ file(
"gpudnn/*.cu"
"kps/*.cu"
"legacy/kps/*.cu"
"legacy/gpu/*.cu"
"selected_rows/gpu/*.cu"
"sparse/gpu/*.cu"
"strings/gpu/*.cu"
......
......@@ -20,16 +20,13 @@
namespace phi {
template <typename T, typename Context>
void UniformRawKernel(const Context &dev_ctx,
const IntArray &shape,
DataType dtype,
const Scalar &min,
const Scalar &max,
int seed,
int diag_num,
int diag_step,
float diag_val,
DenseTensor *out) {
void UniformKernel(const Context &dev_ctx,
const IntArray &shape,
DataType dtype,
const Scalar &min,
const Scalar &max,
int seed,
DenseTensor *out) {
out->Resize(phi::make_ddim(shape.GetData()));
T *data = dev_ctx.template Alloc<T>(out);
auto size = out->numel();
......@@ -42,31 +39,14 @@ void UniformRawKernel(const Context &dev_ctx,
}
UniformRealDistribution<T>(
data, size, min.to<float>(), max.to<float>(), engine);
if (diag_num > 0) {
PADDLE_ENFORCE_GT(
size,
(diag_num - 1) * (diag_step + 1),
phi::errors::InvalidArgument(
"ShapeInvalid: the diagonal's elements is equal (num-1) "
"* (step-1) with num %d, step %d,"
"It should be smaller than %d, but received %d",
diag_num,
diag_step,
(diag_num - 1) * (diag_step + 1),
size));
for (int64_t i = 0; i < diag_num; ++i) {
int64_t pos = i * diag_step + i;
data[pos] = diag_val;
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(uniform_raw,
PD_REGISTER_KERNEL(uniform,
CPU,
ALL_LAYOUT,
phi::UniformRawKernel,
phi::UniformKernel,
float,
double,
phi::dtype::bfloat16) {}
......@@ -54,16 +54,13 @@ struct UniformGenerator {
};
template <typename T, typename Context>
void UniformRawKernel(const Context& dev_ctx,
const IntArray& shape,
DataType dtype,
const Scalar& min,
const Scalar& max,
int seed,
int diag_num,
int diag_step,
float diag_val,
DenseTensor* out) {
void UniformKernel(const Context& dev_ctx,
const IntArray& shape,
DataType dtype,
const Scalar& min,
const Scalar& max,
int seed,
DenseTensor* out) {
out->Resize(phi::make_ddim(shape.GetData()));
dev_ctx.template Alloc<T>(out);
if (seed == 0) {
......@@ -77,19 +74,19 @@ void UniformRawKernel(const Context& dev_ctx,
auto func = UniformGenerator<T>(static_cast<T>(min.to<float>()),
static_cast<T>(max.to<float>()),
seed,
diag_num,
diag_step,
static_cast<T>(diag_val));
0,
0,
static_cast<T>(0.0));
IndexKernel<T, UniformGenerator<T>>(dev_ctx, out, func);
}
}
} // namespace phi
PD_REGISTER_KERNEL(uniform_raw,
PD_REGISTER_KERNEL(uniform,
GPU,
ALL_LAYOUT,
phi::UniformRawKernel,
phi::UniformKernel,
float,
double,
phi::dtype::float16,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/uniform_real_distribution.h"
namespace phi {
template <typename T, typename Context>
void UniformRawKernel(const Context &dev_ctx,
const IntArray &shape,
DataType dtype,
const Scalar &min,
const Scalar &max,
int seed,
int diag_num,
int diag_step,
float diag_val,
DenseTensor *out) {
out->Resize(phi::make_ddim(shape.GetData()));
T *data = dev_ctx.template Alloc<T>(out);
auto size = out->numel();
std::shared_ptr<std::mt19937_64> engine;
if (seed) {
engine = std::make_shared<std::mt19937_64>();
engine->seed(seed);
} else {
engine = dev_ctx.GetGenerator()->GetCPUEngine();
}
UniformRealDistribution<T>(
data, size, min.to<float>(), max.to<float>(), engine);
if (diag_num > 0) {
PADDLE_ENFORCE_GT(
size,
(diag_num - 1) * (diag_step + 1),
phi::errors::InvalidArgument(
"ShapeInvalid: the diagonal's elements is equal (num-1) "
"* (step-1) with num %d, step %d,"
"It should be smaller than %d, but received %d",
diag_num,
diag_step,
(diag_num - 1) * (diag_step + 1),
size));
for (int64_t i = 0; i < diag_num; ++i) {
int64_t pos = i * diag_step + i;
data[pos] = diag_val;
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(uniform_raw,
CPU,
ALL_LAYOUT,
phi::UniformRawKernel,
float,
double,
phi::dtype::bfloat16) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/uniform_kernel.h"
#include <thrust/random.h>
#include "gflags/gflags.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
namespace phi {
template <typename T>
struct UniformGenerator {
T min_, max_;
unsigned int seed_;
T diag_val_;
unsigned int diag_num_;
unsigned int diag_step_;
__host__ __device__ UniformGenerator(
T min, T max, int seed, int diag_num, int diag_step, T diag_val)
: min_(min),
max_(max),
seed_(seed),
diag_num_(diag_num),
diag_step_(diag_step),
diag_val_(diag_val) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(min_, max_);
rng.discard(n);
T out = dist(rng);
unsigned int remainder = n % (diag_step_ + 1);
if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) {
out = diag_val_;
}
return out;
}
};
template <typename T, typename Context>
void UniformRawKernel(const Context& dev_ctx,
const IntArray& shape,
DataType dtype,
const Scalar& min,
const Scalar& max,
int seed,
int diag_num,
int diag_step,
float diag_val,
DenseTensor* out) {
out->Resize(phi::make_ddim(shape.GetData()));
dev_ctx.template Alloc<T>(out);
if (seed == 0) {
// Use global Generator seed
using MT = typename kps::details::MPTypeTrait<T>::Type;
funcs::uniform_distribution<MT> dist;
funcs::uniform_real_transform<MT> trans(min.to<float>(), max.to<float>());
funcs::distribution_and_transform<T>(dev_ctx, out, dist, trans);
} else {
// Use OP seed
auto func = UniformGenerator<T>(static_cast<T>(min.to<float>()),
static_cast<T>(max.to<float>()),
seed,
diag_num,
diag_step,
static_cast<T>(diag_val));
IndexKernel<T, UniformGenerator<T>>(dev_ctx, out, func);
}
}
} // namespace phi
PD_REGISTER_KERNEL(uniform_raw,
GPU,
ALL_LAYOUT,
phi::UniformRawKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,54 +12,25 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/uniform_kernel.h"
#pragma once
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_registry.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/phi/backends/gpu/gpu_context.h"
#endif
#ifdef PADDLE_WITH_XPU
#include "paddle/phi/backends/xpu/xpu_context.h"
#endif
namespace phi {
template <typename T, typename Context>
void UniformKernel(const Context& dev_ctx,
const IntArray& shape,
DataType dtype,
const Scalar& min,
const Scalar& max,
int seed,
DenseTensor* out) {
UniformRawKernel<T>(dev_ctx, shape, dtype, min, max, seed, 0, 0, 0.0f, out);
}
void UniformRawKernel(const Context& dev_ctx,
const IntArray& shape,
DataType dtype,
const Scalar& min,
const Scalar& max,
int seed,
int diag_num,
int diag_step,
float diag_val,
DenseTensor* out);
} // namespace phi
PD_REGISTER_KERNEL(uniform,
CPU,
ALL_LAYOUT,
phi::UniformKernel,
float,
double,
phi::dtype::bfloat16) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(uniform,
GPU,
ALL_LAYOUT,
phi::UniformKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#endif
#ifdef PADDLE_WITH_XPU
PD_REGISTER_KERNEL(uniform, XPU, ALL_LAYOUT, phi::UniformKernel, float) {}
#endif
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/uniform_kernel.h"
#include <string>
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/uniform_real_distribution.h"
namespace phi {
template <typename T, typename Context>
void UniformRawKernel(const Context &dev_ctx,
const IntArray &shape,
DataType dtype,
const Scalar &min,
const Scalar &max,
int seed,
int diag_num,
int diag_step,
float diag_val,
DenseTensor *out) {
out->Resize(phi::make_ddim(shape.GetData()));
T *data = dev_ctx.template Alloc<T>(out);
int64_t size = out->numel();
std::unique_ptr<T[]> data_cpu(new T[size]);
std::shared_ptr<std::mt19937_64> engine;
if (seed) {
engine = std::make_shared<std::mt19937_64>();
engine->seed(seed);
} else {
engine = dev_ctx.GetGenerator()->GetCPUEngine();
}
UniformRealDistribution<T>(
data_cpu.get(), size, min.to<float>(), max.to<float>(), engine);
if (diag_num > 0) {
PADDLE_ENFORCE_GT(
size,
(diag_num - 1) * (diag_step + 1),
phi::errors::InvalidArgument(
"ShapeInvalid: the diagonal's elements is equal (num-1) "
"* (step-1) with num %d, step %d,"
"It should be smaller than %d, but received %d",
diag_num,
diag_step,
(diag_num - 1) * (diag_step + 1),
size));
for (int64_t i = 0; i < diag_num; ++i) {
int64_t pos = i * diag_step + i;
data_cpu[pos] = diag_val;
}
}
memory_utils::Copy(dev_ctx.GetPlace(),
data,
phi::CPUPlace(),
reinterpret_cast<void *>(data_cpu.get()),
size * sizeof(T));
}
} // namespace phi
PD_REGISTER_KERNEL(uniform_raw, XPU, ALL_LAYOUT, phi::UniformRawKernel, float) {
}
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/selected_rows/uniform_kernel.h"
#include "paddle/phi/kernels/legacy/uniform_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
......
......@@ -21,18 +21,6 @@
namespace phi {
template <typename T, typename Context>
void UniformRawKernel(const Context& dev_ctx,
const IntArray& shape,
DataType dtype,
const Scalar& min,
const Scalar& max,
int seed,
int diag_num,
int diag_step,
float diag_val,
DenseTensor* out);
template <typename T, typename Context>
void UniformKernel(const Context& dev_ctx,
const IntArray& shape,
......
......@@ -24,16 +24,16 @@ limitations under the License. */
namespace phi {
template <typename T, typename Context>
void UniformRawKernel(const Context &dev_ctx,
const IntArray &shape,
DataType dtype,
const Scalar &min,
const Scalar &max,
int seed,
int diag_num,
int diag_step,
float diag_val,
DenseTensor *out) {
void UniformKernel(const Context &dev_ctx,
const IntArray &shape,
DataType dtype,
const Scalar &min,
const Scalar &max,
int seed,
DenseTensor *out) {
int diag_num = 0;
int diag_step = 0;
float diag_val = 0.0f;
out->Resize(phi::make_ddim(shape.GetData()));
T *data = dev_ctx.template Alloc<T>(out);
int64_t size = out->numel();
......@@ -76,5 +76,4 @@ void UniformRawKernel(const Context &dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(uniform_raw, XPU, ALL_LAYOUT, phi::UniformRawKernel, float) {
}
PD_REGISTER_KERNEL(uniform, XPU, ALL_LAYOUT, phi::UniformKernel, float) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册