From 30708028936fb04cbc15dc4ddcd2e13f8958d95b Mon Sep 17 00:00:00 2001 From: Sonder <55493212+AndSonder@users.noreply.github.com> Date: Mon, 28 Aug 2023 09:40:44 +0800 Subject: [PATCH] [Phi] move shuffle_batch to phi (#56547) * move shuffle_batch to phi * remove useless codes * add test_shuffle_batch_op to STATIC_BUILD_TESTS * move shuffle_batch_kernel.cc to cpu folder * move shuffle_batch_grad to phi * rm shuffle_batch_op.h * change year at file head --- paddle/fluid/operators/shuffle_batch_op.cc | 35 ++- paddle/fluid/operators/shuffle_batch_op.cu | 200 ------------------ paddle/fluid/operators/shuffle_batch_op.h | 157 -------------- .../kernels/cpu/shuffle_batch_grad_kernel.cc | 60 ++++++ .../phi/kernels/cpu/shuffle_batch_kernel.cc | 117 ++++++++++ .../kernels/gpu/shuffle_batch_grad_kernel.cu | 67 ++++++ .../phi/kernels/gpu/shuffle_batch_kernel.cu | 114 ++++++++++ paddle/phi/kernels/gpu/shuffle_batch_utils.h | 76 +++++++ .../phi/kernels/shuffle_batch_grad_kernel.h | 28 +++ paddle/phi/kernels/shuffle_batch_kernel.h | 30 +++ paddle/phi/ops/compat/shuffle_batch_sig.cc | 40 ++++ test/legacy_test/CMakeLists.txt | 1 + test/legacy_test/test_shuffle_batch_op.py | 2 +- 13 files changed, 549 insertions(+), 378 deletions(-) delete mode 100644 paddle/fluid/operators/shuffle_batch_op.cu delete mode 100644 paddle/fluid/operators/shuffle_batch_op.h create mode 100644 paddle/phi/kernels/cpu/shuffle_batch_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/shuffle_batch_kernel.cc create mode 100644 paddle/phi/kernels/gpu/shuffle_batch_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/shuffle_batch_kernel.cu create mode 100644 paddle/phi/kernels/gpu/shuffle_batch_utils.h create mode 100644 paddle/phi/kernels/shuffle_batch_grad_kernel.h create mode 100644 paddle/phi/kernels/shuffle_batch_kernel.h create mode 100644 paddle/phi/ops/compat/shuffle_batch_sig.cc diff --git a/paddle/fluid/operators/shuffle_batch_op.cc b/paddle/fluid/operators/shuffle_batch_op.cc index 143f94477b3..61b3f30b390 100644 --- a/paddle/fluid/operators/shuffle_batch_op.cc +++ b/paddle/fluid/operators/shuffle_batch_op.cc @@ -12,12 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/shuffle_batch_op.h" - +#include +#include +#include #include - +#include +#include +#include +#include + +#include "glog/logging.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/no_need_buffer_vars_inference.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/var_type_inference.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/platform/timer.h" +#include "paddle/phi/core/mixed_vector.h" namespace paddle { namespace operators { @@ -158,20 +170,3 @@ REGISTER_OPERATOR(shuffle_batch, ops::ShuffleBatchGradOpMaker, ops::ShuffleBatchGradOpMaker); REGISTER_OPERATOR(shuffle_batch_grad, ops::ShuffleBatchOpGrad); - -PD_REGISTER_STRUCT_KERNEL(shuffle_batch, - CPU, - ALL_LAYOUT, - ops::ShuffleBatchKernel, - float, - double, - int32_t, - int64_t) {} -PD_REGISTER_STRUCT_KERNEL(shuffle_batch_grad, - CPU, - ALL_LAYOUT, - ops::ShuffleBatchGradKernel, - float, - double, - int32_t, - int64_t) {} diff --git a/paddle/fluid/operators/shuffle_batch_op.cu b/paddle/fluid/operators/shuffle_batch_op.cu deleted file mode 100644 index 5069cf1e512..00000000000 --- a/paddle/fluid/operators/shuffle_batch_op.cu +++ /dev/null @@ -1,200 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - -#ifndef _MSC_VER -#include -#include -#include -#include -#endif - -#include "paddle/fluid/operators/shuffle_batch_op.h" -#include "paddle/fluid/platform/for_range.h" - -namespace paddle { -namespace operators { - -struct CacheAllocator { - typedef char value_type; - explicit CacheAllocator(platform::Place place) { - VLOG(2) << "construct allocator"; - place_ = place; - } - - ~CacheAllocator() { VLOG(2) << "destory allocator"; } - - char *allocate(std::ptrdiff_t num_bytes) { - VLOG(2) << "allocate " << num_bytes << " bytes"; - auto storage = memory::AllocShared(place_, num_bytes); - char *ptr = reinterpret_cast(storage->ptr()); - busy_allocation_.emplace(std::make_pair(ptr, storage)); - return ptr; - } - - void deallocate(char *ptr, size_t) { - VLOG(2) << "deallocate "; - allocation_map_type::iterator iter = busy_allocation_.find(ptr); - CHECK(iter != busy_allocation_.end()); - busy_allocation_.erase(iter); - } - - private: - typedef std::unordered_map> - allocation_map_type; - allocation_map_type busy_allocation_; - platform::Place place_; -}; - -template -struct ReorderFunctor { - ReorderFunctor(const T *x, const int64_t *shuffle_idx, T *y, int64_t stride) - : x_(x), shuffle_idx_(shuffle_idx), y_(y), stride_(stride) {} - - HOSTDEVICE void operator()(int64_t idx) { - auto reorder_idx = shuffle_idx_[idx / stride_] * stride_ + idx % stride_; - if (kIsForward) { - y_[idx] = x_[reorder_idx]; - } else { - y_[reorder_idx] = x_[idx]; - } - } - - private: - const T *x_; - const int64_t *shuffle_idx_; - T *y_; - int64_t stride_; -}; - -template -class ShuffleBatchCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { -#ifdef _MSC_VER - PADDLE_THROW(platform::errors::Unimplemented( - "GPU shuffle_batch is not supported on Windows yet")); -#else - auto *x = ctx.Input("X"); - auto *seed = ctx.Input("Seed"); - auto *out = ctx.Output("Out"); - auto *shuffleidx = ctx.Output("ShuffleIdx"); - auto *seed_out = ctx.Output("SeedOut"); - - int64_t x_embed_size = x->dims()[x->dims().size() - 1]; - int64_t elem_size = 1; - for (int i = 0; i < x->dims().size() - 1; i++) { - elem_size *= x->dims()[i]; - } - shuffleidx->Resize(phi::make_ddim({elem_size})); - - int64_t seed_int = 0; - if (seed->IsInitialized()) { - const auto &seed_place = seed->place(); - if (platform::is_gpu_place(seed_place)) { - // NOTE: We have overwritten GetKernelTypeForVar, so seed_place would - // not be CUDAPlace in practice. This case would only happen in Python - // op_test framework. - phi::DenseTensor tmp_tensor; - framework::TensorCopySync(*seed, platform::CPUPlace(), &tmp_tensor); - seed_int = *(tmp_tensor.data()); - } else { - seed_int = *(seed->data()); - } - } else { - seed_int = ctx.Attr("startup_seed"); - } - - auto *shuffleidx_data = shuffleidx->mutable_data(ctx.GetPlace()); - - auto &dev_ctx = ctx.template device_context(); -#ifdef PADDLE_WITH_CUDA - CacheAllocator allocator(ctx.GetPlace()); - const auto &exec_policy = thrust::cuda::par(allocator).on(dev_ctx.stream()); -#else - const auto &exec_policy = thrust::hip::par.on(dev_ctx.stream()); -#endif - thrust::random::default_random_engine engine(seed_int); - thrust::counting_iterator cnt_iter(0); - thrust::shuffle_copy(exec_policy, - cnt_iter, - cnt_iter + elem_size, - thrust::device_pointer_cast(shuffleidx_data), - engine); - // TODO(zengjinle): for small data, direct cudaMemcpy may be better - auto *x_data = x->data(); - auto *out_data = out->mutable_data(ctx.GetPlace()); - ReorderFunctor functor( - x_data, shuffleidx_data, out_data, x_embed_size); - platform::ForRange for_range(dev_ctx, - elem_size * x_embed_size); - for_range(functor); - - auto *seed_out_data = seed_out->mutable_data(phi::make_ddim({1}), - platform::CPUPlace()); - *seed_out_data = engine(); -#endif - } -}; - -template -class ShuffleBatchGradCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { -#ifdef _MSC_VER - PADDLE_THROW(platform::errors::Unimplemented( - "GPU shuffle_batch_grad is not supported on Windows yet")); -#else - const auto *out_grad = - ctx.Input(framework::GradVarName("Out")); - const auto *shuffleidx = ctx.Input("ShuffleIdx"); - auto *x_grad = ctx.Output(framework::GradVarName("X")); - - const auto *out_grad_data = out_grad->data(); - const auto *shuffleidx_data = shuffleidx->data(); - auto *x_grad_data = x_grad->mutable_data(ctx.GetPlace()); - auto x_embed_size = x_grad->dims()[x_grad->dims().size() - 1]; - ReorderFunctor functor( - out_grad_data, shuffleidx_data, x_grad_data, x_embed_size); - auto &dev_ctx = ctx.template device_context(); - // TODO(zengjinle): for small data, direct cudaMemcpy may be better - platform::ForRange for_range(dev_ctx, x_grad->numel()); - for_range(functor); -#endif - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -PD_REGISTER_STRUCT_KERNEL(shuffle_batch, - GPU, - ALL_LAYOUT, - ops::ShuffleBatchCUDAKernel, - float, - double, - int32_t, - int64_t) {} -PD_REGISTER_STRUCT_KERNEL(shuffle_batch_grad, - GPU, - ALL_LAYOUT, - ops::ShuffleBatchGradCUDAKernel, - float, - double, - int32_t, - int64_t) {} -#endif diff --git a/paddle/fluid/operators/shuffle_batch_op.h b/paddle/fluid/operators/shuffle_batch_op.h deleted file mode 100644 index 49eeac5cc7b..00000000000 --- a/paddle/fluid/operators/shuffle_batch_op.h +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include "glog/logging.h" -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/memory/memcpy.h" -#include "paddle/fluid/platform/timer.h" -#include "paddle/phi/core/mixed_vector.h" - -namespace paddle { -namespace operators { - -template -using Vector = phi::Vector; - -template -class ShuffleBatchKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - auto *x = context.Input("X"); - auto *seed = context.Input("Seed"); - auto *out = context.Output("Out"); - auto *shuffleidx = context.Output("ShuffleIdx"); - auto *seed_out = context.Output("SeedOut"); - - auto x_embed_size = x->dims()[x->dims().size() - 1]; - auto elem_size = 1; - for (auto i = 0; i < x->dims().size() - 1; i++) elem_size *= x->dims()[i]; - - std::vector idx_vec; // record shuffled order - idx_vec.reserve(elem_size); - for (auto i = 0; i < elem_size; i++) { - idx_vec.push_back(i); - } - int64_t seed_int = 0; - if (seed->IsInitialized()) { - seed_int = *seed->data(); - } else { - seed_int = context.Attr("startup_seed"); - } - std::default_random_engine engine; - engine.seed(seed_int); - - auto custom_random_shuffle = [&idx_vec]() { - std::random_device rnd; - int64_t seed_tmp = rnd(); - std::default_random_engine rng(seed_tmp); - const int n = idx_vec.size(); - std::vector v(n); - std::iota(v.begin(), v.end(), 0); - std::vector visit(n, false); - while (!v.empty()) { - std::shuffle(v.begin(), v.end(), rng); - int tmp = v.back(); - v.pop_back(); - if (v.empty()) { - std::uniform_int_distribution distr(0, n - 2); - idx_vec[tmp] = tmp; - std::swap(idx_vec[tmp], idx_vec[(distr(rng) + tmp + 1) % n]); - return; - } - visit[tmp] = true; - std::shuffle(v.begin(), v.end(), rng); - int curr = v.back(); - v.pop_back(); - v.push_back(tmp); - idx_vec[tmp] = curr; - while (!visit[curr]) { - visit[curr] = true; - std::shuffle(v.begin(), v.end(), rng); - idx_vec[curr] = v.back(); - v.pop_back(); - curr = idx_vec[curr]; - } - } - }; - custom_random_shuffle(); - // change shuffle to custom_random_shuffle - // std::shuffle(idx_vec.begin(), idx_vec.end(), engine); - - // ShuffleIdx record shuffle order - shuffleidx->Resize(phi::make_ddim({(int64_t)idx_vec.size()})); - auto *shuffleidx_data = - shuffleidx->mutable_data(context.GetPlace()); - for (size_t i = 0; i < idx_vec.size(); i++) { - shuffleidx_data[i] = idx_vec[i]; - } - // copy data according to idx_vec - auto *x_data = x->data(); - auto *out_data = out->mutable_data(context.GetPlace()); - for (auto i = 0; i < elem_size; i++) { - memcpy(out_data + idx_vec[i] * x_embed_size, - x_data + i * x_embed_size, - x_embed_size * sizeof(T)); - } - // set new seed - *seed_out->mutable_data(phi::make_ddim({1}), context.GetPlace()) = - engine(); - } -}; - -template -class ShuffleBatchGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - auto *out_grad = - context.Input(framework::GradVarName("Out")); - auto *shuffleidx = context.Input("ShuffleIdx"); - auto *x_grad = - context.Output(framework::GradVarName("X")); - - auto embed_size = out_grad->dims()[out_grad->dims().size() - 1]; - auto elem_size = 1; - for (auto i = 0; i < out_grad->dims().size() - 1; i++) - elem_size *= out_grad->dims()[i]; - - std::vector idx_vec_grad(elem_size); - auto *shuffleidx_data = shuffleidx->data(); - for (size_t i = 0; i < idx_vec_grad.size(); i++) { - idx_vec_grad[shuffleidx_data[i]] = i; - } - - // copy data according to idx_vec_grad - auto *out_grad_data = out_grad->data(); - auto *x_grad_data = x_grad->mutable_data(context.GetPlace()); - for (auto i = 0; i < elem_size; i++) { - memcpy(x_grad_data + idx_vec_grad[i] * embed_size, - out_grad_data + i * embed_size, - embed_size * sizeof(T)); - } - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/kernels/cpu/shuffle_batch_grad_kernel.cc b/paddle/phi/kernels/cpu/shuffle_batch_grad_kernel.cc new file mode 100644 index 00000000000..ccc3af4e2c1 --- /dev/null +++ b/paddle/phi/kernels/cpu/shuffle_batch_grad_kernel.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/shuffle_batch_kernel.h" + +namespace phi { + +template +void ShuffleBatchGradKernel(const Context& dev_ctx, + const DenseTensor& shuffleidx, + const DenseTensor& out_grad, + int startup_seed, + DenseTensor* x_grad) { + auto embed_size = out_grad.dims()[out_grad.dims().size() - 1]; + auto elem_size = 1; + for (auto i = 0; i < out_grad.dims().size() - 1; i++) + elem_size *= out_grad.dims()[i]; + + std::vector idx_vec_grad(elem_size); + auto* shuffleidx_data = shuffleidx.data(); + for (size_t i = 0; i < idx_vec_grad.size(); i++) { + idx_vec_grad[shuffleidx_data[i]] = i; + } + + // copy data according to idx_vec_grad + auto* out_grad_data = out_grad.data(); + auto* x_grad_data = dev_ctx.template Alloc(x_grad); + + for (auto i = 0; i < elem_size; i++) { + memcpy(x_grad_data + idx_vec_grad[i] * embed_size, + out_grad_data + i * embed_size, + embed_size * sizeof(T)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(shuffle_batch_grad, + CPU, + ALL_LAYOUT, + phi::ShuffleBatchGradKernel, + float, + double, + int32_t, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/shuffle_batch_kernel.cc b/paddle/phi/kernels/cpu/shuffle_batch_kernel.cc new file mode 100644 index 00000000000..a509f407f41 --- /dev/null +++ b/paddle/phi/kernels/cpu/shuffle_batch_kernel.cc @@ -0,0 +1,117 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/shuffle_batch_kernel.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" + +namespace phi { + +template +void ShuffleBatchKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& seed, + int startup_seed, + DenseTensor* out, + DenseTensor* shuffleidx, + DenseTensor* seed_out) { + auto x_embed_size = x.dims()[x.dims().size() - 1]; + auto elem_size = 1; + for (auto i = 0; i < x.dims().size() - 1; i++) elem_size *= x.dims()[i]; + + std::vector idx_vec; // record shuffled order + idx_vec.reserve(elem_size); + for (auto i = 0; i < elem_size; i++) { + idx_vec.push_back(i); + } + int64_t seed_int = 0; + if (seed.initialized()) { + seed_int = *seed.data(); + } else { + seed_int = startup_seed; + } + std::default_random_engine engine; + engine.seed(seed_int); + + auto custom_random_shuffle = [&idx_vec]() { + std::random_device rnd; + int64_t seed_tmp = rnd(); + std::default_random_engine rng(seed_tmp); + const int n = idx_vec.size(); + std::vector v(n); + std::iota(v.begin(), v.end(), 0); + std::vector visit(n, false); + while (!v.empty()) { + std::shuffle(v.begin(), v.end(), rng); + int tmp = v.back(); + v.pop_back(); + if (v.empty()) { + std::uniform_int_distribution distr(0, n - 2); + idx_vec[tmp] = tmp; + std::swap(idx_vec[tmp], idx_vec[(distr(rng) + tmp + 1) % n]); + return; + } + visit[tmp] = true; + std::shuffle(v.begin(), v.end(), rng); + int curr = v.back(); + v.pop_back(); + v.push_back(tmp); + idx_vec[tmp] = curr; + while (!visit[curr]) { + visit[curr] = true; + std::shuffle(v.begin(), v.end(), rng); + idx_vec[curr] = v.back(); + v.pop_back(); + curr = idx_vec[curr]; + } + } + }; + custom_random_shuffle(); + // change shuffle to custom_random_shuffle + // std::shuffle(idx_vec.begin(), idx_vec.end(), engine); + + // ShuffleIdx record shuffle order + shuffleidx->Resize(phi::make_ddim({(int64_t)idx_vec.size()})); + auto* shuffleidx_data = dev_ctx.template HostAlloc(shuffleidx); + + for (size_t i = 0; i < idx_vec.size(); i++) { + shuffleidx_data[i] = idx_vec[i]; + } + // copy data according to idx_vec + auto* x_data = x.data(); + auto* out_data = dev_ctx.template HostAlloc(out); + + for (auto i = 0; i < elem_size; i++) { + memcpy(out_data + idx_vec[i] * x_embed_size, + x_data + i * x_embed_size, + x_embed_size * sizeof(T)); + } + // set new seed + seed_out->Resize(phi::make_ddim({1})); + auto* seed_out_data = dev_ctx.template HostAlloc(seed_out); + *seed_out_data = engine(); +} +} // namespace phi + +PD_REGISTER_KERNEL(shuffle_batch, + CPU, + ALL_LAYOUT, + phi::ShuffleBatchKernel, + float, + double, + int32_t, + int64_t) { + kernel->OutputAt(1).SetDataType(phi::DataType::INT64); + kernel->OutputAt(2).SetDataType(phi::DataType::INT64); +} diff --git a/paddle/phi/kernels/gpu/shuffle_batch_grad_kernel.cu b/paddle/phi/kernels/gpu/shuffle_batch_grad_kernel.cu new file mode 100644 index 00000000000..33b39666edf --- /dev/null +++ b/paddle/phi/kernels/gpu/shuffle_batch_grad_kernel.cu @@ -0,0 +1,67 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + +#ifndef _MSC_VER +#include +#include +#include +#include +#endif + +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/errors.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/gpu/shuffle_batch_utils.h" +#include "paddle/phi/kernels/shuffle_batch_grad_kernel.h" + +namespace phi { + +template +void ShuffleBatchGradKernel(const Context& dev_ctx, + const DenseTensor& shuffleidx, + const DenseTensor& out_grad, + int startup_seed, + DenseTensor* x_grad) { +#ifdef _MSC_VER + PADDLE_THROW(phi::errors::Unimplemented( + "GPU shuffle_batch_grad is not supported on Windows yet")); +#else + const auto* out_grad_data = out_grad.data(); + const auto* shuffleidx_data = shuffleidx.data(); + auto* x_grad_data = dev_ctx.template Alloc(x_grad); + auto x_embed_size = x_grad->dims()[x_grad->dims().size() - 1]; + ReorderFunctor functor( + out_grad_data, shuffleidx_data, x_grad_data, x_embed_size); + // TODO(zengjinle): for small data, direct cudaMemcpy may be better + phi::funcs::ForRange for_range(dev_ctx, x_grad->numel()); + for_range(functor); +#endif +} + +} // namespace phi + +PD_REGISTER_KERNEL(shuffle_batch_grad, + GPU, + ALL_LAYOUT, + phi::ShuffleBatchGradKernel, + float, + double, + int32_t, + int64_t) {} +#endif diff --git a/paddle/phi/kernels/gpu/shuffle_batch_kernel.cu b/paddle/phi/kernels/gpu/shuffle_batch_kernel.cu new file mode 100644 index 00000000000..e145e7e1c8a --- /dev/null +++ b/paddle/phi/kernels/gpu/shuffle_batch_kernel.cu @@ -0,0 +1,114 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + +#ifndef _MSC_VER +#include +#include +#include +#include +#endif + +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/errors.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/gpu/shuffle_batch_utils.h" +#include "paddle/phi/kernels/shuffle_batch_kernel.h" + +namespace phi { + +template +void ShuffleBatchKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& seed, + int startup_seed, + DenseTensor* out, + DenseTensor* shuffleidx, + DenseTensor* seed_out) { +#ifdef _MSC_VER + PADDLE_THROW(phi::errors::Unimplemented( + "GPU shuffle_batch is not supported on Windows yet")); +#else + int64_t x_embed_size = x.dims()[x.dims().size() - 1]; + int64_t elem_size = 1; + for (int i = 0; i < x.dims().size() - 1; i++) { + elem_size *= x.dims()[i]; + } + shuffleidx->Resize(phi::make_ddim({elem_size})); + + int64_t seed_int = 0; + if (seed.initialized()) { + const auto& seed_place = seed.place().GetType(); + bool is_gpu_place = seed_place == phi::AllocationType::GPU; + if (is_gpu_place) { + // NOTE: We have overwritten GetKernelTypeForVar, so seed_place would + // not be CUDAPlace in practice. This case would only happen in Python + // op_test framework. + phi::DenseTensor tmp_tensor; + phi::Copy(dev_ctx, seed, phi::CPUPlace(), false, &tmp_tensor); + seed_int = *(tmp_tensor.data()); + } else { + seed_int = *(seed.data()); + } + } else { + seed_int = startup_seed; + } + + auto* shuffleidx_data = dev_ctx.template Alloc(shuffleidx); + +#ifdef PADDLE_WITH_CUDA + CacheAllocator allocator(dev_ctx.GetPlace()); + const auto& exec_policy = thrust::cuda::par(allocator).on(dev_ctx.stream()); +#else + const auto& exec_policy = thrust::hip::par.on(dev_ctx.stream()); +#endif + thrust::random::default_random_engine engine(seed_int); + thrust::counting_iterator cnt_iter(0); + thrust::shuffle_copy(exec_policy, + cnt_iter, + cnt_iter + elem_size, + thrust::device_pointer_cast(shuffleidx_data), + engine); + // TODO(zengjinle): for small data, direct cudaMemcpy may be better + auto* x_data = x.data(); + auto* out_data = dev_ctx.template Alloc(out); + ReorderFunctor functor( + x_data, shuffleidx_data, out_data, x_embed_size); + phi::funcs::ForRange for_range(dev_ctx, + elem_size * x_embed_size); + for_range(functor); + seed_out->Resize(phi::make_ddim({1})); + auto* seed_out_data = dev_ctx.template HostAlloc(seed_out); + *seed_out_data = engine(); +#endif +} + +} // namespace phi + +PD_REGISTER_KERNEL(shuffle_batch, + GPU, + ALL_LAYOUT, + phi::ShuffleBatchKernel, + float, + double, + int32_t, + int64_t) { + kernel->OutputAt(1).SetDataType(phi::DataType::INT64); + kernel->OutputAt(2).SetDataType(phi::DataType::INT64); +} +#endif diff --git a/paddle/phi/kernels/gpu/shuffle_batch_utils.h b/paddle/phi/kernels/gpu/shuffle_batch_utils.h new file mode 100644 index 00000000000..3a7c2230d32 --- /dev/null +++ b/paddle/phi/kernels/gpu/shuffle_batch_utils.h @@ -0,0 +1,76 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/tensor_utils.h" + +namespace phi { + +struct CacheAllocator { + typedef char value_type; + explicit CacheAllocator(phi::Place place) { + VLOG(2) << "construct allocator"; + place_ = place; + } + + ~CacheAllocator() { VLOG(2) << "destory allocator"; } + + char* allocate(std::ptrdiff_t num_bytes) { + VLOG(2) << "allocate " << num_bytes << " bytes"; + auto storage = memory_utils::AllocShared(place_, num_bytes); + char* ptr = reinterpret_cast(storage->ptr()); + busy_allocation_.emplace(std::make_pair(ptr, storage)); + return ptr; + } + + void deallocate(char* ptr, size_t) { + VLOG(2) << "deallocate "; + allocation_map_type::iterator iter = busy_allocation_.find(ptr); + CHECK(iter != busy_allocation_.end()); + busy_allocation_.erase(iter); + } + + private: + typedef std::unordered_map> + allocation_map_type; + allocation_map_type busy_allocation_; + phi::Place place_; +}; + +template +struct ReorderFunctor { + ReorderFunctor(const T* x, const int64_t* shuffle_idx, T* y, int64_t stride) + : x_(x), shuffle_idx_(shuffle_idx), y_(y), stride_(stride) {} + + HOSTDEVICE void operator()(int64_t idx) { + auto reorder_idx = shuffle_idx_[idx / stride_] * stride_ + idx % stride_; + if (kIsForward) { + y_[idx] = x_[reorder_idx]; + } else { + y_[reorder_idx] = x_[idx]; + } + } + + private: + const T* x_; + const int64_t* shuffle_idx_; + T* y_; + int64_t stride_; +}; + +} // namespace phi +#endif diff --git a/paddle/phi/kernels/shuffle_batch_grad_kernel.h b/paddle/phi/kernels/shuffle_batch_grad_kernel.h new file mode 100644 index 00000000000..c47f476498e --- /dev/null +++ b/paddle/phi/kernels/shuffle_batch_grad_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void ShuffleBatchGradKernel(const Context& dev_ctx, + const DenseTensor& shuffleidx, + const DenseTensor& out_grad, + int startup_seed, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/shuffle_batch_kernel.h b/paddle/phi/kernels/shuffle_batch_kernel.h new file mode 100644 index 00000000000..504dfe97fa2 --- /dev/null +++ b/paddle/phi/kernels/shuffle_batch_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void ShuffleBatchKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& seed, + int startup_seed, + DenseTensor* out, + DenseTensor* shuffleidx, + DenseTensor* seed_out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/shuffle_batch_sig.cc b/paddle/phi/ops/compat/shuffle_batch_sig.cc new file mode 100644 index 00000000000..22a9f76d95d --- /dev/null +++ b/paddle/phi/ops/compat/shuffle_batch_sig.cc @@ -0,0 +1,40 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature ShuffleBatchOpArgumentMapping( + const ArgumentMappingContext& ctx UNUSED) { + return KernelSignature("shuffle_batch", + {"X", "Seed"}, + {"startup_seed"}, + {"Out", "ShuffleIdx", "SeedOut"}); +} + +KernelSignature ShuffleBatchGradOpArgumentMapping( + const ArgumentMappingContext& ctx UNUSED) { + return KernelSignature("shuffle_batch_grad", + {"ShuffleIdx", "Out@GRAD"}, + {"startup_seed"}, + {"X@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(shuffle_batch, phi::ShuffleBatchOpArgumentMapping); + +PD_REGISTER_ARG_MAPPING_FN(shuffle_batch_grad, + phi::ShuffleBatchGradOpArgumentMapping); diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index e91a94d5747..079233a9c16 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -1306,6 +1306,7 @@ set(STATIC_BUILD_TESTS test_segment_ops test_sparse_momentum_op test_sgd_op_bf16 + test_shuffle_batch_op test_softmax_mask_fuse_upper_triangle_op test_sparse_conv_op test_sparse_norm_op diff --git a/test/legacy_test/test_shuffle_batch_op.py b/test/legacy_test/test_shuffle_batch_op.py index 629164e1066..c2089026129 100644 --- a/test/legacy_test/test_shuffle_batch_op.py +++ b/test/legacy_test/test_shuffle_batch_op.py @@ -76,7 +76,7 @@ class TestShuffleBatchOpBase(OpTest): return np.reshape(np.array(arr_list), shape) def test_check_grad(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_dygraph=False) class TestShuffleBatchOp2(TestShuffleBatchOpBase): -- GitLab