未验证 提交 22f84122 编写于 作者: L Leo Chen 提交者: GitHub

[phi] refine code of randint, randperm, unbind kernel (#39909)

* refine randint kernel

* refine randperm kernel

* refine unbind kernel

* support op seed
上级 44da9b42
......@@ -22,42 +22,43 @@
namespace phi {
template <typename T, typename Context>
void RandintRawKernel(const Context& ctx,
void RandintRawKernel(const Context& dev_ctx,
int low,
int high,
const ScalarArray& shape,
DataType dtype,
int seed,
DenseTensor* out) {
out->ResizeAndAllocate(phi::make_ddim(shape.GetData()));
auto size = out->numel();
out->Resize(phi::make_ddim(shape.GetData()));
T* data = dev_ctx.template Alloc<T>(out);
auto numel = out->numel();
std::shared_ptr<std::mt19937_64> engine;
if (seed) {
engine = std::make_shared<std::mt19937_64>();
engine->seed(seed);
} else {
engine = ctx.GetGenerator()->GetCPUEngine();
engine = dev_ctx.GetGenerator()->GetCPUEngine();
}
std::uniform_int_distribution<T> dist(low, high - 1);
auto data = out->data<T>();
for (int64_t i = 0; i < size; ++i) {
for (int64_t i = 0; i < numel; ++i) {
data[i] = dist(*engine);
}
}
template <typename T, typename Context>
void RandintKernel(const Context& ctx,
void RandintKernel(const Context& dev_ctx,
int low,
int high,
const ScalarArray& shape,
DataType dtype,
DenseTensor* out) {
RandintRawKernel<T>(ctx, low, high, shape, dtype, 0, out);
RandintRawKernel<T>(dev_ctx, low, high, shape, dtype, 0, out);
}
} // namespace phi
PD_REGISTER_KERNEL(
randint_raw, CPU, ALL_LAYOUT, phi::RandintRawKernel, int, int64_t) {}
PD_REGISTER_KERNEL(randint, CPU, ALL_LAYOUT, phi::RandintKernel, int, int64_t) {
}
......@@ -13,20 +13,23 @@
// limitations under the License.
#include "paddle/phi/kernels/randperm_kernel.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void RandpermKernel(const Context& ctx,
int n,
DataType dtype,
DenseTensor* out) {
T* out_data = ctx.template Alloc<T>(out);
auto gen_ptr = ctx.GetHostGenerator();
auto engine = gen_ptr->GetCPUEngine();
void RandpermRawKernel(
const Context& dev_ctx, int n, DataType dtype, int seed, DenseTensor* out) {
T* out_data = dev_ctx.template Alloc<T>(out);
std::shared_ptr<std::mt19937_64> engine;
if (seed) {
engine = std::make_shared<std::mt19937_64>();
engine->seed(seed);
} else {
engine = dev_ctx.GetGenerator()->GetCPUEngine();
}
for (int i = 0; i < n; ++i) {
out_data[i] = static_cast<T>(i);
......@@ -34,8 +37,25 @@ void RandpermKernel(const Context& ctx,
std::shuffle(out_data, out_data + n, *engine);
}
template <typename T, typename Context>
void RandpermKernel(const Context& dev_ctx,
int n,
DataType dtype,
DenseTensor* out) {
RandpermRawKernel<T>(dev_ctx, n, dtype, 0, out);
}
} // namespace phi
PD_REGISTER_KERNEL(randperm_raw,
CPU,
ALL_LAYOUT,
phi::RandpermRawKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(randperm,
CPU,
ALL_LAYOUT,
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/phi/kernels/unbind_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/unbind_kernel_impl.h"
......
......@@ -12,21 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cmath>
#include <memory>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
namespace phi {
......
......@@ -12,23 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cmath>
#include <memory>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
namespace phi {
namespace funcs {
......
......@@ -13,20 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cmath>
#include <memory>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/utils/data_type.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/memcpy.h"
namespace phi {
namespace funcs {
......
......@@ -25,7 +25,7 @@
namespace phi {
template <typename T, typename Context>
void RandintRawKernel(const Context& ctx,
void RandintRawKernel(const Context& dev_ctx,
int low,
int high,
const ScalarArray& shape,
......@@ -34,21 +34,22 @@ void RandintRawKernel(const Context& ctx,
DenseTensor* out) {
DenseTensor tmp;
tmp.Resize(phi::make_ddim(shape.GetData()));
T* tmp_data = ctx.template HostAlloc<T>(&tmp);
T* tmp_data = dev_ctx.template HostAlloc<T>(&tmp);
out->ResizeAndAllocate(tmp.dims());
auto size = out->numel();
out->Resize(tmp.dims());
T* data = dev_ctx.template Alloc<T>(out);
std::shared_ptr<std::mt19937_64> engine;
if (seed) {
engine = std::make_shared<std::mt19937_64>();
engine->seed(seed);
} else {
engine = ctx.GetHostGenerator()->GetCPUEngine();
engine = dev_ctx.GetHostGenerator()->GetCPUEngine();
}
std::uniform_int_distribution<T> dist(low, high - 1);
auto data = out->data<T>();
for (int64_t i = 0; i < size; ++i) {
auto numel = out->numel();
for (int64_t i = 0; i < numel; ++i) {
tmp_data[i] = dist(*engine);
}
......@@ -57,18 +58,18 @@ void RandintRawKernel(const Context& ctx,
data,
tmp.place(),
tmp_data,
size * paddle::experimental::SizeOf(out->dtype()),
numel * paddle::experimental::SizeOf(out->dtype()),
0);
}
template <typename T, typename Context>
void RandintKernel(const Context& ctx,
void RandintKernel(const Context& dev_ctx,
int low,
int high,
const ScalarArray& shape,
DataType dtype,
DenseTensor* out) {
RandintRawKernel<T>(ctx, low, high, shape, dtype, 0, out);
RandintRawKernel<T>(dev_ctx, low, high, shape, dtype, 0, out);
}
} // namespace phi
......
......@@ -12,41 +12,60 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/kernels/randperm_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void RandpermKernel(const Context& ctx,
int n,
DataType dtype,
DenseTensor* out) {
void RandpermRawKernel(
const Context& dev_ctx, int n, DataType dtype, int seed, DenseTensor* out) {
DenseTensor tmp;
tmp.Resize(phi::make_ddim({n}));
T* tmp_data = ctx.template HostAlloc<T>(&tmp);
T* tmp_data = dev_ctx.template HostAlloc<T>(&tmp);
auto gen_ptr = ctx.GetHostGenerator();
auto engine = gen_ptr->GetCPUEngine();
std::shared_ptr<std::mt19937_64> engine;
if (seed) {
engine = std::make_shared<std::mt19937_64>();
engine->seed(seed);
} else {
engine = dev_ctx.GetHostGenerator()->GetCPUEngine();
}
for (int i = 0; i < n; ++i) {
tmp_data[i] = static_cast<T>(i);
}
std::shuffle(tmp_data, tmp_data + n, *engine);
T* out_data = ctx.template Alloc<T>(out);
T* out_data = dev_ctx.template Alloc<T>(out);
auto size = out->numel() * paddle::experimental::SizeOf(out->dtype());
paddle::memory::Copy<phi::GPUPlace, phi::Place>(
out->place(), out_data, tmp.place(), tmp_data, size, 0);
}
template <typename T, typename Context>
void RandpermKernel(const Context& dev_ctx,
int n,
DataType dtype,
DenseTensor* out) {
RandpermRawKernel<T>(dev_ctx, n, dtype, 0, out);
}
} // namespace phi
PD_REGISTER_KERNEL(randperm_raw,
GPU,
ALL_LAYOUT,
phi::RandpermRawKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(randperm,
GPU,
ALL_LAYOUT,
......
......@@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/unbind_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/unbind_kernel_impl.h"
#include "paddle/phi/kernels/unbind_kernel.h"
PD_REGISTER_KERNEL(unbind,
GPU,
......
......@@ -20,7 +20,7 @@
namespace phi {
template <typename T, typename Context>
void UnbindKernel(const Context& ctx,
void UnbindKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
std::vector<DenseTensor*> outs) {
......@@ -29,12 +29,12 @@ void UnbindKernel(const Context& ctx,
std::vector<const DenseTensor*> shape_refer;
for (size_t j = 0; j < outs.size(); ++j) {
ctx.template Alloc<T>(outs[j]);
dev_ctx.template Alloc<T>(outs[j]);
shape_refer.emplace_back(outs[j]);
}
phi::funcs::SplitFunctor<Context, T> functor;
functor(ctx, x, shape_refer, axis, &outs);
functor(dev_ctx, x, shape_refer, axis, &outs);
}
} // namespace phi
......@@ -20,7 +20,7 @@
namespace phi {
template <typename T, typename Context>
void RandintKernel(const Context& ctx,
void RandintKernel(const Context& dev_ctx,
int low,
int high,
const ScalarArray& shape,
......@@ -28,7 +28,7 @@ void RandintKernel(const Context& ctx,
DenseTensor* out);
template <typename T, typename Context>
void RandintRawKernel(const Context& ctx,
void RandintRawKernel(const Context& dev_ctx,
int low,
int high,
const ScalarArray& shape,
......
......@@ -20,7 +20,11 @@
namespace phi {
template <typename T, typename Context>
void RandpermKernel(const Context& ctx,
void RandpermRawKernel(
const Context& dev_ctx, int n, DataType dtype, int seed, DenseTensor* out);
template <typename T, typename Context>
void RandpermKernel(const Context& dev_ctx,
int n,
DataType dtype,
DenseTensor* out);
......
......@@ -17,7 +17,12 @@
namespace phi {
KernelSignature RandpermOpArgumentMapping(const ArgumentMappingContext& ctx) {
int seed = paddle::any_cast<int>(ctx.Attr("seed"));
if (seed) {
return KernelSignature("randperm", {}, {"n", "dtype", "seed"}, {"Out"});
} else {
return KernelSignature("randperm", {}, {"n", "dtype"}, {"Out"});
}
}
} // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册