未验证 提交 afd5a96b 编写于 作者: H HongyuJia 提交者: GitHub

opt conv_transpose cudnn (#47294)

上级 3e7abca5
...@@ -28,6 +28,9 @@ limitations under the License. */ ...@@ -28,6 +28,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -38,15 +41,11 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( ...@@ -38,15 +41,11 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::CanCUDNNBeUsed(ctx)) {
auto& dev_ctx = ctx.template device_context<phi::GPUContext>(); return framework::OpKernelType(data_type,
if (ctx.HasAttr("use_cudnn") && ctx.Attr<bool>("use_cudnn") && ctx.GetPlace(),
dev_ctx.cudnn_handle() != nullptr) { phi::DataLayout::kAnyLayout,
return framework::OpKernelType(data_type, framework::LibraryType::kCUDNN);
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
} }
#endif #endif
return framework::OpKernelType(data_type, ctx.GetPlace()); return framework::OpKernelType(data_type, ctx.GetPlace());
...@@ -268,28 +267,16 @@ Example: ...@@ -268,28 +267,16 @@ Example:
framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
bool use_cudnn = auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
ctx.HasAttr("use_cudnn") ? ctx.Attr<bool>("use_cudnn") : false;
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::CanCUDNNBeUsed(ctx)) {
auto& dev_ctx = ctx.template device_context<phi::GPUContext>(); return framework::OpKernelType(data_type,
use_cudnn &= dev_ctx.cudnn_handle() != nullptr; ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
} }
#endif #endif
framework::LibraryType library_; return framework::OpKernelType(data_type, ctx.GetPlace());
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
} else {
library_ = framework::LibraryType::kPlain;
}
phi::DataLayout layout_ = phi::DataLayout::kAnyLayout;
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.GetPlace(),
layout_,
library_);
} }
template <typename T> template <typename T>
...@@ -355,28 +342,16 @@ class ConvTransposeDoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -355,28 +342,16 @@ class ConvTransposeDoubleGradMaker : public framework::SingleGradOpMaker<T> {
framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType( framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
bool use_cudnn = auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
ctx.HasAttr("use_cudnn") ? ctx.Attr<bool>("use_cudnn") : false;
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::CanCUDNNBeUsed(ctx)) {
auto& dev_ctx = ctx.template device_context<phi::GPUContext>(); return framework::OpKernelType(data_type,
use_cudnn &= dev_ctx.cudnn_handle() != nullptr; ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
} }
#endif #endif
framework::LibraryType library_; return framework::OpKernelType(data_type, ctx.GetPlace());
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
} else {
library_ = framework::LibraryType::kPlain;
}
phi::DataLayout layout_ = phi::DataLayout::kAnyLayout;
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.GetPlace(),
layout_,
library_);
} }
} // namespace operators } // namespace operators
......
...@@ -617,7 +617,7 @@ class ScopedActivationDescriptor { ...@@ -617,7 +617,7 @@ class ScopedActivationDescriptor {
}; };
inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) { inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
bool use_cudnn = ctx.Attr<bool>("use_cudnn"); bool use_cudnn = ctx.HasAttr("use_cudnn") && ctx.Attr<bool>("use_cudnn");
use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace()); use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (use_cudnn) { if (use_cudnn) {
......
...@@ -554,7 +554,7 @@ class ScopedActivationDescriptor { ...@@ -554,7 +554,7 @@ class ScopedActivationDescriptor {
}; };
inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) { inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
bool use_cudnn = ctx.Attr<bool>("use_cudnn"); bool use_cudnn = ctx.HasAttr("use_cudnn") && ctx.Attr<bool>("use_cudnn");
use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace()); use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
if (use_cudnn) { if (use_cudnn) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册