提交 70e71227 编写于 作者: K Kexin Zhao

initial commit

上级 c1e9b1e3
...@@ -89,6 +89,7 @@ void SoftmaxGradCUDNNFunctor<T>::operator()( ...@@ -89,6 +89,7 @@ void SoftmaxGradCUDNNFunctor<T>::operator()(
XGrad->mutable_data<T>(context.GetPlace()))); XGrad->mutable_data<T>(context.GetPlace())));
} }
template class SoftmaxCUDNNFunctor<platform::float16>;
template class SoftmaxCUDNNFunctor<float>; template class SoftmaxCUDNNFunctor<float>;
template class SoftmaxCUDNNFunctor<double>; template class SoftmaxCUDNNFunctor<double>;
template class SoftmaxGradCUDNNFunctor<float>; template class SoftmaxGradCUDNNFunctor<float>;
......
...@@ -56,7 +56,9 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> { ...@@ -56,7 +56,9 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL(softmax, CUDNN, ::paddle::platform::CUDAPlace, namespace plat = paddle::platform;
ops::SoftmaxCUDNNKernel<float>); REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace,
REGISTER_OP_KERNEL(softmax_grad, CUDNN, ::paddle::platform::CUDAPlace, ops::SoftmaxCUDNNKernel<float>,
ops::SoftmaxCUDNNKernel<plat::float16>);
REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace,
ops::SoftmaxGradCUDNNKernel<float>); ops::SoftmaxGradCUDNNKernel<float>);
...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/softmax_op.h" #include "paddle/fluid/operators/softmax_op.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -38,19 +41,12 @@ class SoftmaxOp : public framework::OperatorWithKernel { ...@@ -38,19 +41,12 @@ class SoftmaxOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
// choose cudnn kernel if the runtime supported. // choose cudnn kernel if the runtime supported.
bool use_cudnn = ctx.Attr<bool>("use_cudnn"); framework::LibraryType library_{framework::LibraryType::kPlain};
bool runtime_cudnn_support = false;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::CanCUDNNBeUsed(ctx)) {
auto& dev_ctx = library = framework::LibraryType::kCUDNN;
ctx.template device_context<platform::CUDADeviceContext>();
runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false;
} }
#endif #endif
framework::LibraryType library_ = framework::LibraryType::kPlain;
if (use_cudnn && runtime_cudnn_support) {
library_ = framework::LibraryType::kCUDNN;
}
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
...@@ -119,19 +115,12 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -119,19 +115,12 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
// choose cudnn kernel if the runtime supported. // choose cudnn kernel if the runtime supported.
bool use_cudnn = ctx.Attr<bool>("use_cudnn"); framework::LibraryType library_{framework::LibraryType::kPlain};
bool runtime_cudnn_support = false;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::CanCUDNNBeUsed(ctx)) {
auto& dev_ctx = library = framework::LibraryType::kCUDNN;
ctx.template device_context<platform::CUDADeviceContext>();
runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false;
} }
#endif #endif
framework::LibraryType library_ = framework::LibraryType::kPlain;
if (use_cudnn && runtime_cudnn_support) {
library_ = framework::LibraryType::kCUDNN;
}
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(), framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
......
...@@ -29,15 +29,16 @@ class TestSoftmaxOp(OpTest): ...@@ -29,15 +29,16 @@ class TestSoftmaxOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "softmax" self.op_type = "softmax"
self.use_cudnn = False self.use_cudnn = False
self.inputs = { self.dtype = np.float32
'X': np.random.uniform(0.1, 1, [10, 10]).astype("float32") self.init_kernel_type()
}
self.outputs = { x = np.random.uniform(0.1, 1, [10, 10]).astype(self.dtype)
'Out': np.apply_along_axis(stable_softmax, 1, self.inputs['X']) out = np.apply_along_axis(stable_softmax, 1, x)
} self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.attrs = {'use_cudnn': self.use_cudnn, } self.attrs = {'use_cudnn': self.use_cudnn, }
def init_op_type(self): def init_kernel_type(self):
pass pass
def test_check_output(self): def test_check_output(self):
...@@ -48,6 +49,8 @@ class TestSoftmaxOp(OpTest): ...@@ -48,6 +49,8 @@ class TestSoftmaxOp(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16:
return
if self.use_cudnn: if self.use_cudnn:
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.check_grad_with_place( self.check_grad_with_place(
...@@ -57,8 +60,20 @@ class TestSoftmaxOp(OpTest): ...@@ -57,8 +60,20 @@ class TestSoftmaxOp(OpTest):
class TestSoftmaxCUDNNOp(TestSoftmaxOp): class TestSoftmaxCUDNNOp(TestSoftmaxOp):
def init_op_type(self): def init_kernel_type(self):
self.use_cudnn = True
class TestSoftmaxFP16CUDNNOp(TestSoftmaxOp):
def init_kernel_type(self):
self.use_cudnn = True self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册