未验证 提交 64c5c8f8 编写于 作者: K Kexin Zhao 提交者: GitHub

Merge pull request #9269 from kexinzhao/softmax_cudnn_fp16

Add float16 support to cudnn softmax kernel
...@@ -89,6 +89,7 @@ void SoftmaxGradCUDNNFunctor<T>::operator()( ...@@ -89,6 +89,7 @@ void SoftmaxGradCUDNNFunctor<T>::operator()(
XGrad->mutable_data<T>(context.GetPlace()))); XGrad->mutable_data<T>(context.GetPlace())));
} }
template class SoftmaxCUDNNFunctor<platform::float16>;
template class SoftmaxCUDNNFunctor<float>; template class SoftmaxCUDNNFunctor<float>;
template class SoftmaxCUDNNFunctor<double>; template class SoftmaxCUDNNFunctor<double>;
template class SoftmaxGradCUDNNFunctor<float>; template class SoftmaxGradCUDNNFunctor<float>;
......
...@@ -56,7 +56,9 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> { ...@@ -56,7 +56,9 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL(softmax, CUDNN, ::paddle::platform::CUDAPlace, namespace plat = paddle::platform;
ops::SoftmaxCUDNNKernel<float>); REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace,
REGISTER_OP_KERNEL(softmax_grad, CUDNN, ::paddle::platform::CUDAPlace, ops::SoftmaxCUDNNKernel<float>,
ops::SoftmaxCUDNNKernel<plat::float16>);
REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace,
ops::SoftmaxGradCUDNNKernel<float>); ops::SoftmaxGradCUDNNKernel<float>);
...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/softmax_op.h" #include "paddle/fluid/operators/softmax_op.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
...@@ -41,29 +44,30 @@ class SoftmaxOp : public framework::OperatorWithKernel { ...@@ -41,29 +44,30 @@ class SoftmaxOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
// choose cudnn kernel if the runtime supported. // choose cudnn kernel if the runtime supported.
bool use_cudnn = ctx.Attr<bool>("use_cudnn"); framework::LibraryType library_{framework::LibraryType::kPlain};
bool runtime_cudnn_support = false;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::CanCUDNNBeUsed(ctx)) {
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false;
}
#endif
framework::LibraryType library_ = framework::LibraryType::kPlain;
if (use_cudnn && runtime_cudnn_support) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
} }
#endif #endif
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("X")->type());
if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN,
"float16 can only be used when CUDNN is used");
}
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
return framework::OpKernelType( return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), framework::StringToDataLayout(data_format),
framework::StringToDataLayout(data_format), library_); library_);
} }
}; };
class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -130,19 +134,12 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -130,19 +134,12 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
// choose cudnn kernel if the runtime supported. // choose cudnn kernel if the runtime supported.
bool use_cudnn = ctx.Attr<bool>("use_cudnn"); framework::LibraryType library_{framework::LibraryType::kPlain};
bool runtime_cudnn_support = false;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::CanCUDNNBeUsed(ctx)) {
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false;
}
#endif
framework::LibraryType library_ = framework::LibraryType::kPlain;
if (use_cudnn && runtime_cudnn_support) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
......
...@@ -27,22 +27,22 @@ def stable_softmax(x): ...@@ -27,22 +27,22 @@ def stable_softmax(x):
class TestSoftmaxOp(OpTest): class TestSoftmaxOp(OpTest):
def setUp(self): def setUp(self):
self.use_mkldnn = False
self.op_type = "softmax" self.op_type = "softmax"
self.use_cudnn = False self.use_cudnn = False
self.init_op_type() self.use_mkldnn = False
self.inputs = { self.dtype = np.float32
'X': np.random.uniform(0.1, 1, [10, 10]).astype("float32") self.init_kernel_type()
}
self.outputs = { x = np.random.uniform(0.1, 1, [10, 10]).astype(self.dtype)
'Out': np.apply_along_axis(stable_softmax, 1, self.inputs['X']) out = np.apply_along_axis(stable_softmax, 1, x)
} self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.attrs = { self.attrs = {
'use_cudnn': self.use_cudnn, 'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn 'use_mkldnn': self.use_mkldnn
} }
def init_op_type(self): def init_kernel_type(self):
pass pass
def test_check_output(self): def test_check_output(self):
...@@ -53,6 +53,8 @@ class TestSoftmaxOp(OpTest): ...@@ -53,6 +53,8 @@ class TestSoftmaxOp(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16:
return
if self.use_cudnn: if self.use_cudnn:
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place( self.check_grad_with_place(
...@@ -62,12 +64,24 @@ class TestSoftmaxOp(OpTest): ...@@ -62,12 +64,24 @@ class TestSoftmaxOp(OpTest):
class TestSoftmaxCUDNNOp(TestSoftmaxOp): class TestSoftmaxCUDNNOp(TestSoftmaxOp):
def init_op_type(self): def init_kernel_type(self):
self.use_cudnn = True self.use_cudnn = True
class TestMKLDNN(TestSoftmaxOp): class TestSoftmaxFP16CUDNNOp(TestSoftmaxOp):
def init_op_type(self): def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
class TestSoftmaxMKLDNNOp(TestSoftmaxOp):
def init_kernel_type(self):
self.use_mkldnn = True self.use_mkldnn = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册