未验证 提交 a4d9851b 编写于 作者: H huangjiyi 提交者: GitHub

[PHI decoupling] move cuda_graph from fluid to phi (#48686)

* move cuda_graph from fluid to phi

* move device_memory_aligment from fluid to phi

* Revert "move device_memory_aligment from fluid to phi"

This reverts commit b92fcd39a0a50fdac13278f49be0237a85f3a13f.

* update xpu cmake
上级 91ff2071
......@@ -32,7 +32,7 @@ if(WITH_GPU OR WITH_ROCM)
endif()
if(WITH_GPU)
list(APPEND ALLOCATOR_DEPS cuda_graph)
list(APPEND ALLOCATOR_DEPS phi_backends)
endif()
if(CUDA_VERSION VERSION_GREATER_EQUAL 10.2)
......
......@@ -85,7 +85,7 @@ if(WITH_GPU)
nv_library(
cuda_graph_with_memory_pool
SRCS cuda_graph_with_memory_pool.cc
DEPS device_context allocator cuda_graph)
DEPS device_context allocator phi_backends)
else()
cc_library(
cuda_graph_with_memory_pool
......
nv_library(
cuda_graph
SRCS cuda_graph.cc
DEPS enforce)
nv_library(
cuda_profiler
SRCS cuda_profiler.cc
......
......@@ -14,45 +14,23 @@
#pragma once
#include <atomic>
#include <functional>
#include <memory>
#include <mutex>
#include <thread>
#include <vector>
#include "cuda.h" // NOLINT
#include "cuda_runtime.h" // NOLINT
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/utils/optional.h"
#include "paddle/phi/backends/gpu/cuda/cuda_graph.h"
namespace paddle {
namespace platform {
using CUDAKernelParams = phi::backends::gpu::CUDAKernelParams;
#if CUDA_VERSION < 10010
using cudaStreamCaptureMode = phi::backends::gpu::cudaStreamCaptureMode;
#endif
using CUDAGraph = phi::backends::gpu::CUDAGraph;
using CUDAGraphCaptureModeGuard = phi::backends::gpu::CUDAGraphCaptureModeGuard;
template <typename T>
static bool IsBitwiseEqual(const T &x, const T &y) {
return std::memcmp(&x, &y, sizeof(T)) == 0;
}
class CUDAKernelParams {
public:
explicit CUDAKernelParams(const cudaKernelNodeParams *params)
: params_(params) {}
const void *func() const { return params_->func; }
template <typename T>
T &As(size_t idx) const {
return *reinterpret_cast<T *>(params_->kernelParams[idx]);
}
private:
const cudaKernelNodeParams *params_;
};
template <typename F, F f>
struct IsSameKernelHelper;
......@@ -96,191 +74,5 @@ struct IsSameKernelHelper<Return (*)(FuncArgs...), kernel_fn> {
}
};
#if CUDA_VERSION >= 10010
static void ThrowErrorIfNotSupportCUDAGraph() {}
#else
enum cudaStreamCaptureMode {
cudaStreamCaptureModeGlobal = 0,
cudaStreamCaptureModeThreadLocal = 1,
cudaStreamCaptureModeRelaxed = 2
};
static void ThrowErrorIfNotSupportCUDAGraph() {
PADDLE_THROW(platform::errors::Unimplemented(
"CUDA Graph is only supported when CUDA version >= 10.1"));
}
#endif
// NOTE: Currently, we do not support to capture CUDA graph in parallel
// NOTE: Do not use this class directly because it should be used with
// the memory pool.
class CUDAGraph {
DISABLE_COPY_AND_ASSIGN(CUDAGraph);
// Since the constructor would throw error is CUDA_VERSION < 10010.
// The non-static method of CUDAGraph need not check CUDA_VERSION
// again.
CUDAGraph() {
ThrowErrorIfNotSupportCUDAGraph();
id_ = UniqueID();
}
public:
static constexpr int64_t kDefaultPoolID = 0;
static constexpr int64_t kInvalidPoolID = -1;
~CUDAGraph() { Reset(); }
CUDAGraphID ID() const { return id_; }
static int64_t SetMemoryPoolID(int64_t pool_id) {
auto &pool_id_ = capturing_graph_->pool_id_;
PADDLE_ENFORCE_EQ(
pool_id_,
kInvalidPoolID,
phi::errors::InvalidArgument("Cannot reset memory pool id twice, the "
"former memory pool id is %d.",
pool_id_));
if (pool_id <= kInvalidPoolID) {
pool_id_ = UniqueMemoryPoolID();
} else {
PADDLE_ENFORCE_GE(
pool_id,
kDefaultPoolID,
phi::errors::InvalidArgument("Invalid memory pool id %d.", pool_id));
pool_id_ = pool_id;
}
return pool_id_;
}
int64_t PoolID() const { return pool_id_; }
static int64_t CapturingPoolID() { return capturing_graph_->pool_id_; }
void Replay();
void Reset();
void AddResetCallback(std::function<void()> callback) {
std::lock_guard<std::mutex> guard(mtx_);
callbacks_.push_back(std::move(callback));
}
void PrintToDotFiles(const std::string &dirname, unsigned int flags);
static void BeginCapture(platform::CUDAPlace place,
cudaStream_t stream,
cudaStreamCaptureMode mode);
static std::unique_ptr<CUDAGraph> EndCapture();
static void BeginSegmentCapture();
static void EndSegmentCapture();
static void AddResetCallbackDuringCapturing(std::function<void()> callback) {
capturing_graph_->AddResetCallback(std::move(callback));
}
// No need to add CUDA_VERSION macro because capturing_graph_ would
// always be nullptr (constructor throws error)
static bool IsCapturing() { return capturing_graph_ != nullptr; }
static CUDAGraphID CapturingID() { return capturing_graph_->id_; }
static platform::CUDAPlace CapturingPlace() {
return capturing_graph_->place_;
}
// This API can be used to debug which GPU operation is not
// supported during capturing CUDA Graph.
static bool IsValidCapturing();
static bool IsThreadLocalCapturing() {
#if CUDA_VERSION >= 10010
return IsCapturing() &&
capturing_graph_->capture_mode_ == cudaStreamCaptureModeThreadLocal;
#else
return false;
#endif
}
static bool IsThisThreadCapturing() {
if (UNLIKELY(IsCapturing())) {
return IsThreadLocalCapturing()
? capturing_thread_id_.get() == std::this_thread::get_id()
: true;
} else {
return false;
}
}
using SetSeedFunc = std::function<bool(CUDAKernelParams *, bool)>;
static void RecordRandomKernelInfo(SetSeedFunc set_seed_func) {
std::lock_guard<std::mutex> guard(capturing_graph_->func_mtx_);
capturing_graph_->set_seed_funcs_.emplace_back(std::move(set_seed_func));
}
static int64_t UniqueMemoryPoolID();
private:
static CUDAGraphID UniqueID();
private:
#if CUDA_VERSION >= 10010
std::vector<cudaGraph_t> graphs_;
std::vector<cudaGraphExec_t> exec_graphs_;
cudaStreamCaptureMode capture_mode_;
#endif
cudaStream_t stream_{nullptr};
platform::CUDAPlace place_;
CUDAGraphID id_;
int64_t pool_id_{kInvalidPoolID};
std::vector<std::function<void()>> callbacks_;
bool is_reset_{false};
std::mutex mtx_;
std::vector<SetSeedFunc> set_seed_funcs_;
std::vector<std::vector<std::function<void(cudaGraphExec_t)>>> pre_hooks_;
std::mutex func_mtx_;
bool is_first_run_{true};
static paddle::optional<std::thread::id> capturing_thread_id_;
static std::unique_ptr<CUDAGraph> capturing_graph_;
};
#if CUDA_VERSION >= 10010
class CUDAGraphCaptureModeGuard {
DISABLE_COPY_AND_ASSIGN(CUDAGraphCaptureModeGuard);
public:
explicit CUDAGraphCaptureModeGuard(
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed) {
if (UNLIKELY(CUDAGraph::IsCapturing())) {
PADDLE_ENFORCE_GPU_SUCCESS(cudaThreadExchangeStreamCaptureMode(&mode));
// After cudaThreadExchangeStreamCaptureMode is called,
// the variable "mode" would be set to the old capturing mode.
old_mode_ = mode;
}
}
~CUDAGraphCaptureModeGuard() PADDLE_MAY_THROW {
if (UNLIKELY(CUDAGraph::IsCapturing())) {
PADDLE_ENFORCE_GPU_SUCCESS(
cudaThreadExchangeStreamCaptureMode(&old_mode_));
}
}
private:
cudaStreamCaptureMode old_mode_;
};
#else
class CUDAGraphCaptureModeGuard {
DISABLE_COPY_AND_ASSIGN(CUDAGraphCaptureModeGuard);
public:
explicit CUDAGraphCaptureModeGuard(
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed) {}
};
#endif
} // namespace platform
} // namespace paddle
......@@ -30,7 +30,7 @@ cc_library(
xpulib
device_context
op_kernel_type
phi_xpu_op_list)
phi_backends)
cc_library(
xpu_resource_pool
SRCS xpu_resource_pool.cc
......
......@@ -7,7 +7,7 @@ if(WITH_GPU OR WITH_ROCM)
list(APPEND BACKENDS_SRCS gpu/gpu_context.cc gpu/gpu_info.cc
gpu/gpu_resources.cc)
if(WITH_GPU)
list(APPEND BACKENDS_SRCS gpu/cuda/cuda_info.cc)
list(APPEND BACKENDS_SRCS gpu/cuda/cuda_info.cc gpu/cuda/cuda_graph.cc)
endif()
if(WITH_ROCM)
list(APPEND BACKENDS_SRCS gpu/rocm/rocm_info.cc)
......@@ -16,8 +16,9 @@ if(WITH_GPU OR WITH_ROCM)
endif()
if(WITH_XPU)
add_subdirectory(xpu)
list(APPEND BACKENDS_SRCS xpu/xpu_context.cc xpu/xpu_info.cc)
list(APPEND BACKENDS_SRCS xpu/xpu_op_list.cc xpu/xpu1_op_list.cc
xpu/xpu2_op_list.cc)
endif()
if(WITH_MKLDNN)
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,14 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h"
#include "paddle/phi/backends/gpu/cuda/cuda_graph.h"
#include <queue>
#include <unordered_map>
#include <unordered_set>
namespace paddle {
namespace platform {
namespace phi {
namespace backends {
namespace gpu {
std::unique_ptr<CUDAGraph> CUDAGraph::capturing_graph_{nullptr};
paddle::optional<std::thread::id> CUDAGraph::capturing_thread_id_{paddle::none};
......@@ -113,7 +114,7 @@ void CUDAGraph::Replay() {
#if CUDA_VERSION >= 10010
PADDLE_ENFORCE_EQ(is_reset_,
false,
errors::PermissionDenied(
phi::errors::PermissionDenied(
"Cannot replay the CUDA Graph after reset is called."));
size_t n = exec_graphs_.size();
for (size_t i = 0; i < n; ++i) {
......@@ -131,43 +132,43 @@ void CUDAGraph::Replay() {
void CUDAGraph::BeginSegmentCapture() {
ThrowErrorIfNotSupportCUDAGraph();
#if CUDA_VERSION >= 10010
PADDLE_ENFORCE_EQ(
IsCapturing(),
true,
errors::PermissionDenied("BeginSegmentCapture should be called when CUDA "
"Graph is capturing."));
PADDLE_ENFORCE_EQ(IsCapturing(),
true,
phi::errors::PermissionDenied(
"BeginSegmentCapture should be called when CUDA "
"Graph is capturing."));
if (IsThreadLocalCapturing()) {
PADDLE_ENFORCE_EQ(IsThisThreadCapturing(),
true,
platform::errors::PermissionDenied(
phi::errors::PermissionDenied(
"When capturing CUDA Graph in the thread local mode, "
"you cannot begin segmented capturing in the thread "
"which is not the one that starts the capturing."));
}
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamBeginCapture(
capturing_graph_->stream_, capturing_graph_->capture_mode_));
PADDLE_ENFORCE_EQ(IsValidCapturing(),
true,
platform::errors::PermissionDenied(
"CUDA Graph should not be invalidated."));
PADDLE_ENFORCE_EQ(
IsValidCapturing(),
true,
phi::errors::PermissionDenied("CUDA Graph should not be invalidated."));
VLOG(10) << "Begin to capture CUDA Graph with ID " << capturing_graph_->id_
<< ", segment id " << capturing_graph_->graphs_.size()
<< ", memory pool id " << capturing_graph_->pool_id_;
#endif
}
void CUDAGraph::BeginCapture(platform::CUDAPlace place,
void CUDAGraph::BeginCapture(phi::GPUPlace place,
cudaStream_t stream,
cudaStreamCaptureMode mode) {
ThrowErrorIfNotSupportCUDAGraph();
#if CUDA_VERSION >= 10010
PADDLE_ENFORCE_EQ(
IsCapturing(),
false,
errors::PermissionDenied("CUDA Graph can only captured one by one."));
PADDLE_ENFORCE_EQ(IsCapturing(),
false,
phi::errors::PermissionDenied(
"CUDA Graph can only captured one by one."));
PADDLE_ENFORCE_NOT_NULL(
stream,
errors::PermissionDenied(
phi::errors::PermissionDenied(
"CUDA Graph cannot be captured in default CUDA stream 0."));
capturing_graph_.reset(new CUDAGraph());
capturing_graph_->place_ = place;
......@@ -185,9 +186,10 @@ void CUDAGraph::BeginCapture(platform::CUDAPlace place,
void CUDAGraph::EndSegmentCapture() {
ThrowErrorIfNotSupportCUDAGraph();
#if CUDA_VERSION >= 10010
PADDLE_ENFORCE_EQ(IsCapturing(),
true,
errors::PermissionDenied("No CUDA Graph is capturing."));
PADDLE_ENFORCE_EQ(
IsCapturing(),
true,
phi::errors::PermissionDenied("No CUDA Graph is capturing."));
cudaGraph_t graph;
PADDLE_ENFORCE_GPU_SUCCESS(
cudaStreamEndCapture(capturing_graph_->stream_, &graph));
......@@ -299,11 +301,12 @@ void CUDAGraph::PrintToDotFiles(const std::string &dirname,
cudaGraphDebugDotPrint(graphs_[i], filename.c_str(), flags));
}
#else
PADDLE_THROW(platform::errors::Unimplemented(
PADDLE_THROW(phi::errors::Unimplemented(
"The print_to_dot_files() method is only supported when CUDA version >= "
"11.3."));
#endif
}
} // namespace platform
} // namespace paddle
} // namespace gpu
} // namespace backends
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <functional>
#include <memory>
#include <mutex>
#include <thread>
#include <vector>
#include "cuda.h" // NOLINT
#include "cuda_runtime.h" // NOLINT
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/macros.h"
#include "paddle/utils/optional.h"
namespace phi {
namespace backends {
namespace gpu {
class CUDAKernelParams {
public:
explicit CUDAKernelParams(const cudaKernelNodeParams *params)
: params_(params) {}
const void *func() const { return params_->func; }
template <typename T>
T &As(size_t idx) const {
return *reinterpret_cast<T *>(params_->kernelParams[idx]);
}
private:
const cudaKernelNodeParams *params_;
};
#if CUDA_VERSION >= 10010
static void ThrowErrorIfNotSupportCUDAGraph() {}
#else
enum cudaStreamCaptureMode {
cudaStreamCaptureModeGlobal = 0,
cudaStreamCaptureModeThreadLocal = 1,
cudaStreamCaptureModeRelaxed = 2
};
static void ThrowErrorIfNotSupportCUDAGraph() {
PADDLE_THROW(phi::errors::Unimplemented(
"CUDA Graph is only supported when CUDA version >= 10.1"));
}
#endif
using CUDAGraphID = unsigned long long; // NOLINT
// NOTE: Currently, we do not support to capture CUDA graph in parallel
// NOTE: Do not use this class directly because it should be used with
// the memory pool.
class CUDAGraph {
DISABLE_COPY_AND_ASSIGN(CUDAGraph);
// Since the constructor would throw error is CUDA_VERSION < 10010.
// The non-static method of CUDAGraph need not check CUDA_VERSION
// again.
CUDAGraph() {
ThrowErrorIfNotSupportCUDAGraph();
id_ = UniqueID();
}
public:
static constexpr int64_t kDefaultPoolID = 0;
static constexpr int64_t kInvalidPoolID = -1;
~CUDAGraph() { Reset(); }
CUDAGraphID ID() const { return id_; }
static int64_t SetMemoryPoolID(int64_t pool_id) {
auto &pool_id_ = capturing_graph_->pool_id_;
PADDLE_ENFORCE_EQ(
pool_id_,
kInvalidPoolID,
phi::errors::InvalidArgument("Cannot reset memory pool id twice, the "
"former memory pool id is %d.",
pool_id_));
if (pool_id <= kInvalidPoolID) {
pool_id_ = UniqueMemoryPoolID();
} else {
PADDLE_ENFORCE_GE(
pool_id,
kDefaultPoolID,
phi::errors::InvalidArgument("Invalid memory pool id %d.", pool_id));
pool_id_ = pool_id;
}
return pool_id_;
}
int64_t PoolID() const { return pool_id_; }
static int64_t CapturingPoolID() { return capturing_graph_->pool_id_; }
void Replay();
void Reset();
void AddResetCallback(std::function<void()> callback) {
std::lock_guard<std::mutex> guard(mtx_);
callbacks_.push_back(std::move(callback));
}
void PrintToDotFiles(const std::string &dirname, unsigned int flags);
static void BeginCapture(phi::GPUPlace place,
cudaStream_t stream,
cudaStreamCaptureMode mode);
static std::unique_ptr<CUDAGraph> EndCapture();
static void BeginSegmentCapture();
static void EndSegmentCapture();
static void AddResetCallbackDuringCapturing(std::function<void()> callback) {
capturing_graph_->AddResetCallback(std::move(callback));
}
// No need to add CUDA_VERSION macro because capturing_graph_ would
// always be nullptr (constructor throws error)
static bool IsCapturing() { return capturing_graph_ != nullptr; }
static CUDAGraphID CapturingID() { return capturing_graph_->id_; }
static phi::GPUPlace CapturingPlace() { return capturing_graph_->place_; }
// This API can be used to debug which GPU operation is not
// supported during capturing CUDA Graph.
static bool IsValidCapturing();
static bool IsThreadLocalCapturing() {
#if CUDA_VERSION >= 10010
return IsCapturing() &&
capturing_graph_->capture_mode_ == cudaStreamCaptureModeThreadLocal;
#else
return false;
#endif
}
static bool IsThisThreadCapturing() {
if (UNLIKELY(IsCapturing())) {
return IsThreadLocalCapturing()
? capturing_thread_id_.get() == std::this_thread::get_id()
: true;
} else {
return false;
}
}
using SetSeedFunc = std::function<bool(CUDAKernelParams *, bool)>;
static void RecordRandomKernelInfo(SetSeedFunc set_seed_func) {
std::lock_guard<std::mutex> guard(capturing_graph_->func_mtx_);
capturing_graph_->set_seed_funcs_.emplace_back(std::move(set_seed_func));
}
static int64_t UniqueMemoryPoolID();
private:
static CUDAGraphID UniqueID();
private:
#if CUDA_VERSION >= 10010
std::vector<cudaGraph_t> graphs_;
std::vector<cudaGraphExec_t> exec_graphs_;
cudaStreamCaptureMode capture_mode_;
#endif
cudaStream_t stream_{nullptr};
phi::GPUPlace place_;
CUDAGraphID id_;
int64_t pool_id_{kInvalidPoolID};
std::vector<std::function<void()>> callbacks_;
bool is_reset_{false};
std::mutex mtx_;
std::vector<SetSeedFunc> set_seed_funcs_;
std::vector<std::vector<std::function<void(cudaGraphExec_t)>>> pre_hooks_;
std::mutex func_mtx_;
bool is_first_run_{true};
static paddle::optional<std::thread::id> capturing_thread_id_;
static std::unique_ptr<CUDAGraph> capturing_graph_;
};
#if CUDA_VERSION >= 10010
class CUDAGraphCaptureModeGuard {
DISABLE_COPY_AND_ASSIGN(CUDAGraphCaptureModeGuard);
public:
explicit CUDAGraphCaptureModeGuard(
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed) {
if (UNLIKELY(CUDAGraph::IsCapturing())) {
PADDLE_ENFORCE_GPU_SUCCESS(cudaThreadExchangeStreamCaptureMode(&mode));
// After cudaThreadExchangeStreamCaptureMode is called,
// the variable "mode" would be set to the old capturing mode.
old_mode_ = mode;
}
}
~CUDAGraphCaptureModeGuard() PADDLE_MAY_THROW {
if (UNLIKELY(CUDAGraph::IsCapturing())) {
PADDLE_ENFORCE_GPU_SUCCESS(
cudaThreadExchangeStreamCaptureMode(&old_mode_));
}
}
private:
cudaStreamCaptureMode old_mode_;
};
#else
class CUDAGraphCaptureModeGuard {
DISABLE_COPY_AND_ASSIGN(CUDAGraphCaptureModeGuard);
public:
explicit CUDAGraphCaptureModeGuard(
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed) {}
};
#endif
} // namespace gpu
} // namespace backends
} // namespace phi
cc_library(
phi_xpu_op_list
SRCS xpu_op_list.cc xpu1_op_list.cc xpu2_op_list.cc
DEPS glog)
......@@ -19,7 +19,7 @@ if(WITH_XPU)
cc_library(
kernel_factory
SRCS kernel_factory.cc
DEPS phi_enforce convert_utils phi_xpu_op_list)
DEPS phi_enforce convert_utils phi_backends)
else()
cc_library(
kernel_factory
......
......@@ -15,7 +15,7 @@
#include "paddle/phi/core/device_context.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h"
#include "paddle/phi/backends/gpu/cuda/cuda_graph.h"
#endif
#include "paddle/phi/core/dense_tensor.h"
......@@ -153,8 +153,9 @@ struct DeviceContext::Impl {
: (pinned ? pinned_allocator_ : device_allocator_);
#ifdef PADDLE_WITH_CUDA
bool must_cuda_graph_allocator = (tensor->numel() != 0) && !pinned;
if (must_cuda_graph_allocator && paddle::platform::is_gpu_place(place) &&
paddle::platform::CUDAGraph::IsThisThreadCapturing()) {
if (must_cuda_graph_allocator &&
place.GetType() == phi::AllocationType::GPU &&
phi::backends::gpu::CUDAGraph::IsThisThreadCapturing()) {
PADDLE_ENFORCE_NOT_NULL(cuda_graph_allocator_,
phi::errors::InvalidArgument(
"Required cuda_graph_allocator_ shall not be "
......
......@@ -465,7 +465,7 @@ struct SearchAlgorithmBase<ConvKind::kBackwardFilter> {
static size_t GetWorkspaceSize(const ConvArgs& args,
cudnnConvolutionBwdFilterAlgo_t algo) {
paddle::platform::CUDAGraphCaptureModeGuard guard;
phi::backends::gpu::CUDAGraphCaptureModeGuard guard;
size_t workspace_size = 0;
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册