提交 3a53be00 编写于 作者: H HongyuJia 提交者: GitHub

Revert "[Kernel Selection] Remove hard code of PADDLE_WITH_CUDA (#47325)"

This reverts commit f9134045.
上级 5ed487bf
...@@ -22,9 +22,6 @@ ...@@ -22,9 +22,6 @@
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/phi/backends/onednn/onednn_context.h" #include "paddle/phi/backends/onednn/onednn_context.h"
#endif #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -136,12 +133,6 @@ void DataTranferHelper::RunAndConstructOpFuncNode( ...@@ -136,12 +133,6 @@ void DataTranferHelper::RunAndConstructOpFuncNode(
auto* dev_ctx = pool.Get(place_); auto* dev_ctx = pool.Get(place_);
auto exec_ctx = ExecutionContext(*op, Scope(), *dev_ctx, runtime_context); auto exec_ctx = ExecutionContext(*op, Scope(), *dev_ctx, runtime_context);
auto expected_kernel_key = op_with_kernel->GetExpectedKernelType(exec_ctx); auto expected_kernel_key = op_with_kernel->GetExpectedKernelType(exec_ctx);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (!op_with_kernel->DnnFallback() &&
paddle::platform::CanCUDNNBeUsed(exec_ctx)) {
expected_kernel_key.library_type_ = framework::LibraryType::kCUDNN;
}
#endif
VLOG(6) << "expected_kernel_key " << expected_kernel_key << "\n"; VLOG(6) << "expected_kernel_key " << expected_kernel_key << "\n";
VLOG(6) << "op_with_kernel Type() " << op_with_kernel->Type() << "\n"; VLOG(6) << "op_with_kernel Type() " << op_with_kernel->Type() << "\n";
......
...@@ -32,9 +32,6 @@ ...@@ -32,9 +32,6 @@
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#endif
PADDLE_DEFINE_EXPORTED_bool( PADDLE_DEFINE_EXPORTED_bool(
new_executor_serial_run, new_executor_serial_run,
...@@ -621,12 +618,6 @@ void BuildOpFuncList(const platform::Place& place, ...@@ -621,12 +618,6 @@ void BuildOpFuncList(const platform::Place& place,
*op_with_kernel, *runtime_scope, *dev_ctx, runtime_context); *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context);
auto expected_kernel_key = auto expected_kernel_key =
op_with_kernel->GetExpectedKernelType(exec_ctx); op_with_kernel->GetExpectedKernelType(exec_ctx);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (!op_with_kernel->DnnFallback() &&
paddle::platform::CanCUDNNBeUsed(exec_ctx)) {
expected_kernel_key.library_type_ = framework::LibraryType::kCUDNN;
}
#endif
VLOG(4) << "expected_kernel_key : " << expected_kernel_key; VLOG(4) << "expected_kernel_key : " << expected_kernel_key;
// change device by the device_guard() // change device by the device_guard()
ApplyDeviceGuard(op, place, &expected_kernel_key); ApplyDeviceGuard(op, place, &expected_kernel_key);
......
...@@ -58,10 +58,6 @@ class DenseTensor; ...@@ -58,10 +58,6 @@ class DenseTensor;
#include "paddle/fluid/platform/device/mlu/mlu_info.h" #include "paddle/fluid/platform/device/mlu/mlu_info.h"
#endif #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#endif
DECLARE_bool(benchmark); DECLARE_bool(benchmark);
DECLARE_bool(check_nan_inf); DECLARE_bool(check_nan_inf);
DECLARE_bool(enable_unused_var_check); DECLARE_bool(enable_unused_var_check);
...@@ -1413,14 +1409,6 @@ bool OperatorWithKernel::SupportsKernelType( ...@@ -1413,14 +1409,6 @@ bool OperatorWithKernel::SupportsKernelType(
} }
#endif #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (!this->DnnFallback() && paddle::platform::CanCUDNNBeUsed(exe_ctx)) {
auto tmp_kernel_type = kernel_type;
tmp_kernel_type.library_type_ = framework::LibraryType::kCUDNN;
return kernels.find(tmp_kernel_type) != kernels.end();
}
#endif
return kernel_iter != kernels.end(); return kernel_iter != kernels.end();
} }
...@@ -1601,12 +1589,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1601,12 +1589,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
} }
#endif #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (!this->DnnFallback() && paddle::platform::CanCUDNNBeUsed(exe_ctx)) {
kernel_type_->library_type_ = framework::LibraryType::kCUDNN;
}
#endif
// NOTE(Liu-xiandong):In my ctest, this branch do not be executed, // NOTE(Liu-xiandong):In my ctest, this branch do not be executed,
// I can't understand it, it's really confusing. // I can't understand it, it's really confusing.
// But we still need to keep this to avoid errors. // But we still need to keep this to avoid errors.
...@@ -1850,12 +1832,6 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( ...@@ -1850,12 +1832,6 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
} }
#endif #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (!this->DnnFallback() && paddle::platform::CanCUDNNBeUsed(ctx)) {
expected_kernel_key.library_type_ = framework::LibraryType::kCUDNN;
}
#endif
if (HasAttr("op_device")) { if (HasAttr("op_device")) {
if (Attr<std::string>("op_device") == "cpu") { if (Attr<std::string>("op_device") == "cpu") {
expected_kernel_key.place_ = platform::CPUPlace(); expected_kernel_key.place_ = platform::CPUPlace();
......
...@@ -103,8 +103,7 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -103,8 +103,7 @@ class DygraphExecutionContext : public framework::ExecutionContext {
bool HasAttr(const std::string& name) const override { bool HasAttr(const std::string& name) const override {
if (attrs_.find(name) == attrs_.end()) { if (attrs_.find(name) == attrs_.end()) {
return &default_attrs_ != nullptr && return default_attrs_.find(name) != default_attrs_.end();
default_attrs_.find(name) != default_attrs_.end();
} }
return true; return true;
} }
......
...@@ -28,9 +28,6 @@ ...@@ -28,9 +28,6 @@
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_op_list.h" #include "paddle/fluid/platform/mkldnn_op_list.h"
#endif #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#endif
#include "paddle/fluid/framework/library_type.h" #include "paddle/fluid/framework/library_type.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/fluid/platform/profiler/event_tracing.h"
...@@ -249,12 +246,6 @@ PreparedOp PrepareImpl( ...@@ -249,12 +246,6 @@ PreparedOp PrepareImpl(
} }
#endif #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (!op.DnnFallback() && paddle::platform::CanCUDNNBeUsed(dygraph_exe_ctx)) {
expected_kernel_key.library_type_ = framework::LibraryType::kCUDNN;
}
#endif
#if defined(PADDLE_WITH_XPU) #if defined(PADDLE_WITH_XPU)
bool is_xpu_unsupport = bool is_xpu_unsupport =
paddle::platform::is_xpu_place(expected_kernel_key.place_) && paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
......
...@@ -93,14 +93,6 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, ...@@ -93,14 +93,6 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
// library = framework::LibraryType::kCUDNN; // library = framework::LibraryType::kCUDNN;
// } // }
// #endif // #endif
// NOTE(jiahongyu): Activation ops have attribute use_cudnn, but cudnn kernels
// are temporarily disabled. Therefore, cudnn kernel also needs to fallback to
// plain GPU kernel temporarily. When above codes are uncommented, below
// fallback codes can be deleted safely.
if (paddle::platform::is_gpu_place(ctx.GetPlace())) {
oper.SetDnnFallback(true);
}
return framework::OpKernelType(data_type, ctx.GetPlace()); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
......
...@@ -134,8 +134,15 @@ class AffineGridOp : public framework::OperatorWithKernel { ...@@ -134,8 +134,15 @@ class AffineGridOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
library = framework::LibraryType::kCUDNN;
}
#endif
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Theta"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Theta");
return framework::OpKernelType(data_type, ctx.GetPlace()); return framework::OpKernelType(
data_type, ctx.GetPlace(), phi::DataLayout::kAnyLayout, library);
} }
}; };
...@@ -245,9 +252,17 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { ...@@ -245,9 +252,17 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType( framework::LibraryType library_{framework::LibraryType::kPlain};
ctx, framework::GradVarName("Output")); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
return framework::OpKernelType(data_type, ctx.GetPlace()); if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Output")),
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
library_);
} }
}; };
......
...@@ -28,6 +28,9 @@ limitations under the License. */ ...@@ -28,6 +28,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -37,6 +40,14 @@ using DataLayout = phi::DataLayout; ...@@ -37,6 +40,14 @@ using DataLayout = phi::DataLayout;
framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace()); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
...@@ -257,6 +268,14 @@ Example: ...@@ -257,6 +268,14 @@ Example:
framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace()); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
...@@ -324,6 +343,14 @@ class ConvTransposeDoubleGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -324,6 +343,14 @@ class ConvTransposeDoubleGradMaker : public framework::SingleGradOpMaker<T> {
framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType( framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace()); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
......
...@@ -35,8 +35,17 @@ class GridSampleOp : public framework::OperatorWithKernel { ...@@ -35,8 +35,17 @@ class GridSampleOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::LibraryType library_{framework::LibraryType::kPlain};
return framework::OpKernelType(data_type, ctx.GetPlace()); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
library_);
} }
}; };
...@@ -137,8 +146,17 @@ class GridSampleOpGrad : public framework::OperatorWithKernel { ...@@ -137,8 +146,17 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::LibraryType library_{framework::LibraryType::kPlain};
return framework::OpKernelType(data_type, ctx.GetPlace()); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
library_);
} }
}; };
......
...@@ -44,13 +44,21 @@ bool CanMKLDNNSupportPool(const framework::ExecutionContext& ctx) { ...@@ -44,13 +44,21 @@ bool CanMKLDNNSupportPool(const framework::ExecutionContext& ctx) {
framework::OpKernelType PoolOp::GetExpectedKernelType( framework::OpKernelType PoolOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
phi::DataLayout layout_ = phi::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
// NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN // NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN
this->SetDnnFallback(!CanMKLDNNSupportPool(ctx)); this->SetDnnFallback(!CanMKLDNNSupportPool(ctx));
// NOTE(jiahongyu) END: Above codes originally enclosed by PADDLE_WITH_MKLDNN // NOTE(jiahongyu) END: Above codes originally enclosed by PADDLE_WITH_MKLDNN
return framework::OpKernelType(data_type, ctx.GetPlace()); return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, library_);
} }
framework::OpKernelType PoolOp::GetKernelTypeForVar( framework::OpKernelType PoolOp::GetKernelTypeForVar(
...@@ -78,13 +86,22 @@ framework::OpKernelType PoolOp::GetKernelTypeForVar( ...@@ -78,13 +86,22 @@ framework::OpKernelType PoolOp::GetKernelTypeForVar(
framework::OpKernelType PoolOpGrad::GetExpectedKernelType( framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
phi::DataLayout layout_ = phi::DataLayout::kAnyLayout;
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
// NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN // NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN
this->SetDnnFallback(!CanMKLDNNSupportPool(ctx)); this->SetDnnFallback(!CanMKLDNNSupportPool(ctx));
// NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN // NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(
input_data_type, ctx.GetPlace(), layout_, library_);
} }
framework::OpKernelType PoolOpGrad::GetKernelTypeForVar( framework::OpKernelType PoolOpGrad::GetKernelTypeForVar(
......
...@@ -43,6 +43,14 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { ...@@ -43,6 +43,14 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
if (ctx.HasAttr("data_format")) { if (ctx.HasAttr("data_format")) {
layout_ = phi::StringToDataLayout(ctx.Attr<std::string>("data_format")); layout_ = phi::StringToDataLayout(ctx.Attr<std::string>("data_format"));
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
layout_,
framework::LibraryType::kCUDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_);
} }
}; };
...@@ -127,6 +135,14 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { ...@@ -127,6 +135,14 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
if (ctx.HasAttr("data_format")) { if (ctx.HasAttr("data_format")) {
layout_ = phi::StringToDataLayout(ctx.Attr<std::string>("data_format")); layout_ = phi::StringToDataLayout(ctx.Attr<std::string>("data_format"));
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
layout_,
framework::LibraryType::kCUDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_);
} }
}; };
......
...@@ -48,6 +48,14 @@ class SoftmaxOp : public framework::OperatorWithKernel { ...@@ -48,6 +48,14 @@ class SoftmaxOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"float16 can only be used on GPU/NPU/XPU/MLU and custom place")); "float16 can only be used on GPU/NPU/XPU/MLU and custom place"));
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
layout_,
framework::LibraryType::kCUDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_);
} }
}; };
...@@ -132,6 +140,14 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -132,6 +140,14 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"float16 can only be used on GPU/NPU/XPU/MLU and custom place")); "float16 can only be used on GPU/NPU/XPU/MLU and custom place"));
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
layout_,
framework::LibraryType::kCUDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_); return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_);
} }
}; };
......
...@@ -617,8 +617,8 @@ class ScopedActivationDescriptor { ...@@ -617,8 +617,8 @@ class ScopedActivationDescriptor {
}; };
inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) { inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
bool use_cudnn = paddle::platform::is_gpu_place(ctx.GetPlace()) && bool use_cudnn = ctx.HasAttr("use_cudnn") && ctx.Attr<bool>("use_cudnn");
ctx.HasAttr("use_cudnn") && ctx.Attr<bool>("use_cudnn"); use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (use_cudnn) { if (use_cudnn) {
auto& dev_ctx = ctx.device_context<phi::GPUContext>(); auto& dev_ctx = ctx.device_context<phi::GPUContext>();
......
...@@ -554,8 +554,8 @@ class ScopedActivationDescriptor { ...@@ -554,8 +554,8 @@ class ScopedActivationDescriptor {
}; };
inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) { inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
bool use_cudnn = paddle::platform::is_gpu_place(ctx.GetPlace()) && bool use_cudnn = ctx.HasAttr("use_cudnn") && ctx.Attr<bool>("use_cudnn");
ctx.HasAttr("use_cudnn") && ctx.Attr<bool>("use_cudnn"); use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
if (use_cudnn) { if (use_cudnn) {
auto& dev_ctx = ctx.device_context<phi::GPUContext>(); auto& dev_ctx = ctx.device_context<phi::GPUContext>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册