未验证 提交 afd5a96b 编写于 作者: H HongyuJia 提交者: GitHub

opt conv_transpose cudnn (#47294)

上级 3e7abca5
......@@ -28,6 +28,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#endif
namespace paddle {
namespace operators {
......@@ -38,15 +41,11 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
if (ctx.HasAttr("use_cudnn") && ctx.Attr<bool>("use_cudnn") &&
dev_ctx.cudnn_handle() != nullptr) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace());
......@@ -268,28 +267,16 @@ Example:
framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
bool use_cudnn =
ctx.HasAttr("use_cudnn") ? ctx.Attr<bool>("use_cudnn") : false;
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
#endif
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
} else {
library_ = framework::LibraryType::kPlain;
}
phi::DataLayout layout_ = phi::DataLayout::kAnyLayout;
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.GetPlace(),
layout_,
library_);
return framework::OpKernelType(data_type, ctx.GetPlace());
}
template <typename T>
......@@ -355,28 +342,16 @@ class ConvTransposeDoubleGradMaker : public framework::SingleGradOpMaker<T> {
framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
bool use_cudnn =
ctx.HasAttr("use_cudnn") ? ctx.Attr<bool>("use_cudnn") : false;
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
#endif
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
} else {
library_ = framework::LibraryType::kPlain;
}
phi::DataLayout layout_ = phi::DataLayout::kAnyLayout;
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.GetPlace(),
layout_,
library_);
return framework::OpKernelType(data_type, ctx.GetPlace());
}
} // namespace operators
......
......@@ -617,7 +617,7 @@ class ScopedActivationDescriptor {
};
inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
bool use_cudnn = ctx.HasAttr("use_cudnn") && ctx.Attr<bool>("use_cudnn");
use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_CUDA
if (use_cudnn) {
......
......@@ -554,7 +554,7 @@ class ScopedActivationDescriptor {
};
inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
bool use_cudnn = ctx.HasAttr("use_cudnn") && ctx.Attr<bool>("use_cudnn");
use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_HIP
if (use_cudnn) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册