未验证 提交 22f84122 编写于 作者: L Leo Chen 提交者: GitHub

[phi] refine code of randint, randperm, unbind kernel (#39909)

* refine randint kernel

* refine randperm kernel

* refine unbind kernel

* support op seed
上级 44da9b42
...@@ -22,42 +22,43 @@ ...@@ -22,42 +22,43 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void RandintRawKernel(const Context& ctx, void RandintRawKernel(const Context& dev_ctx,
int low, int low,
int high, int high,
const ScalarArray& shape, const ScalarArray& shape,
DataType dtype, DataType dtype,
int seed, int seed,
DenseTensor* out) { DenseTensor* out) {
out->ResizeAndAllocate(phi::make_ddim(shape.GetData())); out->Resize(phi::make_ddim(shape.GetData()));
auto size = out->numel(); T* data = dev_ctx.template Alloc<T>(out);
auto numel = out->numel();
std::shared_ptr<std::mt19937_64> engine; std::shared_ptr<std::mt19937_64> engine;
if (seed) { if (seed) {
engine = std::make_shared<std::mt19937_64>(); engine = std::make_shared<std::mt19937_64>();
engine->seed(seed); engine->seed(seed);
} else { } else {
engine = ctx.GetGenerator()->GetCPUEngine(); engine = dev_ctx.GetGenerator()->GetCPUEngine();
} }
std::uniform_int_distribution<T> dist(low, high - 1); std::uniform_int_distribution<T> dist(low, high - 1);
auto data = out->data<T>(); for (int64_t i = 0; i < numel; ++i) {
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(*engine); data[i] = dist(*engine);
} }
} }
template <typename T, typename Context> template <typename T, typename Context>
void RandintKernel(const Context& ctx, void RandintKernel(const Context& dev_ctx,
int low, int low,
int high, int high,
const ScalarArray& shape, const ScalarArray& shape,
DataType dtype, DataType dtype,
DenseTensor* out) { DenseTensor* out) {
RandintRawKernel<T>(ctx, low, high, shape, dtype, 0, out); RandintRawKernel<T>(dev_ctx, low, high, shape, dtype, 0, out);
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
randint_raw, CPU, ALL_LAYOUT, phi::RandintRawKernel, int, int64_t) {} randint_raw, CPU, ALL_LAYOUT, phi::RandintRawKernel, int, int64_t) {}
PD_REGISTER_KERNEL(randint, CPU, ALL_LAYOUT, phi::RandintKernel, int, int64_t) { PD_REGISTER_KERNEL(randint, CPU, ALL_LAYOUT, phi::RandintKernel, int, int64_t) {
} }
...@@ -13,20 +13,23 @@ ...@@ -13,20 +13,23 @@
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/randperm_kernel.h" #include "paddle/phi/kernels/randperm_kernel.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void RandpermKernel(const Context& ctx, void RandpermRawKernel(
int n, const Context& dev_ctx, int n, DataType dtype, int seed, DenseTensor* out) {
DataType dtype, T* out_data = dev_ctx.template Alloc<T>(out);
DenseTensor* out) {
T* out_data = ctx.template Alloc<T>(out); std::shared_ptr<std::mt19937_64> engine;
auto gen_ptr = ctx.GetHostGenerator(); if (seed) {
auto engine = gen_ptr->GetCPUEngine(); engine = std::make_shared<std::mt19937_64>();
engine->seed(seed);
} else {
engine = dev_ctx.GetGenerator()->GetCPUEngine();
}
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
out_data[i] = static_cast<T>(i); out_data[i] = static_cast<T>(i);
...@@ -34,8 +37,25 @@ void RandpermKernel(const Context& ctx, ...@@ -34,8 +37,25 @@ void RandpermKernel(const Context& ctx,
std::shuffle(out_data, out_data + n, *engine); std::shuffle(out_data, out_data + n, *engine);
} }
template <typename T, typename Context>
void RandpermKernel(const Context& dev_ctx,
int n,
DataType dtype,
DenseTensor* out) {
RandpermRawKernel<T>(dev_ctx, n, dtype, 0, out);
}
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(randperm_raw,
CPU,
ALL_LAYOUT,
phi::RandpermRawKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(randperm, PD_REGISTER_KERNEL(randperm,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/unbind_kernel.h" #include "paddle/phi/kernels/unbind_kernel.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/unbind_kernel_impl.h" #include "paddle/phi/kernels/impl/unbind_kernel_impl.h"
......
...@@ -12,21 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,21 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <cmath>
#include <memory>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
namespace phi { namespace phi {
......
...@@ -12,23 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,23 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <cmath>
#include <memory>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
namespace phi { namespace phi {
namespace funcs { namespace funcs {
......
...@@ -13,20 +13,18 @@ See the License for the specific language governing permissions and ...@@ -13,20 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <cmath>
#include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/fluid/framework/eigen.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/core/utils/data_type.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/memcpy.h"
namespace phi { namespace phi {
namespace funcs { namespace funcs {
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void RandintRawKernel(const Context& ctx, void RandintRawKernel(const Context& dev_ctx,
int low, int low,
int high, int high,
const ScalarArray& shape, const ScalarArray& shape,
...@@ -34,21 +34,22 @@ void RandintRawKernel(const Context& ctx, ...@@ -34,21 +34,22 @@ void RandintRawKernel(const Context& ctx,
DenseTensor* out) { DenseTensor* out) {
DenseTensor tmp; DenseTensor tmp;
tmp.Resize(phi::make_ddim(shape.GetData())); tmp.Resize(phi::make_ddim(shape.GetData()));
T* tmp_data = ctx.template HostAlloc<T>(&tmp); T* tmp_data = dev_ctx.template HostAlloc<T>(&tmp);
out->ResizeAndAllocate(tmp.dims()); out->Resize(tmp.dims());
auto size = out->numel(); T* data = dev_ctx.template Alloc<T>(out);
std::shared_ptr<std::mt19937_64> engine; std::shared_ptr<std::mt19937_64> engine;
if (seed) { if (seed) {
engine = std::make_shared<std::mt19937_64>(); engine = std::make_shared<std::mt19937_64>();
engine->seed(seed); engine->seed(seed);
} else { } else {
engine = ctx.GetHostGenerator()->GetCPUEngine(); engine = dev_ctx.GetHostGenerator()->GetCPUEngine();
} }
std::uniform_int_distribution<T> dist(low, high - 1); std::uniform_int_distribution<T> dist(low, high - 1);
auto data = out->data<T>(); auto numel = out->numel();
for (int64_t i = 0; i < size; ++i) { for (int64_t i = 0; i < numel; ++i) {
tmp_data[i] = dist(*engine); tmp_data[i] = dist(*engine);
} }
...@@ -57,18 +58,18 @@ void RandintRawKernel(const Context& ctx, ...@@ -57,18 +58,18 @@ void RandintRawKernel(const Context& ctx,
data, data,
tmp.place(), tmp.place(),
tmp_data, tmp_data,
size * paddle::experimental::SizeOf(out->dtype()), numel * paddle::experimental::SizeOf(out->dtype()),
0); 0);
} }
template <typename T, typename Context> template <typename T, typename Context>
void RandintKernel(const Context& ctx, void RandintKernel(const Context& dev_ctx,
int low, int low,
int high, int high,
const ScalarArray& shape, const ScalarArray& shape,
DataType dtype, DataType dtype,
DenseTensor* out) { DenseTensor* out) {
RandintRawKernel<T>(ctx, low, high, shape, dtype, 0, out); RandintRawKernel<T>(dev_ctx, low, high, shape, dtype, 0, out);
} }
} // namespace phi } // namespace phi
......
...@@ -12,41 +12,60 @@ ...@@ -12,41 +12,60 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/kernels/randperm_kernel.h" #include "paddle/phi/kernels/randperm_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void RandpermKernel(const Context& ctx, void RandpermRawKernel(
int n, const Context& dev_ctx, int n, DataType dtype, int seed, DenseTensor* out) {
DataType dtype,
DenseTensor* out) {
DenseTensor tmp; DenseTensor tmp;
tmp.Resize(phi::make_ddim({n})); tmp.Resize(phi::make_ddim({n}));
T* tmp_data = ctx.template HostAlloc<T>(&tmp); T* tmp_data = dev_ctx.template HostAlloc<T>(&tmp);
auto gen_ptr = ctx.GetHostGenerator(); std::shared_ptr<std::mt19937_64> engine;
auto engine = gen_ptr->GetCPUEngine(); if (seed) {
engine = std::make_shared<std::mt19937_64>();
engine->seed(seed);
} else {
engine = dev_ctx.GetHostGenerator()->GetCPUEngine();
}
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
tmp_data[i] = static_cast<T>(i); tmp_data[i] = static_cast<T>(i);
} }
std::shuffle(tmp_data, tmp_data + n, *engine); std::shuffle(tmp_data, tmp_data + n, *engine);
T* out_data = ctx.template Alloc<T>(out); T* out_data = dev_ctx.template Alloc<T>(out);
auto size = out->numel() * paddle::experimental::SizeOf(out->dtype()); auto size = out->numel() * paddle::experimental::SizeOf(out->dtype());
paddle::memory::Copy<phi::GPUPlace, phi::Place>( paddle::memory::Copy<phi::GPUPlace, phi::Place>(
out->place(), out_data, tmp.place(), tmp_data, size, 0); out->place(), out_data, tmp.place(), tmp_data, size, 0);
} }
template <typename T, typename Context>
void RandpermKernel(const Context& dev_ctx,
int n,
DataType dtype,
DenseTensor* out) {
RandpermRawKernel<T>(dev_ctx, n, dtype, 0, out);
}
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(randperm_raw,
GPU,
ALL_LAYOUT,
phi::RandpermRawKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(randperm, PD_REGISTER_KERNEL(randperm,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
......
...@@ -12,9 +12,10 @@ ...@@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/unbind_kernel.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/unbind_kernel_impl.h" #include "paddle/phi/kernels/impl/unbind_kernel_impl.h"
#include "paddle/phi/kernels/unbind_kernel.h"
PD_REGISTER_KERNEL(unbind, PD_REGISTER_KERNEL(unbind,
GPU, GPU,
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void UnbindKernel(const Context& ctx, void UnbindKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int axis, int axis,
std::vector<DenseTensor*> outs) { std::vector<DenseTensor*> outs) {
...@@ -29,12 +29,12 @@ void UnbindKernel(const Context& ctx, ...@@ -29,12 +29,12 @@ void UnbindKernel(const Context& ctx,
std::vector<const DenseTensor*> shape_refer; std::vector<const DenseTensor*> shape_refer;
for (size_t j = 0; j < outs.size(); ++j) { for (size_t j = 0; j < outs.size(); ++j) {
ctx.template Alloc<T>(outs[j]); dev_ctx.template Alloc<T>(outs[j]);
shape_refer.emplace_back(outs[j]); shape_refer.emplace_back(outs[j]);
} }
phi::funcs::SplitFunctor<Context, T> functor; phi::funcs::SplitFunctor<Context, T> functor;
functor(ctx, x, shape_refer, axis, &outs); functor(dev_ctx, x, shape_refer, axis, &outs);
} }
} // namespace phi } // namespace phi
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void RandintKernel(const Context& ctx, void RandintKernel(const Context& dev_ctx,
int low, int low,
int high, int high,
const ScalarArray& shape, const ScalarArray& shape,
...@@ -28,7 +28,7 @@ void RandintKernel(const Context& ctx, ...@@ -28,7 +28,7 @@ void RandintKernel(const Context& ctx,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
void RandintRawKernel(const Context& ctx, void RandintRawKernel(const Context& dev_ctx,
int low, int low,
int high, int high,
const ScalarArray& shape, const ScalarArray& shape,
......
...@@ -20,7 +20,11 @@ ...@@ -20,7 +20,11 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void RandpermKernel(const Context& ctx, void RandpermRawKernel(
const Context& dev_ctx, int n, DataType dtype, int seed, DenseTensor* out);
template <typename T, typename Context>
void RandpermKernel(const Context& dev_ctx,
int n, int n,
DataType dtype, DataType dtype,
DenseTensor* out); DenseTensor* out);
......
...@@ -17,7 +17,12 @@ ...@@ -17,7 +17,12 @@
namespace phi { namespace phi {
KernelSignature RandpermOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature RandpermOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("randperm", {}, {"n", "dtype"}, {"Out"}); int seed = paddle::any_cast<int>(ctx.Attr("seed"));
if (seed) {
return KernelSignature("randperm", {}, {"n", "dtype", "seed"}, {"Out"});
} else {
return KernelSignature("randperm", {}, {"n", "dtype"}, {"Out"});
}
} }
} // namespace phi } // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册