未验证 提交 737992eb 编写于 作者: F Feiyu Chan 提交者: GitHub

Add LRUCache for fft plans (#36646)

* WIP: add cache

* delete move constructor and operator= for CuFFTHandle and FFTConfig

* remove log from CuFFTHandle and FFTConfig

* add lrucache for fft rocm backend

* disable LRUCache when CUFFT_VERSION >= 10200

* disbale copy and move for hipFFTHandle; format code

* clean debug code
Co-authored-by: NXiaoxu Chen <chenxx_id@163.com>
上级 facf6020
...@@ -27,12 +27,12 @@ ...@@ -27,12 +27,12 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using ScalarType = framework::proto::VarType::Type; using ScalarType = framework::proto::VarType::Type;
const int64_t kMaxCUFFTNdim = 3; const int64_t kMaxFFTNdim = 3;
const int64_t kMaxDataNdim = kMaxCUFFTNdim + 1; const int64_t kMaxDataNdim = kMaxFFTNdim + 1;
// This struct is used to easily compute hashes of the // This struct is used to easily compute hashes of the
// parameters. It will be the **key** to the plan cache. // parameters. It will be the **key** to the plan cache.
struct PlanKey { struct FFTConfigKey {
// between 1 and kMaxCUFFTNdim, i.e., 1 <= signal_ndim <= 3 // between 1 and kMaxFFTNdim, i.e., 1 <= signal_ndim <= 3
int64_t signal_ndim_; int64_t signal_ndim_;
// These include additional batch dimension as well. // These include additional batch dimension as well.
int64_t sizes_[kMaxDataNdim]; int64_t sizes_[kMaxDataNdim];
...@@ -41,12 +41,12 @@ struct PlanKey { ...@@ -41,12 +41,12 @@ struct PlanKey {
FFTTransformType fft_type_; FFTTransformType fft_type_;
ScalarType value_type_; ScalarType value_type_;
PlanKey() = default; FFTConfigKey() = default;
PlanKey(const std::vector<int64_t>& in_shape, FFTConfigKey(const std::vector<int64_t>& in_shape,
const std::vector<int64_t>& out_shape, const std::vector<int64_t>& out_shape,
const std::vector<int64_t>& signal_size, FFTTransformType fft_type, const std::vector<int64_t>& signal_size,
ScalarType value_type) { FFTTransformType fft_type, ScalarType value_type) {
// Padding bits must be zeroed for hashing // Padding bits must be zeroed for hashing
memset(this, 0, sizeof(*this)); memset(this, 0, sizeof(*this));
signal_ndim_ = signal_size.size() - 1; signal_ndim_ = signal_size.size() - 1;
...@@ -69,6 +69,12 @@ class CuFFTHandle { ...@@ -69,6 +69,12 @@ class CuFFTHandle {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftCreate(&handle_)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftCreate(&handle_));
} }
CuFFTHandle(const CuFFTHandle& other) = delete;
CuFFTHandle& operator=(const CuFFTHandle& other) = delete;
CuFFTHandle(CuFFTHandle&& other) = delete;
CuFFTHandle& operator=(CuFFTHandle&& other) = delete;
::cufftHandle& get() { return handle_; } ::cufftHandle& get() { return handle_; }
const ::cufftHandle& get() const { return handle_; } const ::cufftHandle& get() const { return handle_; }
...@@ -81,20 +87,20 @@ using plan_size_type = long long int; // NOLINT ...@@ -81,20 +87,20 @@ using plan_size_type = long long int; // NOLINT
// This class contains all the information needed to execute a cuFFT plan: // This class contains all the information needed to execute a cuFFT plan:
// 1. the plan // 1. the plan
// 2. the workspace size needed // 2. the workspace size needed
class CuFFTConfig { class FFTConfig {
public: public:
// Only move semantics is enought for this class. Although we already use // Only move semantics is enought for this class. Although we already use
// unique_ptr for the plan, still remove copy constructor and assignment op so // unique_ptr for the plan, still remove copy constructor and assignment op so
// we don't accidentally copy and take perf hit. // we don't accidentally copy and take perf hit.
explicit CuFFTConfig(const PlanKey& plan_key) explicit FFTConfig(const FFTConfigKey& plan_key)
: CuFFTConfig( : FFTConfig(
std::vector<int64_t>(plan_key.sizes_, std::vector<int64_t>(plan_key.sizes_,
plan_key.sizes_ + plan_key.signal_ndim_ + 1), plan_key.sizes_ + plan_key.signal_ndim_ + 1),
plan_key.signal_ndim_, plan_key.fft_type_, plan_key.value_type_) {} plan_key.signal_ndim_, plan_key.fft_type_, plan_key.value_type_) {}
// sizes are full signal, including batch size and always two-sided // sizes are full signal, including batch size and always two-sided
CuFFTConfig(const std::vector<int64_t>& sizes, const int64_t signal_ndim, FFTConfig(const std::vector<int64_t>& sizes, const int64_t signal_ndim,
FFTTransformType fft_type, ScalarType dtype) FFTTransformType fft_type, ScalarType dtype)
: fft_type_(fft_type), value_type_(dtype) { : fft_type_(fft_type), value_type_(dtype) {
// signal sizes (excluding batch dim) // signal sizes (excluding batch dim)
std::vector<plan_size_type> signal_sizes(sizes.begin() + 1, sizes.end()); std::vector<plan_size_type> signal_sizes(sizes.begin() + 1, sizes.end());
...@@ -144,6 +150,12 @@ class CuFFTConfig { ...@@ -144,6 +150,12 @@ class CuFFTConfig {
ws_size = ws_size_t; ws_size = ws_size_t;
} }
FFTConfig(const FFTConfig& other) = delete;
FFTConfig& operator=(const FFTConfig& other) = delete;
FFTConfig(FFTConfig&& other) = delete;
FFTConfig& operator=(FFTConfig&& other) = delete;
const cufftHandle& plan() const { return plan_ptr.get(); } const cufftHandle& plan() const { return plan_ptr.get(); }
FFTTransformType transform_type() const { return fft_type_; } FFTTransformType transform_type() const { return fft_type_; }
...@@ -167,6 +179,12 @@ class HIPFFTHandle { ...@@ -167,6 +179,12 @@ class HIPFFTHandle {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftCreate(&handle_)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftCreate(&handle_));
} }
HIPFFTHandle(const HIPFFTHandle& other) = delete;
HIPFFTHandle& operator=(const HIPFFTHandle& other) = delete;
HIPFFTHandle(HIPFFTHandle&& other) = delete;
HIPFFTHandle& operator=(HIPFFTHandle&& other) = delete;
::hipfftHandle& get() { return handle_; } ::hipfftHandle& get() { return handle_; }
const ::hipfftHandle& get() const { return handle_; } const ::hipfftHandle& get() const { return handle_; }
...@@ -178,20 +196,20 @@ using plan_size_type = int; ...@@ -178,20 +196,20 @@ using plan_size_type = int;
// This class contains all the information needed to execute a cuFFT plan: // This class contains all the information needed to execute a cuFFT plan:
// 1. the plan // 1. the plan
// 2. the workspace size needed // 2. the workspace size needed
class HIPFFTConfig { class FFTConfig {
public: public:
// Only move semantics is enought for this class. Although we already use // Only move semantics is enought for this class. Although we already use
// unique_ptr for the plan, still remove copy constructor and assignment op so // unique_ptr for the plan, still remove copy constructor and assignment op so
// we don't accidentally copy and take perf hit. // we don't accidentally copy and take perf hit.
explicit HIPFFTConfig(const PlanKey& plan_key) explicit FFTConfig(const FFTConfigKey& plan_key)
: HIPFFTConfig( : FFTConfig(
std::vector<int64_t>(plan_key.sizes_, std::vector<int64_t>(plan_key.sizes_,
plan_key.sizes_ + plan_key.signal_ndim_ + 1), plan_key.sizes_ + plan_key.signal_ndim_ + 1),
plan_key.signal_ndim_, plan_key.fft_type_, plan_key.value_type_) {} plan_key.signal_ndim_, plan_key.fft_type_, plan_key.value_type_) {}
// sizes are full signal, including batch size and always two-sided // sizes are full signal, including batch size and always two-sided
HIPFFTConfig(const std::vector<int64_t>& sizes, const int64_t signal_ndim, FFTConfig(const std::vector<int64_t>& sizes, const int64_t signal_ndim,
FFTTransformType fft_type, ScalarType dtype) FFTTransformType fft_type, ScalarType dtype)
: fft_type_(fft_type), value_type_(dtype) { : fft_type_(fft_type), value_type_(dtype) {
// signal sizes (excluding batch dim) // signal sizes (excluding batch dim)
std::vector<plan_size_type> signal_sizes(sizes.begin() + 1, sizes.end()); std::vector<plan_size_type> signal_sizes(sizes.begin() + 1, sizes.end());
...@@ -257,5 +275,192 @@ class HIPFFTConfig { ...@@ -257,5 +275,192 @@ class HIPFFTConfig {
ScalarType value_type_; ScalarType value_type_;
}; };
#endif #endif
// Hashing machinery for Key
// Fowler–Noll–Vo hash function
// see
// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
template <typename Key>
struct KeyHash {
// Key must be a POD because we read out its memory
// contenst as char* when hashing
static_assert(std::is_pod<Key>::value, "Key must be plain old data type");
size_t operator()(const Key& params) const {
auto ptr = reinterpret_cast<const uint8_t*>(&params);
uint32_t value = 0x811C9DC5;
for (int i = 0; i < static_cast<int>(sizeof(Key)); ++i) {
value ^= ptr[i];
value *= 0x01000193;
}
return static_cast<size_t>(value);
}
};
template <typename Key>
struct KeyEqual {
// Key must be a POD because we read out its memory
// contenst as char* when comparing
static_assert(std::is_pod<Key>::value, "Key must be plain old data type");
bool operator()(const Key& a, const Key& b) const {
auto ptr1 = reinterpret_cast<const uint8_t*>(&a);
auto ptr2 = reinterpret_cast<const uint8_t*>(&b);
return memcmp(ptr1, ptr2, sizeof(Key)) == 0;
}
};
#if CUDA_VERSION < 10000
// Note that the max plan number for CUDA version < 10 has to be 1023
// due to a bug that fails on the 1024th plan
constexpr size_t CUFFT_MAX_PLAN_NUM = 1023;
constexpr size_t CUFFT_DEFAULT_CACHE_SIZE = CUFFT_MAX_PLAN_NUM;
#else
constexpr size_t CUFFT_MAX_PLAN_NUM = std::numeric_limits<size_t>::max();
// The default max cache size chosen for CUDA version > 10 is arbitrary.
// This number puts a limit on how big of a plan cache should we maintain by
// default. Users can always configure it via cufft_set_plan_cache_max_size.
constexpr size_t CUFFT_DEFAULT_CACHE_SIZE = 4096;
#endif
static_assert(CUFFT_MAX_PLAN_NUM >= 0 &&
CUFFT_MAX_PLAN_NUM <= std::numeric_limits<size_t>::max(),
"CUFFT_MAX_PLAN_NUM not in size_t range");
static_assert(CUFFT_DEFAULT_CACHE_SIZE >= 0 &&
CUFFT_DEFAULT_CACHE_SIZE <= CUFFT_MAX_PLAN_NUM,
"CUFFT_DEFAULT_CACHE_SIZE not in [0, CUFFT_MAX_PLAN_NUM] range");
// This cache assumes that the mapping from key to value never changes.
// This is **NOT** thread-safe. Please use a mutex when using it **AND** the
// value returned from try_emplace_value.
// The contract of using this cache is that try_emplace_value should only be
// used when the max_size is positive.
class FFTConfigCache {
public:
using kv_t = typename std::pair<FFTConfigKey, FFTConfig>;
using map_t = typename std::unordered_map<
std::reference_wrapper<FFTConfigKey>, typename std::list<kv_t>::iterator,
KeyHash<FFTConfigKey>, KeyEqual<FFTConfigKey>>;
using map_kkv_iter_t = typename map_t::iterator;
FFTConfigCache() : FFTConfigCache(CUFFT_DEFAULT_CACHE_SIZE) {}
explicit FFTConfigCache(int64_t max_size) { _set_max_size(max_size); }
FFTConfigCache(const FFTConfigCache& other) = delete;
FFTConfigCache& operator=(const FFTConfigCache& other) = delete;
FFTConfigCache(FFTConfigCache&& other) noexcept
: _usage_list(std::move(other._usage_list)),
_cache_map(std::move(other._cache_map)),
_max_size(other._max_size) {}
FFTConfigCache& operator=(FFTConfigCache&& other) noexcept {
_usage_list = std::move(other._usage_list);
_cache_map = std::move(other._cache_map);
_max_size = other._max_size;
return *this;
}
// If key is in this cache, return the cached config. Otherwise, emplace the
// config in this cache and return it.
FFTConfig& lookup(FFTConfigKey params) {
PADDLE_ENFORCE_GT(_max_size, 0,
platform::errors::InvalidArgument(
"The max size of FFTConfigCache must be great than 0,"
"But received is [%d]",
_max_size));
map_kkv_iter_t map_it = _cache_map.find(params);
// Hit, put to list front
if (map_it != _cache_map.end()) {
_usage_list.splice(_usage_list.begin(), _usage_list, map_it->second);
return map_it->second->second;
}
// Miss
// remove if needed
if (_usage_list.size() >= _max_size) {
auto last = _usage_list.end();
last--;
_cache_map.erase(last->first);
_usage_list.pop_back();
}
// construct new plan at list front, then insert into _cache_map
_usage_list.emplace_front(std::piecewise_construct,
std::forward_as_tuple(params),
std::forward_as_tuple(params));
auto kv_it = _usage_list.begin();
_cache_map.emplace(std::piecewise_construct,
std::forward_as_tuple(kv_it->first),
std::forward_as_tuple(kv_it));
return kv_it->second;
}
void clear() {
_cache_map.clear();
_usage_list.clear();
}
void resize(int64_t new_size) {
_set_max_size(new_size);
auto cur_size = _usage_list.size();
if (cur_size > _max_size) {
auto delete_it = _usage_list.end();
for (size_t i = 0; i < cur_size - _max_size; i++) {
delete_it--;
_cache_map.erase(delete_it->first);
}
_usage_list.erase(delete_it, _usage_list.end());
}
}
size_t size() const { return _cache_map.size(); }
size_t max_size() const noexcept { return _max_size; }
std::mutex mutex;
private:
// Only sets size and does value check. Does not resize the data structures.
void _set_max_size(int64_t new_size) {
// We check that 0 <= new_size <= CUFFT_MAX_PLAN_NUM here. Since
// CUFFT_MAX_PLAN_NUM is of type size_t, we need to do non-negativity check
// first.
PADDLE_ENFORCE_GE(
new_size, 0,
platform::errors::InvalidArgument(
"cuFFT plan cache size must be non-negative, But received is [%d]",
new_size));
PADDLE_ENFORCE_LE(new_size, CUFFT_MAX_PLAN_NUM,
platform::errors::InvalidArgument(
"cuFFT plan cache size can not be larger than [%d], "
"But received is [%d]",
CUFFT_MAX_PLAN_NUM, new_size));
_max_size = static_cast<size_t>(new_size);
}
std::list<kv_t> _usage_list;
map_t _cache_map;
size_t _max_size;
};
static std::vector<std::unique_ptr<FFTConfigCache>> plan_caches;
static std::mutex plan_caches_mutex;
static inline FFTConfigCache& get_fft_plan_cache(int64_t device_index) {
std::lock_guard<std::mutex> guard(plan_caches_mutex);
if (device_index >= plan_caches.size()) {
plan_caches.resize(device_index + 1);
}
if (!plan_caches[device_index]) {
plan_caches[device_index] = std::make_unique<FFTConfigCache>();
}
return *plan_caches[device_index];
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -68,9 +68,9 @@ void exec_normalization(const DeviceContext& ctx, const Tensor* in, Tensor* out, ...@@ -68,9 +68,9 @@ void exec_normalization(const DeviceContext& ctx, const Tensor* in, Tensor* out,
} }
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
CuFFTConfig create_cufft_config(const framework::Tensor& input, FFTConfigKey create_fft_configkey(const framework::Tensor& input,
const framework::Tensor& output, const framework::Tensor& output,
int signal_ndim) { int signal_ndim) {
// Create the transform plan (either from cache or locally) // Create the transform plan (either from cache or locally)
const auto value_type = framework::IsComplexType(input.type()) const auto value_type = framework::IsComplexType(input.type())
? framework::ToRealType(input.type()) ? framework::ToRealType(input.type())
...@@ -85,15 +85,14 @@ CuFFTConfig create_cufft_config(const framework::Tensor& input, ...@@ -85,15 +85,14 @@ CuFFTConfig create_cufft_config(const framework::Tensor& input,
auto out_size = output.dims()[i]; auto out_size = output.dims()[i];
signal_size[i] = std::max(in_size, out_size); signal_size[i] = std::max(in_size, out_size);
} }
PlanKey key(framework::vectorize(input.dims()), FFTConfigKey key(framework::vectorize(input.dims()),
framework::vectorize(output.dims()), signal_size, fft_type, framework::vectorize(output.dims()), signal_size, fft_type,
value_type); value_type);
return key;
return CuFFTConfig(key);
} }
// Execute a pre-planned transform // Execute a pre-planned transform
static void exec_cufft_plan_raw(const CuFFTConfig& config, void* in_data, static void exec_cufft_plan_raw(const FFTConfig& config, void* in_data,
void* out_data, bool forward) { void* out_data, bool forward) {
auto& plan = config.plan(); auto& plan = config.plan();
...@@ -102,7 +101,7 @@ static void exec_cufft_plan_raw(const CuFFTConfig& config, void* in_data, ...@@ -102,7 +101,7 @@ static void exec_cufft_plan_raw(const CuFFTConfig& config, void* in_data,
} }
template <typename DeviceContext, typename Ti, typename To> template <typename DeviceContext, typename Ti, typename To>
void exec_cufft_plan(const DeviceContext& ctx, const CuFFTConfig& config, void exec_cufft_plan(const DeviceContext& ctx, const FFTConfig& config,
framework::Tensor* input, framework::Tensor* output, framework::Tensor* input, framework::Tensor* output,
bool forward) { bool forward) {
// execute transform plan // execute transform plan
...@@ -136,7 +135,7 @@ void exec_cufft_plan(const DeviceContext& ctx, const CuFFTConfig& config, ...@@ -136,7 +135,7 @@ void exec_cufft_plan(const DeviceContext& ctx, const CuFFTConfig& config,
#elif defined(PADDLE_WITH_HIP) #elif defined(PADDLE_WITH_HIP)
HIPFFTConfig create_hipfft_config(const framework::Tensor& input, FFTConfigKey create_fft_configkey(const framework::Tensor& input,
const framework::Tensor& output, const framework::Tensor& output,
int signal_ndim) { int signal_ndim) {
// Create the transform plan (either from cache or locally) // Create the transform plan (either from cache or locally)
...@@ -153,15 +152,14 @@ HIPFFTConfig create_hipfft_config(const framework::Tensor& input, ...@@ -153,15 +152,14 @@ HIPFFTConfig create_hipfft_config(const framework::Tensor& input,
auto out_size = output.dims()[i]; auto out_size = output.dims()[i];
signal_size[i] = std::max(in_size, out_size); signal_size[i] = std::max(in_size, out_size);
} }
PlanKey key(framework::vectorize(input.dims()), FFTConfigKey key(framework::vectorize(input.dims()),
framework::vectorize(output.dims()), signal_size, fft_type, framework::vectorize(output.dims()), signal_size, fft_type,
value_type); value_type);
return key;
return HIPFFTConfig(key);
} }
// Execute a pre-planned transform // Execute a pre-planned transform
static void exec_hipfft_plan_raw(const HIPFFTConfig& config, void* in_data, static void exec_hipfft_plan_raw(const FFTConfig& config, void* in_data,
void* out_data, bool forward) { void* out_data, bool forward) {
auto& plan = config.plan(); auto& plan = config.plan();
...@@ -216,7 +214,7 @@ static void exec_hipfft_plan_raw(const HIPFFTConfig& config, void* in_data, ...@@ -216,7 +214,7 @@ static void exec_hipfft_plan_raw(const HIPFFTConfig& config, void* in_data,
} }
template <typename DeviceContext, typename Ti, typename To> template <typename DeviceContext, typename Ti, typename To>
void exec_hipfft_plan(const DeviceContext& ctx, const HIPFFTConfig& config, void exec_hipfft_plan(const DeviceContext& ctx, const FFTConfig& config,
framework::Tensor* input, framework::Tensor* output, framework::Tensor* input, framework::Tensor* output,
bool forward) { bool forward) {
auto fft_type = config.transform_type(); auto fft_type = config.transform_type();
...@@ -308,34 +306,58 @@ void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out, ...@@ -308,34 +306,58 @@ void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out,
collapsed_output.Resize(framework::make_ddim(collapsed_output_shape)); collapsed_output.Resize(framework::make_ddim(collapsed_output_shape));
collapsed_output.mutable_data<To>(tensor_place); collapsed_output.mutable_data<To>(tensor_place);
FFTConfig* config = nullptr;
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA)
std::unique_ptr<FFTConfig> config_ = nullptr;
// create plan // create plan
CuFFTConfig config = FFTConfigKey key =
create_cufft_config(collapsed_input, collapsed_output, signal_ndim); create_fft_configkey(collapsed_input, collapsed_output, signal_ndim);
if (CUFFT_VERSION < 10200) {
const int64_t device_id = static_cast<int64_t>(
reinterpret_cast<const platform::CUDAPlace*>(&collapsed_input.place())
->GetDeviceId());
FFTConfigCache& plan_cache = get_fft_plan_cache(device_id);
std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock);
guard.lock();
config = &(plan_cache.lookup(key));
} else {
config_ = std::make_unique<FFTConfig>(key);
config = config_.get();
}
// prepare cufft for execution // prepare cufft for execution
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cufftSetStream(config.plan(), ctx.stream())); platform::dynload::cufftSetStream(config->plan(), ctx.stream()));
framework::Tensor workspace_tensor; framework::Tensor workspace_tensor;
workspace_tensor.mutable_data<To>(tensor_place, config.workspace_size()); workspace_tensor.mutable_data<To>(tensor_place, config->workspace_size());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftSetWorkArea( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cufftSetWorkArea(
config.plan(), workspace_tensor.data<To>())); config->plan(), workspace_tensor.data<To>()));
// execute transform plan // execute transform plan
exec_cufft_plan<DeviceContext, Ti, To>(ctx, config, &collapsed_input, exec_cufft_plan<DeviceContext, Ti, To>(ctx, *config, &collapsed_input,
&collapsed_output, forward); &collapsed_output, forward);
#elif defined(PADDLE_WITH_HIP) #elif defined(PADDLE_WITH_HIP)
// create plan // create plan
HIPFFTConfig config = FFTConfigKey key =
create_hipfft_config(collapsed_input, collapsed_output, signal_ndim); create_fft_configkey(collapsed_input, collapsed_output, signal_ndim);
const int64_t device_id = static_cast<int64_t>(
reinterpret_cast<const platform::CUDAPlace*>(&collapsed_input.place())
->GetDeviceId());
FFTConfigCache& plan_cache = get_fft_plan_cache(device_id);
std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock);
guard.lock();
config = &(plan_cache.lookup(key));
// prepare cufft for execution // prepare cufft for execution
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::hipfftSetStream(config.plan(), ctx.stream())); platform::dynload::hipfftSetStream(config->plan(), ctx.stream()));
framework::Tensor workspace_tensor; framework::Tensor workspace_tensor;
workspace_tensor.mutable_data<To>(tensor_place, config.workspace_size()); workspace_tensor.mutable_data<To>(tensor_place, config->workspace_size());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftSetWorkArea( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::hipfftSetWorkArea(
config.plan(), workspace_tensor.data<To>())); config->plan(), workspace_tensor.data<To>()));
// execute transform plan // execute transform plan
exec_hipfft_plan<DeviceContext, Ti, To>(ctx, config, &collapsed_input, exec_hipfft_plan<DeviceContext, Ti, To>(ctx, *config, &collapsed_input,
&collapsed_output, forward); &collapsed_output, forward);
#endif #endif
...@@ -358,10 +380,10 @@ void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out, ...@@ -358,10 +380,10 @@ void exec_fft(const DeviceContext& ctx, const Tensor* X, Tensor* out,
// Use the optimized path to perform single R2C or C2R if transformation dim is // Use the optimized path to perform single R2C or C2R if transformation dim is
// supported by cuFFT // supported by cuFFT
bool use_optimized_cufft_path(const std::vector<int64_t>& axes) { bool use_optimized_fft_path(const std::vector<int64_t>& axes) {
// For performance reason, when axes starts with (0, 1), do not use the // For performance reason, when axes starts with (0, 1), do not use the
// optimized path. // optimized path.
if (axes.size() > kMaxCUFFTNdim || if (axes.size() > kMaxFFTNdim ||
(axes.size() >= 2 && axes[0] == 0 && axes[1] == 1)) { (axes.size() >= 2 && axes[0] == 0 && axes[1] == 1)) {
return false; return false;
} else { } else {
...@@ -391,7 +413,7 @@ struct FFTC2CFunctor<platform::CUDADeviceContext, Ti, To> { ...@@ -391,7 +413,7 @@ struct FFTC2CFunctor<platform::CUDADeviceContext, Ti, To> {
while (true) { while (true) {
max_dims = max_dims =
std::min(static_cast<size_t>(kMaxCUFFTNdim), working_axes.size()); std::min(static_cast<size_t>(kMaxFFTNdim), working_axes.size());
first_dims.assign(working_axes.end() - max_dims, working_axes.end()); first_dims.assign(working_axes.end() - max_dims, working_axes.end());
exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, p_working_tensor, exec_fft<platform::CUDADeviceContext, Ti, To>(ctx, p_working_tensor,
...@@ -418,7 +440,7 @@ struct FFTC2RFunctor<platform::CUDADeviceContext, Ti, To> { ...@@ -418,7 +440,7 @@ struct FFTC2RFunctor<platform::CUDADeviceContext, Ti, To> {
std::vector<int64_t> in_dims = framework::vectorize(X->dims()); std::vector<int64_t> in_dims = framework::vectorize(X->dims());
std::vector<int64_t> out_dims = framework::vectorize(out->dims()); std::vector<int64_t> out_dims = framework::vectorize(out->dims());
if (use_optimized_cufft_path(axes)) { if (use_optimized_fft_path(axes)) {
framework::Tensor x_copy(X->type()); framework::Tensor x_copy(X->type());
x_copy.mutable_data<Ti>(X->dims(), ctx.GetPlace()); x_copy.mutable_data<Ti>(X->dims(), ctx.GetPlace());
framework::TensorCopy(*X, ctx.GetPlace(), &x_copy); framework::TensorCopy(*X, ctx.GetPlace(), &x_copy);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册