未验证 提交 908a381d 编写于 作者: H HongyuJia 提交者: GitHub

[Restore PR] Remove hard code of PADDLE_WITH_CUDA (#47630)

* move cudnn hardcode outside GetExpectedKernelType

* add header file

* debug

* update interpreter_util with hardcode

* update interpreter_util headerfile

* solve activation hardcode

* debug with CI

* add mkldnn_op_list header file

* temporarily uncomment mkldnn

* temporarily uncomment mkldnn

* delete sequence_softmax cudnn hardcode

* add hardcode to data_transfer.cc

* update data_transfer headerfile

* try fix segment fault

* update cudnn&miopen_helper

* reset HasAttr of DygraphExctnCtx

* debug, this commit should pass all CI

* debug should pass CI, temporarily disable activation

* debug should pass CI

* fix default_attr=nullptr bug

* clean debug code

* Call SetDnnFallback function in the base class

* activation fallback to plain kernel

* fix default GetExpectedKernelType find wrong kernel

* search cudnn kernel instead of fallback

* fix cudnn_handle bug

* remove tanh use_cudnn

* restore tanh use_cudnn

* debug tanh

* fix tanh bug

* delete activation cudnn kernel

* polish code
上级 c9a7cadf
......@@ -133,6 +133,12 @@ void DataTranferHelper::RunAndConstructOpFuncNode(
auto* dev_ctx = pool.Get(place_);
auto exec_ctx = ExecutionContext(*op, Scope(), *dev_ctx, runtime_context);
auto expected_kernel_key = op_with_kernel->GetExpectedKernelType(exec_ctx);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (op_with_kernel->CanCUDNNBeUsed(exec_ctx,
expected_kernel_key.data_type_)) {
expected_kernel_key.library_type_ = framework::LibraryType::kCUDNN;
}
#endif
VLOG(6) << "expected_kernel_key " << expected_kernel_key << "\n";
VLOG(6) << "op_with_kernel Type() " << op_with_kernel->Type() << "\n";
......
......@@ -635,6 +635,12 @@ void BuildOpFuncList(const platform::Place& place,
*op_with_kernel, *runtime_scope, *dev_ctx, runtime_context);
auto expected_kernel_key =
op_with_kernel->GetExpectedKernelType(exec_ctx);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (op_with_kernel->CanCUDNNBeUsed(exec_ctx,
expected_kernel_key.data_type_)) {
expected_kernel_key.library_type_ = framework::LibraryType::kCUDNN;
}
#endif
VLOG(4) << "expected_kernel_key : " << expected_kernel_key;
// change device by the device_guard()
ApplyDeviceGuard(op, place, &expected_kernel_key);
......
......@@ -58,6 +58,10 @@ class DenseTensor;
#include "paddle/fluid/platform/device/mlu/mlu_info.h"
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#endif
DECLARE_bool(benchmark);
DECLARE_bool(check_nan_inf);
DECLARE_bool(enable_unused_var_check);
......@@ -1358,6 +1362,39 @@ bool OperatorWithKernel::SupportsMKLDNN(
}
}
bool OperatorWithKernel::SupportsCUDNN(
const proto::VarType::Type data_type) const {
auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap(
phi::TransToPhiKernelName(type_));
paddle::experimental::DataType phi_data_type =
framework::TransToPhiDataType(data_type);
auto has_phi_kernel = std::any_of(
phi_kernels.begin(),
phi_kernels.end(),
[phi_data_type](phi::KernelKeyMap::const_reference kern_pair) {
return kern_pair.first.backend() == phi::Backend::GPUDNN &&
kern_pair.first.dtype() == phi_data_type;
});
if (has_phi_kernel) {
return true;
} else {
auto op_kernel_iter = OperatorWithKernel::AllOpKernels().find(type_);
if (op_kernel_iter == OperatorWithKernel::AllOpKernels().end()) {
return false;
} else {
auto& op_kernels = op_kernel_iter->second;
return std::any_of(
op_kernels.begin(),
op_kernels.end(),
[data_type](OpKernelMap::const_reference kern_pair) {
return platform::is_gpu_place(kern_pair.first.place_) &&
kern_pair.first.library_type_ == LibraryType::kCUDNN &&
kern_pair.first.data_type_ == data_type;
});
}
}
}
bool OperatorWithKernel::SupportsKernelType(
const OpKernelType& kernel_type, const ExecutionContext& exe_ctx) const {
auto& all_op_kernels = AllOpKernels();
......@@ -1409,17 +1446,49 @@ bool OperatorWithKernel::SupportsKernelType(
}
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (this->CanCUDNNBeUsed(exe_ctx, kernel_type.data_type_)) {
auto tmp_kernel_type = kernel_type;
tmp_kernel_type.library_type_ = framework::LibraryType::kCUDNN;
return kernels.find(tmp_kernel_type) != kernels.end();
}
#endif
return kernel_iter != kernels.end();
}
bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const {
const std::string use_mkldnn_attr = "use_mkldnn";
return ctx.HasAttr(use_mkldnn_attr) && ctx.Attr<bool>(use_mkldnn_attr) &&
return ctx.HasAttr("use_mkldnn") && ctx.Attr<bool>("use_mkldnn") &&
platform::is_cpu_place(ctx.GetPlace()) &&
this->SupportsMKLDNN(data_type);
}
bool OperatorWithKernel::CanCUDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const {
bool use_cudnn = ctx.HasAttr("use_cudnn") && ctx.Attr<bool>("use_cudnn") &&
paddle::platform::is_gpu_place(ctx.GetPlace());
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (use_cudnn) {
auto& dev_ctx = ctx.device_context<phi::GPUContext>();
use_cudnn &= (dev_ctx.cudnn_handle() != nullptr);
}
#endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP
#if defined(PADDLE_WITH_CUDA)
if (use_cudnn && data_type == framework::proto::VarType::BF16) {
PADDLE_ENFORCE_GE(
platform::DnnVersion(),
8100,
platform::errors::InvalidArgument(
"bfloat16 can only be used when CUDNN_VERSION >= 8100"));
}
#endif // PADDLE_WITH_CUDA
return use_cudnn && this->SupportsCUDNN(data_type);
}
void OperatorWithKernel::InferShape(InferShapeContext* ctx) const {
PADDLE_THROW(platform::errors::PermissionDenied(
"The default InferShape function of OperatorWithKernel is not allowed to "
......@@ -1583,6 +1652,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
}
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (this->CanCUDNNBeUsed(exe_ctx, kernel_type_->data_type_)) {
kernel_type_->library_type_ = framework::LibraryType::kCUDNN;
}
#endif
// NOTE(Liu-xiandong):In my ctest, this branch do not be executed,
// I can't understand it, it's really confusing.
// But we still need to keep this to avoid errors.
......@@ -1826,6 +1901,12 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
}
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (this->CanCUDNNBeUsed(ctx, expected_kernel_key.data_type_)) {
expected_kernel_key.library_type_ = framework::LibraryType::kCUDNN;
}
#endif
if (HasAttr("op_device")) {
if (Attr<std::string>("op_device") == "cpu") {
expected_kernel_key.place_ = platform::CPUPlace();
......
......@@ -632,12 +632,17 @@ class OperatorWithKernel : public OperatorBase {
bool SupportsMKLDNN(proto::VarType::Type data_type) const;
bool SupportsCUDNN(proto::VarType::Type data_type) const;
bool SupportsKernelType(const OpKernelType& kernel_type,
const ExecutionContext& exe_ctx) const;
bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const;
bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const;
virtual void InferShape(InferShapeContext* ctx) const;
void RuntimeInferShape(const Scope& scope,
......
......@@ -246,6 +246,12 @@ PreparedOp PrepareImpl(
}
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (op.CanCUDNNBeUsed(dygraph_exe_ctx, expected_kernel_key.data_type_)) {
expected_kernel_key.library_type_ = framework::LibraryType::kCUDNN;
}
#endif
#if defined(PADDLE_WITH_XPU)
bool is_xpu_unsupport =
paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
namespace paddle {
namespace operators {
using platform::ActivationDescriptor;
using platform::TensorDescriptor;
template <typename Functor>
class CudnnActivationKernel
: public framework::OpKernel<Functor::ElEWISE_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
phi::DenseTensor *X, *Out;
ExtractActivationTensor(context, X, Out);
ActivationDescriptor act_desc;
TensorDescriptor x_desc, out_desc;
x_desc.set(GET_DATA_SAFELY(X, "Input", "X", "CudnnActivation"));
out_desc.set(GET_DATA_SAFELY(Out, "Output", "Out", "CudnnActivation");
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
namespace paddle {
namespace operators {
using phi::GPUContext;
using platform::ActivationDescriptor;
using platform::TensorDescriptor;
#ifdef PADDLE_WITH_HIP
#define GPUDNN_ACTIVATION_RELU miopenActivationRELU
#define GPUDNN_ACTIVATION_CLIPPED_RELU miopenActivationCLIPPEDRELU
#define GPUDNN_ACTIVATION_SIGMOID miopenActivationLOGISTIC
#define GPUDNN_ACTIVATION_TANH miopenActivationTANH
#else
#define GPUDNN_ACTIVATION_RELU CUDNN_ACTIVATION_RELU
#define GPUDNN_ACTIVATION_CLIPPED_RELU CUDNN_ACTIVATION_CLIPPED_RELU
#define GPUDNN_ACTIVATION_SIGMOID CUDNN_ACTIVATION_SIGMOID
#define GPUDNN_ACTIVATION_TANH CUDNN_ACTIVATION_TANH
#endif
template <typename T>
struct CudnnActivationFunctor {
using ELEMENT_TYPE = T;
#ifdef PADDLE_WITH_HIP
CudnnActivationFunctor(const phi::GPUContext& ctx,
const T& c,
const miopenActivationMode_t& m)
: ctx_(ctx), coef_(c), mode_(m) {}
#else
CudnnActivationFunctor(const phi::GPUContext& ctx,
const T& c,
const cudnnActivationMode_t& m)
: ctx_(ctx), coef_(c), mode_(m) {}
#endif
void operator()(const phi::DenseTensor& x, phi::DenseTensor* out) {
ActivationDescriptor act_desc;
act_desc.set(mode_, coef_);
TensorDescriptor x_desc, out_desc;
x_desc.set(x);
out_desc.set(GET_DATA_SAFELY(out, "Output", "Out", "CudnnActivation"));
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenActivationForward(
ctx_.cudnn_handle(),
act_desc.desc(),
platform::CudnnDataType<T>::kOne(),
x_desc.desc(),
x.data<T>(),
platform::CudnnDataType<T>::kZero(),
out_desc.desc(),
out->mutable_data<T>(ctx_.GetPlace())));
#else
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnActivationForward(
ctx_.cudnn_handle(),
act_desc.desc(),
platform::CudnnDataType<T>::kOne(),
x_desc.desc(),
x.data<T>(),
platform::CudnnDataType<T>::kZero(),
out_desc.desc(),
out->mutable_data<T>(ctx_.GetPlace())));
#endif
}
const phi::GPUContext& ctx_;
const T coef_;
#ifdef PADDLE_WITH_HIP
const miopenActivationMode_t mode_;
#else
const cudnnActivationMode_t mode_;
#endif
};
template <typename T>
struct CudnnActivationGradFunctor {
using ELEMENT_TYPE = T;
#ifdef PADDLE_WITH_HIP
CudnnActivationGradFunctor(const phi::GPUContext& ctx,
const T& c,
const miopenActivationMode_t& m)
: ctx_(ctx), coef_(c), mode_(m) {}
#else
CudnnActivationGradFunctor(const phi::GPUContext& ctx,
const T& c,
const cudnnActivationMode_t& m)
: ctx_(ctx), coef_(c), mode_(m) {}
#endif
void operator()(const phi::DenseTensor& x,
const phi::DenseTensor& out,
const phi::DenseTensor dout,
phi::DenseTensor* dx) {
ActivationDescriptor act_desc;
act_desc.set(mode_, coef_);
TensorDescriptor x_desc, out_desc, dout_desc, dx_desc;
x_desc.set(x);
out_desc.set(out);
dout_desc.set(dout);
dx_desc.set(GET_DATA_SAFELY(dx, "Output", "X@GRAD", "CudnnActivationGrad"));
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenActivationBackward(
ctx_.cudnn_handle(),
act_desc.desc(),
platform::CudnnDataType<T>::kOne(),
out_desc.desc(),
out.data<T>(),
dout_desc.desc(),
dout.data<T>(),
x_desc.desc(),
x.data<T>(),
platform::CudnnDataType<T>::kZero(),
dx_desc.desc(),
dx->mutable_data<T>(ctx_.GetPlace())));
#else
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnActivationBackward(
ctx_.cudnn_handle(),
act_desc.desc(),
platform::CudnnDataType<T>::kOne(),
out_desc.desc(),
out.data<T>(),
dout_desc.desc(),
dout.data<T>(),
x_desc.desc(),
x.data<T>(),
platform::CudnnDataType<T>::kZero(),
dx_desc.desc(),
dx->mutable_data<T>(ctx_.GetPlace())));
#endif
}
const phi::GPUContext& ctx_;
const T coef_;
#ifdef PADDLE_WITH_HIP
const miopenActivationMode_t mode_;
#else
const cudnnActivationMode_t mode_;
#endif
};
template <typename T>
struct CudnnReluFunctor : public CudnnActivationFunctor<T> {
explicit CudnnReluFunctor(const phi::GPUContext& ctx)
: CudnnActivationFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_RELU) {}
};
template <typename T>
struct CudnnReluGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnReluGradFunctor(const phi::GPUContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_RELU) {}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct CudnnRelu6Functor : public CudnnActivationFunctor<T> {
explicit CudnnRelu6Functor(const phi::GPUContext& ctx)
: CudnnActivationFunctor<T>(ctx, 6.0, GPUDNN_ACTIVATION_CLIPPED_RELU) {}
};
template <typename T>
struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnRelu6GradFunctor(const phi::GPUContext& ctx)
: CudnnActivationGradFunctor<T>(
ctx, 6.0, GPUDNN_ACTIVATION_CLIPPED_RELU) {}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct CudnnSigmoidFunctor : public CudnnActivationFunctor<T> {
explicit CudnnSigmoidFunctor(const phi::GPUContext& ctx)
: CudnnActivationFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_SIGMOID) {}
};
template <typename T>
struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnSigmoidGradFunctor(const phi::GPUContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_SIGMOID) {}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct CudnnTanhFunctor : public CudnnActivationFunctor<T> {
explicit CudnnTanhFunctor(const phi::GPUContext& ctx)
: CudnnActivationFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_TANH) {}
};
template <typename T>
struct CudnnTanhGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnTanhGradFunctor(const phi::GPUContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_TANH) {}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename Functor>
class CudnnActivationKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override {
const phi::DenseTensor* X = nullptr;
phi::DenseTensor* Out = nullptr;
ExtractActivationTensor(context, &X, &Out);
Out->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<phi::GPUContext>();
Functor functor(dev_ctx);
functor(GET_DATA_SAFELY(X, "Input", "X", "CudnnActivation"), Out);
}
};
template <typename Functor>
class CudnnActivationGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override {
static_assert(Functor::FwdDeps() == ActBwdOpFwdDeps::kDepOut,
"Forward deps must be Out.");
const phi::DenseTensor *X, *Out, *dOut;
X = Out = dOut = nullptr;
phi::DenseTensor* dX = nullptr;
ExtractActivationGradTensor<Functor::FwdDeps()>(
context, &X, &Out, &dOut, &dX);
dX->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<phi::GPUContext>();
Functor functor(dev_ctx);
functor(GET_DATA_SAFELY(X, "Input", "X", "CudnnActivationGrad"),
GET_DATA_SAFELY(Out, "Input", "Out", "CudnnActivationGrad"),
GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "CudnnActivationGrad"),
dX);
}
};
} // namespace operators
} // namespace paddle
namespace plat = paddle::platform;
namespace ops = paddle::operators;
#define FOR_EACH_CUDNN_OP_FUNCTOR(__macro) \
__macro(relu, CudnnReluFunctor, CudnnReluGradFunctor); \
__macro(relu6, CudnnRelu6Functor, CudnnRelu6GradFunctor); \
__macro(sigmoid, CudnnSigmoidFunctor, CudnnSigmoidGradFunctor); \
__macro(tanh, CudnnTanhFunctor, CudnnTanhGradFunctor)
#ifdef PADDLE_WITH_HIP
#define REGISTER_ACTIVATION_CUDNN_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_KERNEL(act_type, \
CUDNN, \
plat::CUDAPlace, \
ops::CudnnActivationKernel<ops::functor<float>>); \
REGISTER_OP_KERNEL( \
act_type##_grad, \
CUDNN, \
plat::CUDAPlace, \
ops::CudnnActivationGradKernel<ops::grad_functor<float>>);
#else
#define REGISTER_ACTIVATION_CUDNN_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_KERNEL(act_type, \
CUDNN, \
plat::CUDAPlace, \
ops::CudnnActivationKernel<ops::functor<float>>, \
ops::CudnnActivationKernel<ops::functor<double>>); \
REGISTER_OP_KERNEL( \
act_type##_grad, \
CUDNN, \
plat::CUDAPlace, \
ops::CudnnActivationGradKernel<ops::grad_functor<float>>, \
ops::CudnnActivationGradKernel<ops::grad_functor<double>>);
#endif
FOR_EACH_CUDNN_OP_FUNCTOR(REGISTER_ACTIVATION_CUDNN_KERNEL);
......@@ -134,15 +134,8 @@ class AffineGridOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
library = framework::LibraryType::kCUDNN;
}
#endif
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Theta");
return framework::OpKernelType(
data_type, ctx.GetPlace(), phi::DataLayout::kAnyLayout, library);
return framework::OpKernelType(data_type, ctx.GetPlace());
}
};
......@@ -252,17 +245,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Output")),
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
library_);
auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Output"));
return framework::OpKernelType(data_type, ctx.GetPlace());
}
};
......
......@@ -209,24 +209,6 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
paddle::framework::DataTypeToString(filter_data_type)));
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
#if PADDLE_WITH_CUDA
if (input_data_type == framework::proto::VarType::BF16) {
PADDLE_ENFORCE_GE(
platform::DnnVersion(),
8100,
platform::errors::InvalidArgument(
"bfloat16 can only be used when CUDNN_VERSION >= 8100"));
}
#endif // PADDLE_WITH_CUDA
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
#endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......@@ -476,16 +458,6 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace());
}
......@@ -657,14 +629,6 @@ void ConvOpDoubleGrad::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace());
}
......
......@@ -28,9 +28,6 @@ limitations under the License. */
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#endif
namespace paddle {
namespace operators {
......@@ -40,14 +37,6 @@ using DataLayout = phi::DataLayout;
framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace());
}
......@@ -268,14 +257,6 @@ Example:
framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace());
}
......@@ -343,14 +324,6 @@ class ConvTransposeDoubleGradMaker : public framework::SingleGradOpMaker<T> {
framework::OpKernelType ConvTransposeOpDoubleGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(data_type,
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
framework::LibraryType::kCUDNN);
}
#endif
return framework::OpKernelType(data_type, ctx.GetPlace());
}
......
......@@ -35,17 +35,8 @@ class GridSampleOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
library_);
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.GetPlace());
}
};
......@@ -146,17 +137,8 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace(),
phi::DataLayout::kAnyLayout,
library_);
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.GetPlace());
}
};
......
......@@ -44,21 +44,13 @@ bool CanMKLDNNSupportPool(const framework::ExecutionContext& ctx) {
framework::OpKernelType PoolOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
phi::DataLayout layout_ = phi::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
// NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN
this->SetDnnFallback(!CanMKLDNNSupportPool(ctx));
// NOTE(jiahongyu) END: Above codes originally enclosed by PADDLE_WITH_MKLDNN
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, library_);
return framework::OpKernelType(data_type, ctx.GetPlace());
}
framework::OpKernelType PoolOp::GetKernelTypeForVar(
......@@ -86,22 +78,13 @@ framework::OpKernelType PoolOp::GetKernelTypeForVar(
framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
phi::DataLayout layout_ = phi::DataLayout::kAnyLayout;
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
// NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN
this->SetDnnFallback(!CanMKLDNNSupportPool(ctx));
// NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN
return framework::OpKernelType(
input_data_type, ctx.GetPlace(), layout_, library_);
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
framework::OpKernelType PoolOpGrad::GetKernelTypeForVar(
......
......@@ -43,14 +43,6 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
if (ctx.HasAttr("data_format")) {
layout_ = phi::StringToDataLayout(ctx.Attr<std::string>("data_format"));
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
layout_,
framework::LibraryType::kCUDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_);
}
};
......@@ -135,14 +127,6 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
if (ctx.HasAttr("data_format")) {
layout_ = phi::StringToDataLayout(ctx.Attr<std::string>("data_format"));
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
layout_,
framework::LibraryType::kCUDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_);
}
};
......
......@@ -48,14 +48,6 @@ class SoftmaxOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument(
"float16 can only be used on GPU/NPU/XPU/MLU and custom place"));
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
layout_,
framework::LibraryType::kCUDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_);
}
};
......@@ -140,14 +132,6 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
PADDLE_THROW(platform::errors::InvalidArgument(
"float16 can only be used on GPU/NPU/XPU/MLU and custom place"));
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::CanCUDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
layout_,
framework::LibraryType::kCUDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_);
}
};
......
......@@ -362,7 +362,6 @@ register_unity_group(
lstm_op.cu.cc
rnn_op.cu.cc
split_op.cu.cc
activation_cudnn_op.cu.cc
assign_value_op.cu.cc
run_program_op.cu.cc
warpctc_op.cu.cc)
......
......@@ -616,18 +616,6 @@ class ScopedActivationDescriptor {
DISABLE_COPY_AND_ASSIGN(ScopedActivationDescriptor);
};
inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
bool use_cudnn = ctx.HasAttr("use_cudnn") && ctx.Attr<bool>("use_cudnn");
use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_CUDA
if (use_cudnn) {
auto& dev_ctx = ctx.device_context<phi::GPUContext>();
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
}
#endif
return use_cudnn;
}
#if CUDNN_VERSION >= 7001
class ScopedCTCLossDescriptor {
public:
......
......@@ -553,18 +553,6 @@ class ScopedActivationDescriptor {
DISABLE_COPY_AND_ASSIGN(ScopedActivationDescriptor);
};
inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
bool use_cudnn = ctx.HasAttr("use_cudnn") && ctx.Attr<bool>("use_cudnn");
use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_HIP
if (use_cudnn) {
auto& dev_ctx = ctx.device_context<phi::GPUContext>();
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
}
#endif
return use_cudnn;
}
class ScopedCTCLossDescriptor {
public:
ScopedCTCLossDescriptor() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册