未验证 提交 f9134045 编写于 作者: H HongyuJia 提交者: GitHub

[Kernel Selection] Remove hard code of PADDLE_WITH_CUDA (#47325)

* move cudnn hardcode outside GetExpectedKernelType

* add header file

* debug

* update interpreter_util with hardcode

* update interpreter_util headerfile

* solve activation hardcode

* debug with CI

* add mkldnn_op_list header file

* temporarily uncomment mkldnn

* temporarily uncomment mkldnn

* delete sequence_softmax cudnn hardcode

* add hardcode to data_transfer.cc

* update data_transfer headerfile

* try fix segment fault

* update cudnn&miopen_helper

* reset HasAttr of DygraphExctnCtx

* debug, this commit should pass all CI

* debug should pass CI, temporarily disable activation

* debug should pass CI

* fix default_attr=nullptr bug

* clean debug code
上级 db323927
...@@ -22,6 +22,9 @@ ...@@ -22,6 +22,9 @@
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/phi/backends/onednn/onednn_context.h" #include "paddle/phi/backends/onednn/onednn_context.h"
#endif #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -133,6 +136,12 @@ void DataTranferHelper::RunAndConstructOpFuncNode( ...@@ -133,6 +136,12 @@ void DataTranferHelper::RunAndConstructOpFuncNode(
auto* dev_ctx = pool.Get(place_); auto* dev_ctx = pool.Get(place_);
auto exec_ctx = ExecutionContext(*op, Scope(), *dev_ctx, runtime_context); auto exec_ctx = ExecutionContext(*op, Scope(), *dev_ctx, runtime_context);
auto expected_kernel_key = op_with_kernel->GetExpectedKernelType(exec_ctx); auto expected_kernel_key = op_with_kernel->GetExpectedKernelType(exec_ctx);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (!op_with_kernel->DnnFallback() &&
paddle::platform::CanCUDNNBeUsed(exec_ctx)) {
expected_kernel_key.library_type_ = framework::LibraryType::kCUDNN;
}
#endif
VLOG(6) << "expected_kernel_key " << expected_kernel_key << "\n"; VLOG(6) << "expected_kernel_key " << expected_kernel_key << "\n";
VLOG(6) << "op_with_kernel Type() " << op_with_kernel->Type() << "\n"; VLOG(6) << "op_with_kernel Type() " << op_with_kernel->Type() << "\n";
......
...@@ -32,6 +32,9 @@ ...@@ -32,6 +32,9 @@
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#endif
PADDLE_DEFINE_EXPORTED_bool( PADDLE_DEFINE_EXPORTED_bool(
new_executor_serial_run, new_executor_serial_run,
...@@ -615,6 +618,12 @@ void BuildOpFuncList(const platform::Place& place, ...@@ -615,6 +618,12 @@ void BuildOpFuncList(const platform::Place& place,
*op_with_kernel, *runtime_scope, *dev_ctx, runtime_context); *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context);
auto expected_kernel_key = auto expected_kernel_key =
op_with_kernel->GetExpectedKernelType(exec_ctx); op_with_kernel->GetExpectedKernelType(exec_ctx);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (!op_with_kernel->DnnFallback() &&
paddle::platform::CanCUDNNBeUsed(exec_ctx)) {
expected_kernel_key.library_type_ = framework::LibraryType::kCUDNN;
}
#endif
VLOG(4) << "expected_kernel_key : " << expected_kernel_key; VLOG(4) << "expected_kernel_key : " << expected_kernel_key;
// change device by the device_guard() // change device by the device_guard()
ApplyDeviceGuard(op, place, &expected_kernel_key); ApplyDeviceGuard(op, place, &expected_kernel_key);
......
...@@ -58,6 +58,10 @@ class DenseTensor; ...@@ -58,6 +58,10 @@ class DenseTensor;
#include "paddle/fluid/platform/device/mlu/mlu_info.h" #include "paddle/fluid/platform/device/mlu/mlu_info.h"
#endif #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#endif
DECLARE_bool(benchmark); DECLARE_bool(benchmark);
DECLARE_bool(check_nan_inf); DECLARE_bool(check_nan_inf);
DECLARE_bool(enable_unused_var_check); DECLARE_bool(enable_unused_var_check);
...@@ -1409,6 +1413,14 @@ bool OperatorWithKernel::SupportsKernelType( ...@@ -1409,6 +1413,14 @@ bool OperatorWithKernel::SupportsKernelType(
} }
#endif #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (!this->DnnFallback() && paddle::platform::CanCUDNNBeUsed(exe_ctx)) {
auto tmp_kernel_type = kernel_type;
tmp_kernel_type.library_type_ = framework::LibraryType::kCUDNN;
return kernels.find(tmp_kernel_type) != kernels.end();
}
#endif
return kernel_iter != kernels.end(); return kernel_iter != kernels.end();
} }
...@@ -1589,6 +1601,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1589,6 +1601,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
} }
#endif #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (!this->DnnFallback() && paddle::platform::CanCUDNNBeUsed(exe_ctx)) {
kernel_type_->library_type_ = framework::LibraryType::kCUDNN;
}
#endif
// NOTE(Liu-xiandong):In my ctest, this branch do not be executed, // NOTE(Liu-xiandong):In my ctest, this branch do not be executed,
// I can't understand it, it's really confusing. // I can't understand it, it's really confusing.
// But we still need to keep this to avoid errors. // But we still need to keep this to avoid errors.
...@@ -1832,6 +1850,12 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( ...@@ -1832,6 +1850,12 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
} }
#endif #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (!this->DnnFallback() && paddle::platform::CanCUDNNBeUsed(ctx)) {
expected_kernel_key.library_type_ = framework::LibraryType::kCUDNN;
}
#endif
if (HasAttr("op_device")) { if (HasAttr("op_device")) {
if (Attr<std::string>("op_device") == "cpu") { if (Attr<std::string>("op_device") == "cpu") {
expected_kernel_key.place_ = platform::CPUPlace(); expected_kernel_key.place_ = platform::CPUPlace();
......
...@@ -103,7 +103,8 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -103,7 +103,8 @@ class DygraphExecutionContext : public framework::ExecutionContext {
bool HasAttr(const std::string& name) const override { bool HasAttr(const std::string& name) const override {
if (attrs_.find(name) == attrs_.end()) { if (attrs_.find(name) == attrs_.end()) {
return default_attrs_.find(name) != default_attrs_.end(); return &default_attrs_ != nullptr &&
default_attrs_.find(name) != default_attrs_.end();
} }
return true; return true;
} }
......
...@@ -28,6 +28,9 @@ ...@@ -28,6 +28,9 @@
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_op_list.h" #include "paddle/fluid/platform/mkldnn_op_list.h"
#endif #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#endif
#include "paddle/fluid/framework/library_type.h" #include "paddle/fluid/framework/library_type.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/fluid/platform/profiler/event_tracing.h"
...@@ -246,6 +249,12 @@ PreparedOp PrepareImpl( ...@@ -246,6 +249,12 @@ PreparedOp PrepareImpl(
} }
#endif #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (!op.DnnFallback() && paddle::platform::CanCUDNNBeUsed(dygraph_exe_ctx)) {
expected_kernel_key.library_type_ = framework::LibraryType::kCUDNN;
}
#endif
#if defined(PADDLE_WITH_XPU) #if defined(PADDLE_WITH_XPU)
bool is_xpu_unsupport = bool is_xpu_unsupport =
paddle::platform::is_xpu_place(expected_kernel_key.place_) && paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
......
...@@ -93,6 +93,14 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, ...@@ -93,6 +93,14 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
// library = framework::LibraryType::kCUDNN; // library = framework::LibraryType::kCUDNN;
// } // }
// #endif // #endif
// NOTE(jiahongyu): Activation ops have attribute use_cudnn, but cudnn kernels
// are temporarily disabled. Therefore, cudnn kernel also needs to fallback to
// plain GPU kernel temporarily. When above codes are uncommented, below
// fallback codes can be deleted safely.
if (paddle::platform::is_gpu_place(ctx.GetPlace())) {
oper.SetDnnFallback(true);
}
return framework::OpKernelType(data_type, ctx.GetPlace()); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
......
...@@ -134,15 +134,8 @@ class AffineGridOp : public framework::OperatorWithKernel { ...@@ -134,15 +134,8 @@ class AffineGridOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
library = framework::LibraryType::kCUDNN;
}
#endif
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Theta"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Theta");
return framework::OpKernelType( return framework::OpKernelType(data_type, ctx.GetPlace());
data_type, ctx.GetPlace(), phi::DataLayout::kAnyLayout, library);
} }
}; };
...@@ -252,17 +245,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { ...@@ -252,17 +245,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain}; auto data_type = OperatorWithKernel::IndicateVarDataType(
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) ctx, framework::GradVarName("Output"));
if (platform::CanCUDNNBeUsed(ctx)) { return framework::OpKernelType(data_type, ctx.GetPlace());
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Output")),
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
library_);
} }
}; };
......
...@@ -28,9 +28,6 @@ limitations under the License. */ ...@@ -28,9 +28,6 @@ limitations under the License. */
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -40,14 +37,6 @@ using DataLayout = phi::DataLayout; ...@@ -40,14 +37,6 @@ using DataLayout = phi::DataLayout;
framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace()); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
...@@ -268,14 +257,6 @@ Example: ...@@ -268,14 +257,6 @@ Example:
framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace()); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
...@@ -343,14 +324,6 @@ class ConvTransposeDoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -343,14 +324,6 @@ class ConvTransposeDoubleGradMaker : public framework::SingleGradOpMaker<T> {
framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType( framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace()); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
......
...@@ -35,17 +35,8 @@ class GridSampleOp : public framework::OperatorWithKernel { ...@@ -35,17 +35,8 @@ class GridSampleOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain}; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) return framework::OpKernelType(data_type, ctx.GetPlace());
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
library_);
} }
}; };
...@@ -146,17 +137,8 @@ class GridSampleOpGrad : public framework::OperatorWithKernel { ...@@ -146,17 +137,8 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain}; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) return framework::OpKernelType(data_type, ctx.GetPlace());
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
library_);
} }
}; };
......
...@@ -44,21 +44,13 @@ bool CanMKLDNNSupportPool(const framework::ExecutionContext& ctx) { ...@@ -44,21 +44,13 @@ bool CanMKLDNNSupportPool(const framework::ExecutionContext& ctx) {
framework::OpKernelType PoolOp::GetExpectedKernelType( framework::OpKernelType PoolOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
phi::DataLayout layout_ = phi::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
// NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN // NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN
this->SetDnnFallback(!CanMKLDNNSupportPool(ctx)); this->SetDnnFallback(!CanMKLDNNSupportPool(ctx));
// NOTE(jiahongyu) END: Above codes originally enclosed by PADDLE_WITH_MKLDNN // NOTE(jiahongyu) END: Above codes originally enclosed by PADDLE_WITH_MKLDNN
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, library_); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
framework::OpKernelType PoolOp::GetKernelTypeForVar( framework::OpKernelType PoolOp::GetKernelTypeForVar(
...@@ -86,22 +78,13 @@ framework::OpKernelType PoolOp::GetKernelTypeForVar( ...@@ -86,22 +78,13 @@ framework::OpKernelType PoolOp::GetKernelTypeForVar(
framework::OpKernelType PoolOpGrad::GetExpectedKernelType( framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
phi::DataLayout layout_ = phi::DataLayout::kAnyLayout;
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
// NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN // NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN
this->SetDnnFallback(!CanMKLDNNSupportPool(ctx)); this->SetDnnFallback(!CanMKLDNNSupportPool(ctx));
// NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN // NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN
return framework::OpKernelType( return framework::OpKernelType(input_data_type, ctx.GetPlace());
input_data_type, ctx.GetPlace(), layout_, library_);
} }
framework::OpKernelType PoolOpGrad::GetKernelTypeForVar( framework::OpKernelType PoolOpGrad::GetKernelTypeForVar(
......
...@@ -43,14 +43,6 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { ...@@ -43,14 +43,6 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
if (ctx.HasAttr("data_format")) { if (ctx.HasAttr("data_format")) {
layout_ = phi::StringToDataLayout(ctx.Attr<std::string>("data_format")); layout_ = phi::StringToDataLayout(ctx.Attr<std::string>("data_format"));
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
layout_,
framework::LibraryType::kCUDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_);
} }
}; };
...@@ -135,14 +127,6 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { ...@@ -135,14 +127,6 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
if (ctx.HasAttr("data_format")) { if (ctx.HasAttr("data_format")) {
layout_ = phi::StringToDataLayout(ctx.Attr<std::string>("data_format")); layout_ = phi::StringToDataLayout(ctx.Attr<std::string>("data_format"));
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
layout_,
framework::LibraryType::kCUDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_);
} }
}; };
......
...@@ -48,14 +48,6 @@ class SoftmaxOp : public framework::OperatorWithKernel { ...@@ -48,14 +48,6 @@ class SoftmaxOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"float16 can only be used on GPU/NPU/XPU/MLU and custom place")); "float16 can only be used on GPU/NPU/XPU/MLU and custom place"));
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
layout_,
framework::LibraryType::kCUDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_);
} }
}; };
...@@ -140,14 +132,6 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -140,14 +132,6 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"float16 can only be used on GPU/NPU/XPU/MLU and custom place")); "float16 can only be used on GPU/NPU/XPU/MLU and custom place"));
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
layout_,
framework::LibraryType::kCUDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_);
} }
}; };
......
...@@ -617,8 +617,8 @@ class ScopedActivationDescriptor { ...@@ -617,8 +617,8 @@ class ScopedActivationDescriptor {
}; };
inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) { inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
bool use_cudnn = ctx.HasAttr("use_cudnn") && ctx.Attr<bool>("use_cudnn"); bool use_cudnn = paddle::platform::is_gpu_place(ctx.GetPlace()) &&
use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace()); ctx.HasAttr("use_cudnn") && ctx.Attr<bool>("use_cudnn");
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (use_cudnn) { if (use_cudnn) {
auto& dev_ctx = ctx.device_context<phi::GPUContext>(); auto& dev_ctx = ctx.device_context<phi::GPUContext>();
......
...@@ -554,8 +554,8 @@ class ScopedActivationDescriptor { ...@@ -554,8 +554,8 @@ class ScopedActivationDescriptor {
}; };
inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) { inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
bool use_cudnn = ctx.HasAttr("use_cudnn") && ctx.Attr<bool>("use_cudnn"); bool use_cudnn = paddle::platform::is_gpu_place(ctx.GetPlace()) &&
use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace()); ctx.HasAttr("use_cudnn") && ctx.Attr<bool>("use_cudnn");
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
if (use_cudnn) { if (use_cudnn) {
auto& dev_ctx = ctx.device_context<phi::GPUContext>(); auto& dev_ctx = ctx.device_context<phi::GPUContext>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册