未验证 提交 0f38bb45 编写于 作者: K Kexin Zhao 提交者: GitHub

add fp16 support to activation op (#9769)

上级 22b9d4e6
...@@ -662,14 +662,3 @@ REGISTER_OP(swish, ops::ActivationOp, ops::SwishOpMaker, swish_grad, ...@@ -662,14 +662,3 @@ REGISTER_OP(swish, ops::ActivationOp, ops::SwishOpMaker, swish_grad,
ops::grad_functor<double>>); ops::grad_functor<double>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);
REGISTER_OP_CPU_KERNEL(relu,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::ReluFunctor<float>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::ReluFunctor<double>>);
REGISTER_OP_CPU_KERNEL(
relu_grad, ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::ReluGradFunctor<float>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::ReluGradFunctor<double>>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -17,31 +14,19 @@ limitations under the License. */ ...@@ -17,31 +14,19 @@ limitations under the License. */
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CUDA_KERNEL( \ #define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, functor, grad_functor) \
act_type, ops::ActivationKernel<paddle::platform::CUDADeviceContext, \ REGISTER_OP_CUDA_KERNEL( \
ops::functor<float>>, \ act_type, \
ops::ActivationKernel<paddle::platform::CUDADeviceContext, \ ops::ActivationKernel<plat::CUDADeviceContext, ops::functor<float>>, \
ops::functor<double>>); \ ops::ActivationKernel<plat::CUDADeviceContext, ops::functor<double>>, \
REGISTER_OP_CUDA_KERNEL( \ ops::ActivationKernel<plat::CUDADeviceContext, \
act_type##_grad, \ ops::functor<plat::float16>>); \
ops::ActivationGradKernel<paddle::platform::CUDADeviceContext, \ REGISTER_OP_CUDA_KERNEL( \
ops::grad_functor<float>>, \ act_type##_grad, ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::ActivationGradKernel<paddle::platform::CUDADeviceContext, \ ops::grad_functor<float>>, \
ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<double>>); ops::grad_functor<double>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL);
REGISTER_OP_CUDA_KERNEL(
relu, ops::ActivationKernel<paddle::platform::CUDADeviceContext,
ops::ReluFunctor<float>>,
ops::ActivationKernel<paddle::platform::CUDADeviceContext,
ops::ReluFunctor<double>>,
ops::ActivationKernel<paddle::platform::CUDADeviceContext,
ops::ReluFunctor<paddle::platform::float16>>);
REGISTER_OP_CUDA_KERNEL(
relu_grad, ops::ActivationGradKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradFunctor<float>>,
ops::ActivationGradKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradFunctor<double>>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -15,9 +12,11 @@ limitations under the License. */ ...@@ -15,9 +12,11 @@ limitations under the License. */
#pragma once #pragma once
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/platform/float16.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
...@@ -338,11 +337,25 @@ struct Sine { ...@@ -338,11 +337,25 @@ struct Sine {
HOSTDEVICE T operator()(const T& val) const { return sin(val); } HOSTDEVICE T operator()(const T& val) const { return sin(val); }
}; };
template <>
struct Sine<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
return platform::float16(sin(static_cast<float>(val)));
}
};
template <typename T> template <typename T>
struct Cosine { struct Cosine {
HOSTDEVICE T operator()(const T& val) const { return cos(val); } HOSTDEVICE T operator()(const T& val) const { return cos(val); }
}; };
template <>
struct Cosine<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
return platform::float16(cos(static_cast<float>(val)));
}
};
// cosine'(x) = -sin(x) // cosine'(x) = -sin(x)
template <typename T> template <typename T>
struct CosGradFunctor : public BaseActivationFunctor<T> { struct CosGradFunctor : public BaseActivationFunctor<T> {
...@@ -826,6 +839,7 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> { ...@@ -826,6 +839,7 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {
__macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \ __macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ __macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
__macro(exp, ExpFunctor, ExpGradFunctor); \ __macro(exp, ExpFunctor, ExpGradFunctor); \
__macro(relu, ReluFunctor, ReluGradFunctor); \
__macro(tanh, TanhFunctor, TanhGradFunctor); \ __macro(tanh, TanhFunctor, TanhGradFunctor); \
__macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \ __macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \ __macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
......
...@@ -1003,6 +1003,46 @@ HOSTDEVICE inline float16 exp(const float16& a) { ...@@ -1003,6 +1003,46 @@ HOSTDEVICE inline float16 exp(const float16& a) {
return float16(::expf(static_cast<float>(a))); return float16(::expf(static_cast<float>(a)));
} }
template <>
HOSTDEVICE inline float16 log(const float16& a) {
return float16(::logf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline float16 tanh(const float16& a) {
return float16(::tanhf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline float16 sqrt(const float16& a) {
return float16(::sqrtf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline float16 ceil(const float16& a) {
return float16(::ceilf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline float16 floor(const float16& a) {
return float16(::floorf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline float16 round(const float16& a) {
return float16(::roundf(static_cast<float>(a)));
}
template <>
HOSTDEVICE inline float16 pow(const float16& a, const float16& b) {
return float16(::powf(static_cast<float>(a), static_cast<float>(b)));
}
template <>
HOSTDEVICE inline float16 abs(const float16& a) {
return float16(::fabs(static_cast<float>(a)));
}
} // namespace numext } // namespace numext
} // namespace Eigen } // namespace Eigen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册