未验证 提交 8d8527fb 编写于 作者: Y Yibing Liu 提交者: GitHub

register fp16 kernel for some ops (#22650)

test=release/1.7
上级 5d96b6e0
......@@ -261,7 +261,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
int output_offset =
transformed_output.numel() / transformed_output.dims()[0] / groups;
int filter_offset = filter->numel() / groups;
T alpha = 1.0f, beta = 0.0f;
T alpha = static_cast<T>(1.0), beta = static_cast<T>(0.0);
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
for (int g = 0; g < groups; g++) {
auto cudnn_func = [&](void* cudnn_workspace) {
......@@ -507,7 +507,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
int output_grad_offset = transformed_output_grad.numel() /
transformed_output_grad.dims()[0] / groups;
int filter_offset = filter->numel() / groups;
T alpha = 1.0f, beta = 0.0f;
T alpha = static_cast<T>(1.0), beta = static_cast<T>(0.0);
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
......@@ -569,17 +569,22 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_KERNEL(conv2d_transpose, CUDNN, ::paddle::platform::CUDAPlace,
ops::CUDNNConvTransposeOpKernel<plat::float16>,
ops::CUDNNConvTransposeOpKernel<float>,
ops::CUDNNConvTransposeOpKernel<double>);
REGISTER_OP_KERNEL(conv2d_transpose_grad, CUDNN, ::paddle::platform::CUDAPlace,
ops::CUDNNConvTransposeGradOpKernel<plat::float16>,
ops::CUDNNConvTransposeGradOpKernel<float>,
ops::CUDNNConvTransposeGradOpKernel<double>);
REGISTER_OP_KERNEL(conv3d_transpose, CUDNN, ::paddle::platform::CUDAPlace,
ops::CUDNNConvTransposeOpKernel<plat::float16>,
ops::CUDNNConvTransposeOpKernel<float>,
ops::CUDNNConvTransposeOpKernel<double>);
REGISTER_OP_KERNEL(conv3d_transpose_grad, CUDNN, ::paddle::platform::CUDAPlace,
ops::CUDNNConvTransposeGradOpKernel<plat::float16>,
ops::CUDNNConvTransposeGradOpKernel<float>,
ops::CUDNNConvTransposeGradOpKernel<double>);
......@@ -14,9 +14,12 @@ limitations under the License. */
#include "paddle/fluid/operators/expand_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
expand, ops::ExpandKernel<paddle::platform::CUDADeviceContext, float>,
ops::ExpandKernel<paddle::platform::CUDADeviceContext, double>,
ops::ExpandKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::ExpandKernel<paddle::platform::CUDADeviceContext, int>,
ops::ExpandKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ExpandKernel<paddle::platform::CUDADeviceContext, bool>);
......@@ -24,5 +27,6 @@ REGISTER_OP_CUDA_KERNEL(
expand_grad,
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -461,8 +461,12 @@ class Pad2dGradCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(pad2d, ops::Pad2dCUDAKernel<float>,
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(pad2d, ops::Pad2dCUDAKernel<plat::float16>,
ops::Pad2dCUDAKernel<float>,
ops::Pad2dCUDAKernel<double>, ops::Pad2dCUDAKernel<int>,
ops::Pad2dCUDAKernel<int64_t>);
REGISTER_OP_CUDA_KERNEL(pad2d_grad, ops::Pad2dGradCUDAKernel<float>,
REGISTER_OP_CUDA_KERNEL(pad2d_grad, ops::Pad2dGradCUDAKernel<plat::float16>,
ops::Pad2dGradCUDAKernel<float>,
ops::Pad2dGradCUDAKernel<double>);
......@@ -15,10 +15,12 @@ limitations under the License. */
#include "paddle/fluid/operators/squeeze_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
squeeze, ops::SqueezeKernel<paddle::platform::CUDADeviceContext, float>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, double>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -26,12 +28,14 @@ REGISTER_OP_CUDA_KERNEL(
squeeze_grad,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
squeeze2, ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, float>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, double>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -39,6 +43,7 @@ REGISTER_OP_CUDA_KERNEL(
squeeze2_grad,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, float>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, double>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -15,10 +15,12 @@ limitations under the License. */
#include "paddle/fluid/operators/unsqueeze_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
unsqueeze, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -26,6 +28,8 @@ REGISTER_OP_CUDA_KERNEL(
unsqueeze_grad,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -33,6 +37,7 @@ REGISTER_OP_CUDA_KERNEL(
unsqueeze2,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -40,6 +45,8 @@ REGISTER_OP_CUDA_KERNEL(
unsqueeze2_grad,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, float>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, double>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册