未验证 提交 c6b6ba1f 编写于 作者: Z Zeng Jinle 提交者: GitHub

Add gpu implementation of shuffle_batch_op (#33938)

* add gpu implementation of shuffle batch
test=develop

* add thrust cuda patches
test=develop

* fix macro guard

* fix shuffle batch compile on windows/hip

* fix hip compilation error

* refine CMakeLists.txt

* fix windows compile error

* try to fix windows CI compilation error

* fix windows compilation again

* fix shuffle_batch op test on Windows
上级 5085c44b
......@@ -233,3 +233,4 @@ endif()
mark_as_advanced(CUDA_BUILD_CUBIN CUDA_BUILD_EMULATION CUDA_VERBOSE_BUILD)
mark_as_advanced(CUDA_SDK_ROOT_DIR CUDA_SEPARABLE_COMPILATION)
include(thrust)
......@@ -85,3 +85,5 @@ message(STATUS "HIP library name: ${hip_library_name}")
# set HIP link libs
find_library(ROCM_HIPRTC_LIB ${hip_library_name} HINTS ${HIP_PATH}/lib)
message(STATUS "ROCM_HIPRTC_LIB: ${ROCM_HIPRTC_LIB}")
include(thrust)
function(add_thrust_patches_if_necessary)
set(thrust_detect_file ${PROJECT_BINARY_DIR}/detect_thrust.cu)
file(WRITE ${thrust_detect_file} ""
"#include \"thrust/version.h\"\n"
"#include \"thrust/shuffle.h\"\n"
"#include \"stdio.h\"\n"
"int main() {\n"
" int version = THRUST_VERSION;\n"
" printf(\"%d\", version);\n"
" return 0;\n"
"}\n")
execute_process(COMMAND "${CUDA_NVCC_EXECUTABLE}"
"--run" "${thrust_detect_file}"
WORKING_DIRECTORY "${PROJECT_BINARY_DIR}/CMakeFiles/"
RESULT_VARIABLE nvcc_res ERROR_QUIET)
if(NOT nvcc_res EQUAL 0)
set(thrust_patches "${PADDLE_SOURCE_DIR}/patches/thrust")
message(STATUS "Add thrust patches: ${thrust_patches}")
include_directories(${thrust_patches})
endif()
endfunction()
add_thrust_patches_if_necessary()
......@@ -53,6 +53,16 @@ class ShuffleBatchOp : public framework::OperatorWithKernel {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (var_name == "Seed") {
return expected_kernel_type;
}
return framework::OperatorWithKernel::GetKernelTypeForVar(
var_name, tensor, expected_kernel_type);
}
};
class ShuffleBatchOpMaker : public framework::OpProtoAndCheckerMaker {
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifndef _MSC_VER
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/random.h>
#include <thrust/shuffle.h>
#endif
#include "paddle/fluid/operators/shuffle_batch_op.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
template <typename T, bool kIsForward>
struct ReorderFunctor {
ReorderFunctor(const T *x, const int64_t *shuffle_idx, T *y, int64_t stride)
: x_(x), shuffle_idx_(shuffle_idx), y_(y), stride_(stride) {}
HOSTDEVICE void operator()(int64_t idx) {
auto reorder_idx = shuffle_idx_[idx / stride_] * stride_ + idx % stride_;
if (kIsForward) {
y_[idx] = x_[reorder_idx];
} else {
y_[reorder_idx] = x_[idx];
}
}
private:
const T *x_;
const int64_t *shuffle_idx_;
T *y_;
int64_t stride_;
};
template <typename T>
class ShuffleBatchCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
#ifdef _MSC_VER
PADDLE_THROW(platform::errors::Unimplemented(
"GPU shuffle_batch is not supported on Windows yet"));
#else
auto *x = ctx.Input<framework::Tensor>("X");
auto *seed = ctx.Input<framework::Tensor>("Seed");
auto *out = ctx.Output<framework::Tensor>("Out");
auto *shuffleidx = ctx.Output<framework::Tensor>("ShuffleIdx");
auto *seed_out = ctx.Output<framework::Tensor>("SeedOut");
int64_t x_embed_size = x->dims()[x->dims().size() - 1];
int64_t elem_size = 1;
for (int i = 0; i < x->dims().size() - 1; i++) {
elem_size *= x->dims()[i];
}
shuffleidx->Resize(framework::make_ddim({elem_size}));
int64_t seed_int = 0;
if (seed->IsInitialized()) {
const auto &seed_place = seed->place();
if (platform::is_gpu_place(seed_place)) {
// NOTE: We have overwritten GetKernelTypeForVar, so seed_place would
// not be CUDAPlace in practice. This case would only happen in Python
// op_test framework.
framework::Tensor tmp_tensor;
framework::TensorCopySync(*seed, platform::CPUPlace(), &tmp_tensor);
seed_int = *(tmp_tensor.data<int64_t>());
} else {
seed_int = *(seed->data<int64_t>());
}
} else {
seed_int = ctx.Attr<int>("startup_seed");
}
auto *shuffleidx_data = shuffleidx->mutable_data<int64_t>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
#ifdef PADDLE_WITH_CUDA
const auto &exec_policy = thrust::cuda::par.on(dev_ctx.stream());
#else
const auto &exec_policy = thrust::hip::par.on(dev_ctx.stream());
#endif
thrust::random::default_random_engine engine(seed_int);
thrust::counting_iterator<int64_t> cnt_iter(0);
thrust::shuffle_copy(exec_policy, cnt_iter, cnt_iter + elem_size,
thrust::device_pointer_cast(shuffleidx_data), engine);
// TODO(zengjinle): for small data, direct cudaMemcpy may be better
auto *x_data = x->data<T>();
auto *out_data = out->mutable_data<T>(ctx.GetPlace());
ReorderFunctor<T, true> functor(x_data, shuffleidx_data, out_data,
x_embed_size);
platform::ForRange<platform::CUDADeviceContext> for_range(
dev_ctx, elem_size * x_embed_size);
for_range(functor);
auto *seed_out_data = seed_out->mutable_data<int64_t>(
framework::make_ddim({1}), platform::CPUPlace());
*seed_out_data = engine();
#endif
}
};
template <typename T>
class ShuffleBatchGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
#ifdef _MSC_VER
PADDLE_THROW(platform::errors::Unimplemented(
"GPU shuffle_batch_grad is not supported on Windows yet"));
#else
const auto *out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
const auto *shuffleidx = ctx.Input<framework::Tensor>("ShuffleIdx");
auto *x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
const auto *out_grad_data = out_grad->data<T>();
const auto *shuffleidx_data = shuffleidx->data<int64_t>();
auto *x_grad_data = x_grad->mutable_data<T>(ctx.GetPlace());
auto x_embed_size = x_grad->dims()[x_grad->dims().size() - 1];
ReorderFunctor<T, false> functor(out_grad_data, shuffleidx_data,
x_grad_data, x_embed_size);
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
// TODO(zengjinle): for small data, direct cudaMemcpy may be better
platform::ForRange<platform::CUDADeviceContext> for_range(dev_ctx,
x_grad->numel());
for_range(functor);
#endif
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(shuffle_batch, ops::ShuffleBatchCUDAKernel<float>,
ops::ShuffleBatchCUDAKernel<double>,
ops::ShuffleBatchCUDAKernel<int32_t>,
ops::ShuffleBatchCUDAKernel<int64_t>);
REGISTER_OP_CUDA_KERNEL(shuffle_batch_grad,
ops::ShuffleBatchGradCUDAKernel<float>,
ops::ShuffleBatchGradCUDAKernel<double>,
ops::ShuffleBatchGradCUDAKernel<int32_t>,
ops::ShuffleBatchGradCUDAKernel<int64_t>);
#endif
/*
* Copyright 2008-2020 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*! \file shuffle.inl
* \brief Inline file for shuffle.h.
*/
#include <thrust/detail/config.h>
#include <thrust/detail/cpp11_required.h>
#if THRUST_CPP_DIALECT >= 2011
#include <thrust/iterator/iterator_traits.h>
#include <thrust/shuffle.h>
#include <thrust/system/detail/generic/select_system.h>
#include <thrust/system/detail/generic/shuffle.h>
namespace thrust {
__thrust_exec_check_disable__
template <typename DerivedPolicy, typename RandomIterator, typename URBG>
__host__ __device__ void shuffle(
const thrust::detail::execution_policy_base<DerivedPolicy>& exec,
RandomIterator first, RandomIterator last, URBG&& g) {
using thrust::system::detail::generic::shuffle;
return shuffle(
thrust::detail::derived_cast(thrust::detail::strip_const(exec)),
first, last, g);
}
template <typename RandomIterator, typename URBG>
__host__ __device__ void shuffle(RandomIterator first, RandomIterator last,
URBG&& g) {
using thrust::system::detail::generic::select_system;
typedef typename thrust::iterator_system<RandomIterator>::type System;
System system;
return thrust::shuffle(select_system(system), first, last, g);
}
__thrust_exec_check_disable__
template <typename DerivedPolicy, typename RandomIterator,
typename OutputIterator, typename URBG>
__host__ __device__ void shuffle_copy(
const thrust::detail::execution_policy_base<DerivedPolicy>& exec,
RandomIterator first, RandomIterator last, OutputIterator result,
URBG&& g) {
using thrust::system::detail::generic::shuffle_copy;
return shuffle_copy(
thrust::detail::derived_cast(thrust::detail::strip_const(exec)),
first, last, result, g);
}
template <typename RandomIterator, typename OutputIterator, typename URBG>
__host__ __device__ void shuffle_copy(RandomIterator first, RandomIterator last,
OutputIterator result, URBG&& g) {
using thrust::system::detail::generic::select_system;
typedef typename thrust::iterator_system<RandomIterator>::type System1;
typedef typename thrust::iterator_system<OutputIterator>::type System2;
System1 system1;
System2 system2;
return thrust::shuffle_copy(select_system(system1, system2), first, last,
result, g);
}
} // namespace thrust
#endif
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* Copyright 2008-2020 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*! \file shuffle.h
* \brief Reorders range by a uniform random permutation
*/
#pragma once
#include <thrust/detail/config.h>
#include <thrust/detail/cpp11_required.h>
#if THRUST_CPP_DIALECT >= 2011
#include <thrust/detail/config.h>
#include <thrust/detail/execution_policy.h>
namespace thrust {
/*! \addtogroup reordering
* \ingroup algorithms
*
* \addtogroup shuffling
* \ingroup reordering
* \{
*/
/*! \p shuffle reorders the elements <tt>[first, last)</tt> by a uniform
* pseudorandom permutation, defined by
* random engine \p g.
*
* The algorithm's execution is parallelized as determined by \p exec.
*
* \param exec The execution policy to use for parallelization.
* \param first The beginning of the sequence to shuffle.
* \param last The end of the sequence to shuffle.
* \param g A UniformRandomBitGenerator
*
* \tparam DerivedPolicy The name of the derived execution policy.
* \tparam RandomIterator is a random access iterator
* \tparam URBG is a uniform random bit generator
*
* The following code snippet demonstrates how to use \p shuffle to create a
* random permutation
* using the \p thrust::host execution policy for parallelization:
*
* \code
* #include <thrust/shuffle.h>
* #include <thrust/random.h>
* #include <thrust/execution_policy.h>
* int A[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
* const int N = sizeof(A)/sizeof(int);
* thrust::default_random_engine g;
* thrust::shuffle(thrust::host, A, A + N, g);
* // A is now {6, 5, 8, 7, 2, 1, 4, 3, 10, 9}
* \endcode
*
* \see \p shuffle_copy
*/
template <typename DerivedPolicy, typename RandomIterator, typename URBG>
__host__ __device__ void shuffle(
const thrust::detail::execution_policy_base<DerivedPolicy>& exec,
RandomIterator first,
RandomIterator last,
URBG&& g);
/*! \p shuffle reorders the elements <tt>[first, last)</tt> by a uniform
* pseudorandom permutation, defined by
* random engine \p g.
*
* \param first The beginning of the sequence to shuffle.
* \param last The end of the sequence to shuffle.
* \param g A UniformRandomBitGenerator
*
* \tparam RandomIterator is a random access iterator
* \tparam URBG is a uniform random bit generator
*
* The following code snippet demonstrates how to use \p shuffle to create a
* random permutation.
*
* \code
* #include <thrust/shuffle.h>
* #include <thrust/random.h>
* int A[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
* const int N = sizeof(A)/sizeof(int);
* thrust::default_random_engine g;
* thrust::shuffle(A, A + N, g);
* // A is now {6, 5, 8, 7, 2, 1, 4, 3, 10, 9}
* \endcode
*
* \see \p shuffle_copy
*/
template <typename RandomIterator, typename URBG>
__host__ __device__ void shuffle(RandomIterator first,
RandomIterator last,
URBG&& g);
/*! shuffle_copy differs from shuffle only in that the reordered sequence is
written to different output sequences, rather than in place.
* \p shuffle_copy reorders the elements <tt>[first, last)</tt> by a uniform
pseudorandom permutation, defined by
* random engine \p g.
*
* The algorithm's execution is parallelized as determined by \p exec.
* \param exec The execution policy to use for parallelization.
* \param first The beginning of the sequence to shuffle.
* \param last The end of the sequence to shuffle.
* \param result Destination of shuffled sequence
* \param g A UniformRandomBitGenerator
*
* \tparam DerivedPolicy The name of the derived execution policy.
* \tparam RandomIterator is a random access iterator
* \tparam OutputIterator is a model of <a
href="https://en.cppreference.com/w/cpp/iterator/output_iterator">Output
Iterator</a>.
* \tparam URBG is a uniform random bit generator
*
* The following code snippet demonstrates how to use \p shuffle_copy to create
a random permutation.
*
* \code
* #include <thrust/shuffle.h>
* #include <thrust/random.h>
* #include <thrust/execution_policy.h>
* int A[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
* int result[10];
* const int N = sizeof(A)/sizeof(int);
* thrust::default_random_engine g;
* thrust::shuffle_copy(thrust::host, A, A + N, result, g);
* // result is now {6, 5, 8, 7, 2, 1, 4, 3, 10, 9}
* \endcode
*
* \see \p shuffle
*/
template <typename DerivedPolicy,
typename RandomIterator,
typename OutputIterator,
typename URBG>
__host__ __device__ void shuffle_copy(
const thrust::detail::execution_policy_base<DerivedPolicy>& exec,
RandomIterator first,
RandomIterator last,
OutputIterator result,
URBG&& g);
/*! shuffle_copy differs from shuffle only in that the reordered sequence is
*written to different output sequences, rather than in place.
*\p shuffle_copy reorders the elements <tt>[first, last)</tt> by a uniform
*pseudorandom permutation, defined by
* random engine \p g.
*
* \param first The beginning of the sequence to shuffle.
* \param last The end of the sequence to shuffle.
* \param result Destination of shuffled sequence
* \param g A UniformRandomBitGenerator
*
* \tparam RandomIterator is a random access iterator
* \tparam OutputIterator is a model of <a
*href="https://en.cppreference.com/w/cpp/iterator/output_iterator">Output
*Iterator</a>.
* \tparam URBG is a uniform random bit generator
*
* The following code snippet demonstrates how to use \p shuffle_copy to create
*a random permutation.
*
* \code
* #include <thrust/shuffle.h>
* #include <thrust/random.h>
* int A[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
* int result[10];
* const int N = sizeof(A)/sizeof(int);
* thrust::default_random_engine g;
* thrust::shuffle_copy(A, A + N, result, g);
* // result is now {6, 5, 8, 7, 2, 1, 4, 3, 10, 9}
* \endcode
*
* \see \p shuffle
*/
template <typename RandomIterator, typename OutputIterator, typename URBG>
__host__ __device__ void shuffle_copy(RandomIterator first,
RandomIterator last,
OutputIterator result,
URBG&& g);
} // namespace thrust
#include <thrust/detail/shuffle.inl>
#endif
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* Copyright 2008-2020 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*! \file shuffle.h
* \brief Generic implementations of shuffle functions.
*/
#pragma once
#include <thrust/detail/config.h>
#include <thrust/detail/cpp11_required.h>
#if THRUST_CPP_DIALECT >= 2011
#include <thrust/system/detail/generic/tag.h>
namespace thrust {
namespace system {
namespace detail {
namespace generic {
template <typename ExecutionPolicy, typename RandomIterator, typename URBG>
__host__ __device__ void shuffle(
thrust::execution_policy<ExecutionPolicy>& exec,
RandomIterator first,
RandomIterator last,
URBG&& g);
template <typename ExecutionPolicy,
typename RandomIterator,
typename OutputIterator,
typename URBG>
__host__ __device__ void shuffle_copy(
thrust::execution_policy<ExecutionPolicy>& exec,
RandomIterator first,
RandomIterator last,
OutputIterator result,
URBG&& g);
} // end namespace generic
} // end namespace detail
} // end namespace system
} // end namespace thrust
#include <thrust/system/detail/generic/shuffle.inl>
#endif
/*
* Copyright 2008-20120 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thrust/detail/config.h>
#include <thrust/detail/temporary_array.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
#include <thrust/random.h>
#include <thrust/scan.h>
#include <thrust/system/detail/generic/shuffle.h>
#include <cstdint>
namespace thrust {
template <typename Iterator>
using iterator_value_t = typename iterator_value<Iterator>::type;
namespace system {
namespace detail {
namespace generic {
// An implementation of a Feistel cipher for operating on 64 bit keys
class feistel_bijection {
struct round_state {
std::uint32_t left;
std::uint32_t right;
};
public:
template <class URBG>
__host__ __device__ feistel_bijection(std::uint64_t m, URBG&& g) {
std::uint64_t total_bits = get_cipher_bits(m);
// Half bits rounded down
left_side_bits = total_bits / 2;
left_side_mask = (1ull << left_side_bits) - 1;
// Half the bits rounded up
right_side_bits = total_bits - left_side_bits;
right_side_mask = (1ull << right_side_bits) - 1;
for (std::uint64_t i = 0; i < num_rounds; i++) {
key[i] = g();
}
}
__host__ __device__ std::uint64_t nearest_power_of_two() const {
return 1ull << (left_side_bits + right_side_bits);
}
__host__ __device__ std::uint64_t operator()(const std::uint64_t val) const {
// Extract the right and left sides of the input
auto left = static_cast<std::uint32_t>(val >> right_side_bits);
auto right = static_cast<std::uint32_t>(val & right_side_mask);
round_state state = {left, right};
for (std::uint64_t i = 0; i < num_rounds; i++) {
state = do_round(state, i);
}
// Check we have the correct number of bits on each side
assert((state.left >> left_side_bits) == 0);
assert((state.right >> right_side_bits) == 0);
// Combine the left and right sides together to get result
return state.left << right_side_bits | state.right;
}
private:
// Find the nearest power of two
__host__ __device__ std::uint64_t get_cipher_bits(std::uint64_t m) {
if (m == 0) return 0;
std::uint64_t i = 0;
m--;
while (m != 0) {
i++;
m >>= 1;
}
return i;
}
// Equivalent to boost::hash_combine
__host__ __device__
std::size_t hash_combine(std::uint64_t lhs, std::uint64_t rhs) const {
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}
// Round function, a 'pseudorandom function' who's output is indistinguishable
// from random for each key value input. This is not cryptographically secure
// but sufficient for generating permutations.
__host__ __device__ std::uint32_t round_function(std::uint64_t value,
const std::uint64_t key_) const {
std::uint64_t hash0 = thrust::random::taus88(static_cast<std::uint32_t>(value))();
std::uint64_t hash1 = thrust::random::ranlux48(value)();
return static_cast<std::uint32_t>(
hash_combine(hash_combine(hash0, key_), hash1) & left_side_mask);
}
__host__ __device__ round_state do_round(const round_state state,
const std::uint64_t round) const {
const std::uint32_t new_left = state.right & left_side_mask;
const std::uint32_t round_function_res =
state.left ^ round_function(state.right, key[round]);
if (right_side_bits != left_side_bits) {
// Upper bit of the old right becomes lower bit of new right if we have
// odd length feistel
const std::uint32_t new_right =
(round_function_res << 1ull) | state.right >> left_side_bits;
return {new_left, new_right};
}
return {new_left, round_function_res};
}
static constexpr std::uint64_t num_rounds = 16;
std::uint64_t right_side_bits;
std::uint64_t left_side_bits;
std::uint64_t right_side_mask;
std::uint64_t left_side_mask;
std::uint64_t key[num_rounds];
};
struct key_flag_tuple {
std::uint64_t key;
std::uint64_t flag;
};
// scan only flags
struct key_flag_scan_op {
__host__ __device__ key_flag_tuple operator()(const key_flag_tuple& a,
const key_flag_tuple& b) {
return {b.key, a.flag + b.flag};
}
};
struct construct_key_flag_op {
std::uint64_t m;
feistel_bijection bijection;
__host__ __device__ construct_key_flag_op(std::uint64_t m,
feistel_bijection bijection)
: m(m), bijection(bijection) {}
__host__ __device__ key_flag_tuple operator()(std::uint64_t idx) {
auto gather_key = bijection(idx);
return key_flag_tuple{gather_key, (gather_key < m) ? 1ull : 0ull};
}
};
template <typename InputIterT, typename OutputIterT>
struct write_output_op {
std::uint64_t m;
InputIterT in;
OutputIterT out;
// flag contains inclusive scan of valid keys
// perform gather using valid keys
__thrust_exec_check_disable__
__host__ __device__ std::size_t operator()(key_flag_tuple x) {
if (x.key < m) {
// -1 because inclusive scan
out[x.flag - 1] = in[x.key];
}
return 0; // Discarded
}
};
template <typename ExecutionPolicy, typename RandomIterator, typename URBG>
__host__ __device__ void shuffle(
thrust::execution_policy<ExecutionPolicy>& exec, RandomIterator first,
RandomIterator last, URBG&& g) {
using InputType = typename thrust::iterator_value_t<RandomIterator>;
// copy input to temp buffer
thrust::detail::temporary_array<InputType, ExecutionPolicy> temp(exec, first,
last);
thrust::shuffle_copy(exec, temp.begin(), temp.end(), first, g);
}
template <typename ExecutionPolicy, typename RandomIterator,
typename OutputIterator, typename URBG>
__host__ __device__ void shuffle_copy(
thrust::execution_policy<ExecutionPolicy>& exec, RandomIterator first,
RandomIterator last, OutputIterator result, URBG&& g) {
// m is the length of the input
// we have an available bijection of length n via a feistel cipher
std::size_t m = last - first;
feistel_bijection bijection(m, g);
std::uint64_t n = bijection.nearest_power_of_two();
// perform stream compaction over length n bijection to get length m
// pseudorandom bijection over the original input
thrust::counting_iterator<std::uint64_t> indices(0);
thrust::transform_iterator<construct_key_flag_op, decltype(indices),
key_flag_tuple>
key_flag_it(indices, construct_key_flag_op(m, bijection));
write_output_op<RandomIterator, decltype(result)> write_functor{m, first,
result};
auto gather_output_it = thrust::make_transform_output_iterator(
thrust::discard_iterator<std::size_t>(), write_functor);
// the feistel_bijection outputs a stream of permuted indices in range [0,n)
// flag each value < m and compact it, so we have a set of permuted indices in
// range [0,m) each thread gathers an input element according to its
// pseudorandom permuted index
thrust::inclusive_scan(exec, key_flag_it, key_flag_it + n, gather_output_it,
key_flag_scan_op());
}
} // end namespace generic
} // end namespace detail
} // end namespace system
} // end namespace thrust
......@@ -20,27 +20,36 @@ import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
from op_test import OpTest
import os
import random
class TestShuffleBatchOp(OpTest):
class TestShuffleBatchOpBase(OpTest):
def gen_random_array(self, shape, low=0, high=1):
rnd = (high - low) * np.random.random(shape) + low
return rnd.astype(self.dtype)
def get_shape(self):
return (10, 10, 5)
def _get_places(self):
# NOTE: shuffle_batch is not supported on Windows
if os.name == 'nt':
return [fluid.CPUPlace()]
return super(TestShuffleBatchOpBase, self)._get_places()
def setUp(self):
self.op_type = 'shuffle_batch'
self.dtype = np.float64
x = np.array(
[np.arange(100), np.arange(100)]).astype(self.dtype).reshape(
[2, 100])
out = np.array(
[np.arange(100), np.arange(100)]).astype(self.dtype).reshape(
[2, 100])
self.possible_res = [
np.array([np.arange(100), np.arange(100)]).astype(self.dtype),
]
self.inputs = {'X': x, 'Seed': np.array([1]).astype('int64')}
self.shape = self.get_shape()
x = self.gen_random_array(self.shape)
seed = np.random.random_integers(
low=10, high=100, size=(1, )).astype('int64')
self.inputs = {'X': x, 'Seed': seed}
self.outputs = {
'Out': out,
'ShuffleIdx': np.array([1, 0]).astype('int64'),
'SeedOut': np.array([1]).astype('int64')
'Out': np.array([]).astype(x.dtype),
'ShuffleIdx': np.array([]).astype('int64'),
'SeedOut': np.array([]).astype(seed.dtype),
}
self.attrs = {'startup_seed': 1}
......@@ -48,16 +57,33 @@ class TestShuffleBatchOp(OpTest):
self.check_output_customized(self.verify_output)
def verify_output(self, outs):
for elem in outs:
if elem.shape == self.outputs['Out'].shape:
out = elem
x = np.copy(self.inputs['X'])
y = None
for out in outs:
if out.shape == x.shape:
y = np.copy(out)
break
is_equal = [np.all(out == res) for res in self.possible_res]
self.assertIn(True, is_equal)
assert y is not None
sort_x = self.sort_array(x)
sort_y = self.sort_array(y)
self.assertTrue(np.array_equal(sort_x, sort_y))
def sort_array(self, array):
shape = array.shape
new_shape = [-1, shape[-1]]
arr_list = np.reshape(array, new_shape).tolist()
arr_list.sort(key=lambda x: x[0])
return np.reshape(np.array(arr_list), shape)
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestShuffleBatchOp2(TestShuffleBatchOpBase):
def get_shape(self):
return (4, 30)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册