未验证 提交 30708028 编写于 作者: S Sonder 提交者: GitHub

[Phi] move shuffle_batch to phi (#56547)

* move shuffle_batch to phi

* remove useless codes

* add test_shuffle_batch_op to STATIC_BUILD_TESTS

* move shuffle_batch_kernel.cc to cpu folder

* move shuffle_batch_grad to phi

* rm shuffle_batch_op.h

* change year at file head
上级 7995a389
......@@ -12,12 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/shuffle_batch_op.h"
#include <atomic>
#include <cstring>
#include <ctime>
#include <memory>
#include <random>
#include <string>
#include <utility>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/timer.h"
#include "paddle/phi/core/mixed_vector.h"
namespace paddle {
namespace operators {
......@@ -158,20 +170,3 @@ REGISTER_OPERATOR(shuffle_batch,
ops::ShuffleBatchGradOpMaker<paddle::framework::OpDesc>,
ops::ShuffleBatchGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(shuffle_batch_grad, ops::ShuffleBatchOpGrad);
PD_REGISTER_STRUCT_KERNEL(shuffle_batch,
CPU,
ALL_LAYOUT,
ops::ShuffleBatchKernel,
float,
double,
int32_t,
int64_t) {}
PD_REGISTER_STRUCT_KERNEL(shuffle_batch_grad,
CPU,
ALL_LAYOUT,
ops::ShuffleBatchGradKernel,
float,
double,
int32_t,
int64_t) {}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifndef _MSC_VER
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/shuffle.h>
#endif
#include "paddle/fluid/operators/shuffle_batch_op.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
struct CacheAllocator {
typedef char value_type;
explicit CacheAllocator(platform::Place place) {
VLOG(2) << "construct allocator";
place_ = place;
}
~CacheAllocator() { VLOG(2) << "destory allocator"; }
char *allocate(std::ptrdiff_t num_bytes) {
VLOG(2) << "allocate " << num_bytes << " bytes";
auto storage = memory::AllocShared(place_, num_bytes);
char *ptr = reinterpret_cast<char *>(storage->ptr());
busy_allocation_.emplace(std::make_pair(ptr, storage));
return ptr;
}
void deallocate(char *ptr, size_t) {
VLOG(2) << "deallocate ";
allocation_map_type::iterator iter = busy_allocation_.find(ptr);
CHECK(iter != busy_allocation_.end());
busy_allocation_.erase(iter);
}
private:
typedef std::unordered_map<char *, std::shared_ptr<phi::Allocation>>
allocation_map_type;
allocation_map_type busy_allocation_;
platform::Place place_;
};
template <typename T, bool kIsForward>
struct ReorderFunctor {
ReorderFunctor(const T *x, const int64_t *shuffle_idx, T *y, int64_t stride)
: x_(x), shuffle_idx_(shuffle_idx), y_(y), stride_(stride) {}
HOSTDEVICE void operator()(int64_t idx) {
auto reorder_idx = shuffle_idx_[idx / stride_] * stride_ + idx % stride_;
if (kIsForward) {
y_[idx] = x_[reorder_idx];
} else {
y_[reorder_idx] = x_[idx];
}
}
private:
const T *x_;
const int64_t *shuffle_idx_;
T *y_;
int64_t stride_;
};
template <typename T, typename DeviceContext>
class ShuffleBatchCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
#ifdef _MSC_VER
PADDLE_THROW(platform::errors::Unimplemented(
"GPU shuffle_batch is not supported on Windows yet"));
#else
auto *x = ctx.Input<phi::DenseTensor>("X");
auto *seed = ctx.Input<phi::DenseTensor>("Seed");
auto *out = ctx.Output<phi::DenseTensor>("Out");
auto *shuffleidx = ctx.Output<phi::DenseTensor>("ShuffleIdx");
auto *seed_out = ctx.Output<phi::DenseTensor>("SeedOut");
int64_t x_embed_size = x->dims()[x->dims().size() - 1];
int64_t elem_size = 1;
for (int i = 0; i < x->dims().size() - 1; i++) {
elem_size *= x->dims()[i];
}
shuffleidx->Resize(phi::make_ddim({elem_size}));
int64_t seed_int = 0;
if (seed->IsInitialized()) {
const auto &seed_place = seed->place();
if (platform::is_gpu_place(seed_place)) {
// NOTE: We have overwritten GetKernelTypeForVar, so seed_place would
// not be CUDAPlace in practice. This case would only happen in Python
// op_test framework.
phi::DenseTensor tmp_tensor;
framework::TensorCopySync(*seed, platform::CPUPlace(), &tmp_tensor);
seed_int = *(tmp_tensor.data<int64_t>());
} else {
seed_int = *(seed->data<int64_t>());
}
} else {
seed_int = ctx.Attr<int>("startup_seed");
}
auto *shuffleidx_data = shuffleidx->mutable_data<int64_t>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
#ifdef PADDLE_WITH_CUDA
CacheAllocator allocator(ctx.GetPlace());
const auto &exec_policy = thrust::cuda::par(allocator).on(dev_ctx.stream());
#else
const auto &exec_policy = thrust::hip::par.on(dev_ctx.stream());
#endif
thrust::random::default_random_engine engine(seed_int);
thrust::counting_iterator<int64_t> cnt_iter(0);
thrust::shuffle_copy(exec_policy,
cnt_iter,
cnt_iter + elem_size,
thrust::device_pointer_cast(shuffleidx_data),
engine);
// TODO(zengjinle): for small data, direct cudaMemcpy may be better
auto *x_data = x->data<T>();
auto *out_data = out->mutable_data<T>(ctx.GetPlace());
ReorderFunctor<T, true> functor(
x_data, shuffleidx_data, out_data, x_embed_size);
platform::ForRange<phi::GPUContext> for_range(dev_ctx,
elem_size * x_embed_size);
for_range(functor);
auto *seed_out_data = seed_out->mutable_data<int64_t>(phi::make_ddim({1}),
platform::CPUPlace());
*seed_out_data = engine();
#endif
}
};
template <typename T, typename DeviceContext>
class ShuffleBatchGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
#ifdef _MSC_VER
PADDLE_THROW(platform::errors::Unimplemented(
"GPU shuffle_batch_grad is not supported on Windows yet"));
#else
const auto *out_grad =
ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"));
const auto *shuffleidx = ctx.Input<phi::DenseTensor>("ShuffleIdx");
auto *x_grad = ctx.Output<phi::DenseTensor>(framework::GradVarName("X"));
const auto *out_grad_data = out_grad->data<T>();
const auto *shuffleidx_data = shuffleidx->data<int64_t>();
auto *x_grad_data = x_grad->mutable_data<T>(ctx.GetPlace());
auto x_embed_size = x_grad->dims()[x_grad->dims().size() - 1];
ReorderFunctor<T, false> functor(
out_grad_data, shuffleidx_data, x_grad_data, x_embed_size);
auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
// TODO(zengjinle): for small data, direct cudaMemcpy may be better
platform::ForRange<phi::GPUContext> for_range(dev_ctx, x_grad->numel());
for_range(functor);
#endif
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
PD_REGISTER_STRUCT_KERNEL(shuffle_batch,
GPU,
ALL_LAYOUT,
ops::ShuffleBatchCUDAKernel,
float,
double,
int32_t,
int64_t) {}
PD_REGISTER_STRUCT_KERNEL(shuffle_batch_grad,
GPU,
ALL_LAYOUT,
ops::ShuffleBatchGradCUDAKernel,
float,
double,
int32_t,
int64_t) {}
#endif
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <cstring>
#include <ctime>
#include <random>
#include <string>
#include <utility>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/timer.h"
#include "paddle/phi/core/mixed_vector.h"
namespace paddle {
namespace operators {
template <typename T>
using Vector = phi::Vector<T>;
template <typename T, typename DeviceContext>
class ShuffleBatchKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *x = context.Input<phi::DenseTensor>("X");
auto *seed = context.Input<phi::DenseTensor>("Seed");
auto *out = context.Output<phi::DenseTensor>("Out");
auto *shuffleidx = context.Output<phi::DenseTensor>("ShuffleIdx");
auto *seed_out = context.Output<phi::DenseTensor>("SeedOut");
auto x_embed_size = x->dims()[x->dims().size() - 1];
auto elem_size = 1;
for (auto i = 0; i < x->dims().size() - 1; i++) elem_size *= x->dims()[i];
std::vector<int64_t> idx_vec; // record shuffled order
idx_vec.reserve(elem_size);
for (auto i = 0; i < elem_size; i++) {
idx_vec.push_back(i);
}
int64_t seed_int = 0;
if (seed->IsInitialized()) {
seed_int = *seed->data<int64_t>();
} else {
seed_int = context.Attr<int>("startup_seed");
}
std::default_random_engine engine;
engine.seed(seed_int);
auto custom_random_shuffle = [&idx_vec]() {
std::random_device rnd;
int64_t seed_tmp = rnd();
std::default_random_engine rng(seed_tmp);
const int n = idx_vec.size();
std::vector<int> v(n);
std::iota(v.begin(), v.end(), 0);
std::vector<bool> visit(n, false);
while (!v.empty()) {
std::shuffle(v.begin(), v.end(), rng);
int tmp = v.back();
v.pop_back();
if (v.empty()) {
std::uniform_int_distribution<int> distr(0, n - 2);
idx_vec[tmp] = tmp;
std::swap(idx_vec[tmp], idx_vec[(distr(rng) + tmp + 1) % n]);
return;
}
visit[tmp] = true;
std::shuffle(v.begin(), v.end(), rng);
int curr = v.back();
v.pop_back();
v.push_back(tmp);
idx_vec[tmp] = curr;
while (!visit[curr]) {
visit[curr] = true;
std::shuffle(v.begin(), v.end(), rng);
idx_vec[curr] = v.back();
v.pop_back();
curr = idx_vec[curr];
}
}
};
custom_random_shuffle();
// change shuffle to custom_random_shuffle
// std::shuffle(idx_vec.begin(), idx_vec.end(), engine);
// ShuffleIdx record shuffle order
shuffleidx->Resize(phi::make_ddim({(int64_t)idx_vec.size()}));
auto *shuffleidx_data =
shuffleidx->mutable_data<int64_t>(context.GetPlace());
for (size_t i = 0; i < idx_vec.size(); i++) {
shuffleidx_data[i] = idx_vec[i];
}
// copy data according to idx_vec
auto *x_data = x->data<T>();
auto *out_data = out->mutable_data<T>(context.GetPlace());
for (auto i = 0; i < elem_size; i++) {
memcpy(out_data + idx_vec[i] * x_embed_size,
x_data + i * x_embed_size,
x_embed_size * sizeof(T));
}
// set new seed
*seed_out->mutable_data<int64_t>(phi::make_ddim({1}), context.GetPlace()) =
engine();
}
};
template <typename T, typename DeviceContext>
class ShuffleBatchGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *out_grad =
context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
auto *shuffleidx = context.Input<phi::DenseTensor>("ShuffleIdx");
auto *x_grad =
context.Output<phi::DenseTensor>(framework::GradVarName("X"));
auto embed_size = out_grad->dims()[out_grad->dims().size() - 1];
auto elem_size = 1;
for (auto i = 0; i < out_grad->dims().size() - 1; i++)
elem_size *= out_grad->dims()[i];
std::vector<int> idx_vec_grad(elem_size);
auto *shuffleidx_data = shuffleidx->data<int64_t>();
for (size_t i = 0; i < idx_vec_grad.size(); i++) {
idx_vec_grad[shuffleidx_data[i]] = i;
}
// copy data according to idx_vec_grad
auto *out_grad_data = out_grad->data<T>();
auto *x_grad_data = x_grad->mutable_data<T>(context.GetPlace());
for (auto i = 0; i < elem_size; i++) {
memcpy(x_grad_data + idx_vec_grad[i] * embed_size,
out_grad_data + i * embed_size,
embed_size * sizeof(T));
}
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/shuffle_batch_kernel.h"
namespace phi {
template <typename T, typename Context>
void ShuffleBatchGradKernel(const Context& dev_ctx,
const DenseTensor& shuffleidx,
const DenseTensor& out_grad,
int startup_seed,
DenseTensor* x_grad) {
auto embed_size = out_grad.dims()[out_grad.dims().size() - 1];
auto elem_size = 1;
for (auto i = 0; i < out_grad.dims().size() - 1; i++)
elem_size *= out_grad.dims()[i];
std::vector<int> idx_vec_grad(elem_size);
auto* shuffleidx_data = shuffleidx.data<int64_t>();
for (size_t i = 0; i < idx_vec_grad.size(); i++) {
idx_vec_grad[shuffleidx_data[i]] = i;
}
// copy data according to idx_vec_grad
auto* out_grad_data = out_grad.data<T>();
auto* x_grad_data = dev_ctx.template Alloc<T>(x_grad);
for (auto i = 0; i < elem_size; i++) {
memcpy(x_grad_data + idx_vec_grad[i] * embed_size,
out_grad_data + i * embed_size,
embed_size * sizeof(T));
}
}
} // namespace phi
PD_REGISTER_KERNEL(shuffle_batch_grad,
CPU,
ALL_LAYOUT,
phi::ShuffleBatchGradKernel,
float,
double,
int32_t,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/shuffle_batch_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
namespace phi {
template <typename T, typename Context>
void ShuffleBatchKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& seed,
int startup_seed,
DenseTensor* out,
DenseTensor* shuffleidx,
DenseTensor* seed_out) {
auto x_embed_size = x.dims()[x.dims().size() - 1];
auto elem_size = 1;
for (auto i = 0; i < x.dims().size() - 1; i++) elem_size *= x.dims()[i];
std::vector<int64_t> idx_vec; // record shuffled order
idx_vec.reserve(elem_size);
for (auto i = 0; i < elem_size; i++) {
idx_vec.push_back(i);
}
int64_t seed_int = 0;
if (seed.initialized()) {
seed_int = *seed.data<int64_t>();
} else {
seed_int = startup_seed;
}
std::default_random_engine engine;
engine.seed(seed_int);
auto custom_random_shuffle = [&idx_vec]() {
std::random_device rnd;
int64_t seed_tmp = rnd();
std::default_random_engine rng(seed_tmp);
const int n = idx_vec.size();
std::vector<int> v(n);
std::iota(v.begin(), v.end(), 0);
std::vector<bool> visit(n, false);
while (!v.empty()) {
std::shuffle(v.begin(), v.end(), rng);
int tmp = v.back();
v.pop_back();
if (v.empty()) {
std::uniform_int_distribution<int> distr(0, n - 2);
idx_vec[tmp] = tmp;
std::swap(idx_vec[tmp], idx_vec[(distr(rng) + tmp + 1) % n]);
return;
}
visit[tmp] = true;
std::shuffle(v.begin(), v.end(), rng);
int curr = v.back();
v.pop_back();
v.push_back(tmp);
idx_vec[tmp] = curr;
while (!visit[curr]) {
visit[curr] = true;
std::shuffle(v.begin(), v.end(), rng);
idx_vec[curr] = v.back();
v.pop_back();
curr = idx_vec[curr];
}
}
};
custom_random_shuffle();
// change shuffle to custom_random_shuffle
// std::shuffle(idx_vec.begin(), idx_vec.end(), engine);
// ShuffleIdx record shuffle order
shuffleidx->Resize(phi::make_ddim({(int64_t)idx_vec.size()}));
auto* shuffleidx_data = dev_ctx.template HostAlloc<int64_t>(shuffleidx);
for (size_t i = 0; i < idx_vec.size(); i++) {
shuffleidx_data[i] = idx_vec[i];
}
// copy data according to idx_vec
auto* x_data = x.data<T>();
auto* out_data = dev_ctx.template HostAlloc<T>(out);
for (auto i = 0; i < elem_size; i++) {
memcpy(out_data + idx_vec[i] * x_embed_size,
x_data + i * x_embed_size,
x_embed_size * sizeof(T));
}
// set new seed
seed_out->Resize(phi::make_ddim({1}));
auto* seed_out_data = dev_ctx.template HostAlloc<int64_t>(seed_out);
*seed_out_data = engine();
}
} // namespace phi
PD_REGISTER_KERNEL(shuffle_batch,
CPU,
ALL_LAYOUT,
phi::ShuffleBatchKernel,
float,
double,
int32_t,
int64_t) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
kernel->OutputAt(2).SetDataType(phi::DataType::INT64);
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifndef _MSC_VER
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/shuffle.h>
#endif
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/gpu/shuffle_batch_utils.h"
#include "paddle/phi/kernels/shuffle_batch_grad_kernel.h"
namespace phi {
template <typename T, typename Context>
void ShuffleBatchGradKernel(const Context& dev_ctx,
const DenseTensor& shuffleidx,
const DenseTensor& out_grad,
int startup_seed,
DenseTensor* x_grad) {
#ifdef _MSC_VER
PADDLE_THROW(phi::errors::Unimplemented(
"GPU shuffle_batch_grad is not supported on Windows yet"));
#else
const auto* out_grad_data = out_grad.data<T>();
const auto* shuffleidx_data = shuffleidx.data<int64_t>();
auto* x_grad_data = dev_ctx.template Alloc<T>(x_grad);
auto x_embed_size = x_grad->dims()[x_grad->dims().size() - 1];
ReorderFunctor<T, false> functor(
out_grad_data, shuffleidx_data, x_grad_data, x_embed_size);
// TODO(zengjinle): for small data, direct cudaMemcpy may be better
phi::funcs::ForRange<phi::GPUContext> for_range(dev_ctx, x_grad->numel());
for_range(functor);
#endif
}
} // namespace phi
PD_REGISTER_KERNEL(shuffle_batch_grad,
GPU,
ALL_LAYOUT,
phi::ShuffleBatchGradKernel,
float,
double,
int32_t,
int64_t) {}
#endif
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifndef _MSC_VER
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/shuffle.h>
#endif
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/gpu/shuffle_batch_utils.h"
#include "paddle/phi/kernels/shuffle_batch_kernel.h"
namespace phi {
template <typename T, typename Context>
void ShuffleBatchKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& seed,
int startup_seed,
DenseTensor* out,
DenseTensor* shuffleidx,
DenseTensor* seed_out) {
#ifdef _MSC_VER
PADDLE_THROW(phi::errors::Unimplemented(
"GPU shuffle_batch is not supported on Windows yet"));
#else
int64_t x_embed_size = x.dims()[x.dims().size() - 1];
int64_t elem_size = 1;
for (int i = 0; i < x.dims().size() - 1; i++) {
elem_size *= x.dims()[i];
}
shuffleidx->Resize(phi::make_ddim({elem_size}));
int64_t seed_int = 0;
if (seed.initialized()) {
const auto& seed_place = seed.place().GetType();
bool is_gpu_place = seed_place == phi::AllocationType::GPU;
if (is_gpu_place) {
// NOTE: We have overwritten GetKernelTypeForVar, so seed_place would
// not be CUDAPlace in practice. This case would only happen in Python
// op_test framework.
phi::DenseTensor tmp_tensor;
phi::Copy(dev_ctx, seed, phi::CPUPlace(), false, &tmp_tensor);
seed_int = *(tmp_tensor.data<int64_t>());
} else {
seed_int = *(seed.data<int64_t>());
}
} else {
seed_int = startup_seed;
}
auto* shuffleidx_data = dev_ctx.template Alloc<int64_t>(shuffleidx);
#ifdef PADDLE_WITH_CUDA
CacheAllocator allocator(dev_ctx.GetPlace());
const auto& exec_policy = thrust::cuda::par(allocator).on(dev_ctx.stream());
#else
const auto& exec_policy = thrust::hip::par.on(dev_ctx.stream());
#endif
thrust::random::default_random_engine engine(seed_int);
thrust::counting_iterator<int64_t> cnt_iter(0);
thrust::shuffle_copy(exec_policy,
cnt_iter,
cnt_iter + elem_size,
thrust::device_pointer_cast(shuffleidx_data),
engine);
// TODO(zengjinle): for small data, direct cudaMemcpy may be better
auto* x_data = x.data<T>();
auto* out_data = dev_ctx.template Alloc<T>(out);
ReorderFunctor<T, true> functor(
x_data, shuffleidx_data, out_data, x_embed_size);
phi::funcs::ForRange<phi::GPUContext> for_range(dev_ctx,
elem_size * x_embed_size);
for_range(functor);
seed_out->Resize(phi::make_ddim({1}));
auto* seed_out_data = dev_ctx.template HostAlloc<int64_t>(seed_out);
*seed_out_data = engine();
#endif
}
} // namespace phi
PD_REGISTER_KERNEL(shuffle_batch,
GPU,
ALL_LAYOUT,
phi::ShuffleBatchKernel,
float,
double,
int32_t,
int64_t) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
kernel->OutputAt(2).SetDataType(phi::DataType::INT64);
}
#endif
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/tensor_utils.h"
namespace phi {
struct CacheAllocator {
typedef char value_type;
explicit CacheAllocator(phi::Place place) {
VLOG(2) << "construct allocator";
place_ = place;
}
~CacheAllocator() { VLOG(2) << "destory allocator"; }
char* allocate(std::ptrdiff_t num_bytes) {
VLOG(2) << "allocate " << num_bytes << " bytes";
auto storage = memory_utils::AllocShared(place_, num_bytes);
char* ptr = reinterpret_cast<char*>(storage->ptr());
busy_allocation_.emplace(std::make_pair(ptr, storage));
return ptr;
}
void deallocate(char* ptr, size_t) {
VLOG(2) << "deallocate ";
allocation_map_type::iterator iter = busy_allocation_.find(ptr);
CHECK(iter != busy_allocation_.end());
busy_allocation_.erase(iter);
}
private:
typedef std::unordered_map<char*, std::shared_ptr<phi::Allocation>>
allocation_map_type;
allocation_map_type busy_allocation_;
phi::Place place_;
};
template <typename T, bool kIsForward>
struct ReorderFunctor {
ReorderFunctor(const T* x, const int64_t* shuffle_idx, T* y, int64_t stride)
: x_(x), shuffle_idx_(shuffle_idx), y_(y), stride_(stride) {}
HOSTDEVICE void operator()(int64_t idx) {
auto reorder_idx = shuffle_idx_[idx / stride_] * stride_ + idx % stride_;
if (kIsForward) {
y_[idx] = x_[reorder_idx];
} else {
y_[reorder_idx] = x_[idx];
}
}
private:
const T* x_;
const int64_t* shuffle_idx_;
T* y_;
int64_t stride_;
};
} // namespace phi
#endif
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void ShuffleBatchGradKernel(const Context& dev_ctx,
const DenseTensor& shuffleidx,
const DenseTensor& out_grad,
int startup_seed,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void ShuffleBatchKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& seed,
int startup_seed,
DenseTensor* out,
DenseTensor* shuffleidx,
DenseTensor* seed_out);
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature ShuffleBatchOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("shuffle_batch",
{"X", "Seed"},
{"startup_seed"},
{"Out", "ShuffleIdx", "SeedOut"});
}
KernelSignature ShuffleBatchGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("shuffle_batch_grad",
{"ShuffleIdx", "Out@GRAD"},
{"startup_seed"},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(shuffle_batch, phi::ShuffleBatchOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(shuffle_batch_grad,
phi::ShuffleBatchGradOpArgumentMapping);
......@@ -1306,6 +1306,7 @@ set(STATIC_BUILD_TESTS
test_segment_ops
test_sparse_momentum_op
test_sgd_op_bf16
test_shuffle_batch_op
test_softmax_mask_fuse_upper_triangle_op
test_sparse_conv_op
test_sparse_norm_op
......
......@@ -76,7 +76,7 @@ class TestShuffleBatchOpBase(OpTest):
return np.reshape(np.array(arr_list), shape)
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_dygraph=False)
class TestShuffleBatchOp2(TestShuffleBatchOpBase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册