提交 70e71227 编写于 作者: K Kexin Zhao

initial commit

上级 c1e9b1e3
......@@ -89,6 +89,7 @@ void SoftmaxGradCUDNNFunctor<T>::operator()(
XGrad->mutable_data<T>(context.GetPlace())));
}
template class SoftmaxCUDNNFunctor<platform::float16>;
template class SoftmaxCUDNNFunctor<float>;
template class SoftmaxCUDNNFunctor<double>;
template class SoftmaxGradCUDNNFunctor<float>;
......
......@@ -56,7 +56,9 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(softmax, CUDNN, ::paddle::platform::CUDAPlace,
ops::SoftmaxCUDNNKernel<float>);
REGISTER_OP_KERNEL(softmax_grad, CUDNN, ::paddle::platform::CUDAPlace,
namespace plat = paddle::platform;
REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace,
ops::SoftmaxCUDNNKernel<float>,
ops::SoftmaxCUDNNKernel<plat::float16>);
REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace,
ops::SoftmaxGradCUDNNKernel<float>);
......@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/softmax_op.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace paddle {
namespace operators {
......@@ -38,19 +41,12 @@ class SoftmaxOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
// choose cudnn kernel if the runtime supported.
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
bool runtime_cudnn_support = false;
framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false;
if (platform::CanCUDNNBeUsed(ctx)) {
library = framework::LibraryType::kCUDNN;
}
#endif
framework::LibraryType library_ = framework::LibraryType::kPlain;
if (use_cudnn && runtime_cudnn_support) {
library_ = framework::LibraryType::kCUDNN;
}
std::string data_format = ctx.Attr<std::string>("data_format");
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
......@@ -119,19 +115,12 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
// choose cudnn kernel if the runtime supported.
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
bool runtime_cudnn_support = false;
framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false;
if (platform::CanCUDNNBeUsed(ctx)) {
library = framework::LibraryType::kCUDNN;
}
#endif
framework::LibraryType library_ = framework::LibraryType::kPlain;
if (use_cudnn && runtime_cudnn_support) {
library_ = framework::LibraryType::kCUDNN;
}
std::string data_format = ctx.Attr<std::string>("data_format");
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
......
......@@ -29,15 +29,16 @@ class TestSoftmaxOp(OpTest):
def setUp(self):
self.op_type = "softmax"
self.use_cudnn = False
self.inputs = {
'X': np.random.uniform(0.1, 1, [10, 10]).astype("float32")
}
self.outputs = {
'Out': np.apply_along_axis(stable_softmax, 1, self.inputs['X'])
}
self.dtype = np.float32
self.init_kernel_type()
x = np.random.uniform(0.1, 1, [10, 10]).astype(self.dtype)
out = np.apply_along_axis(stable_softmax, 1, x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.attrs = {'use_cudnn': self.use_cudnn, }
def init_op_type(self):
def init_kernel_type(self):
pass
def test_check_output(self):
......@@ -48,6 +49,8 @@ class TestSoftmaxOp(OpTest):
self.check_output()
def test_check_grad(self):
if self.dtype == np.float16:
return
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_grad_with_place(
......@@ -57,8 +60,20 @@ class TestSoftmaxOp(OpTest):
class TestSoftmaxCUDNNOp(TestSoftmaxOp):
def init_op_type(self):
def init_kernel_type(self):
self.use_cudnn = True
class TestSoftmaxFP16CUDNNOp(TestSoftmaxOp):
def init_kernel_type(self):
self.use_cudnn = True
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-3)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册