未验证 提交 5271c32d 编写于 作者: K Kexin Zhao 提交者: GitHub

Merge pull request #9223 from kexinzhao/dropout_fp16

Add float16 support to dropout operator
......@@ -35,7 +35,6 @@ class DropoutOp : public framework::OperatorWithKernel {
}
};
template <typename AttrType>
class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
public:
DropoutOpMaker(OpProto* proto, OpAttrChecker* op_checker)
......@@ -73,7 +72,6 @@ are set equal to their corresponding inputs.
}
};
template <typename AttrType>
class DropoutOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -103,11 +101,10 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker<float>, dropout_grad,
ops::DropoutOpGrad<float>);
REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad,
ops::DropoutOpGrad);
REGISTER_OP_CPU_KERNEL(
dropout,
ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float, float>);
dropout, ops::CPUDropoutKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
dropout_grad,
ops::DropoutGradKernel<paddle::platform::CPUDeviceContext, float>);
......@@ -18,17 +18,18 @@ limitations under the License. */
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
template <typename T, typename AttrType>
template <typename T>
__global__ void RandomGenerator(const size_t n, const int seed,
const AttrType dropout_prob, const T* src,
const float dropout_prob, const T* src,
T* mask_data, T* dst) {
thrust::minstd_rand rng;
rng.seed(seed);
thrust::uniform_real_distribution<AttrType> dist(0, 1);
thrust::uniform_real_distribution<float> dist(0, 1);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < n; idx += blockDim.x * gridDim.x) {
......@@ -44,14 +45,14 @@ __global__ void RandomGenerator(const size_t n, const int seed,
// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template <typename Place, typename T, typename AttrType>
template <typename Place, typename T>
class GPUDropoutKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* y = context.Output<Tensor>("Out");
y->mutable_data<T>(context.GetPlace());
AttrType dropout_prob = context.Attr<AttrType>("dropout_prob");
float dropout_prob = context.Attr<float>("dropout_prob");
auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1);
......@@ -70,11 +71,11 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
int threads = 512;
int grid = (x->numel() + threads - 1) / threads;
RandomGenerator<T, AttrType><<<grid, threads, 0,
context.cuda_device_context().stream()>>>(
RandomGenerator<
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
size, seed, dropout_prob, x_data, mask_data, y_data);
} else {
Y.device(place) = X * (1.0f - dropout_prob);
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
}
}
};
......@@ -83,9 +84,9 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
dropout,
ops::GPUDropoutKernel<paddle::platform::CUDADeviceContext, float, float>);
REGISTER_OP_CUDA_KERNEL(
dropout_grad,
ops::DropoutGradKernel<paddle::platform::CUDADeviceContext, float>);
dropout, ops::GPUDropoutKernel<plat::CUDADeviceContext, float>,
ops::GPUDropoutKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(dropout_grad,
ops::DropoutGradKernel<plat::CUDADeviceContext, float>);
......@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T, typename AttrType>
template <typename DeviceContext, typename T>
class CPUDropoutKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......
......@@ -483,8 +483,123 @@ DEVICE inline bool operator>=(const half& a, const half& b) {
#endif // PADDLE_CUDA_FP16
// Arithmetic operators on ARMv8.2-A CPU
#if defined(PADDLE_WITH_NATIVE_FP16)
// Arithmetic operators for float16 on GPU
#if defined(PADDLE_CUDA_FP16)
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hadd(half(a), half(b)));
#else
return float16(float(a) + float(b));
#endif
}
HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hsub(half(a), half(b)));
#else
return float16(float(a) - float(b));
#endif
}
HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hmul(half(a), half(b)));
#else
return float16(float(a) * float(b));
#endif
}
HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
// TODO(kexinzhao): check which cuda version starts to support __hdiv
float num = __half2float(half(a));
float denom = __half2float(half(b));
return float16(num / denom);
#else
return float16(float(a) / float(b));
#endif
}
HOSTDEVICE inline float16 operator-(const float16& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hneg(half(a)));
#else
float16 res;
res.x = a.x ^ 0x8000;
return res;
#endif
}
HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) {
a = a + b;
return a;
}
HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) {
a = a - b;
return a;
}
HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) {
a = a * b;
return a;
}
HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) {
a = a / b;
return a;
}
HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __heq(half(a), half(b));
#else
return float(a) == float(b);
#endif
}
HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hne(half(a), half(b));
#else
return float(a) != float(b);
#endif
}
HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hlt(half(a), half(b));
#else
return float(a) < float(b);
#endif
}
HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hle(half(a), half(b));
#else
return float(a) <= float(b);
#endif
}
HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hgt(half(a), half(b));
#else
return float(a) > float(b);
#endif
}
HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hge(half(a), half(b));
#else
return float(a) >= float(b);
#endif
}
// Arithmetic operators for float16 on ARMv8.2-A CPU
#elif defined(PADDLE_WITH_NATIVE_FP16)
HOST inline float16 operator+(const float16& a, const float16& b) {
float16 res;
asm volatile(
......@@ -668,71 +783,71 @@ HOST inline bool operator>=(const float16& a, const float16& b) {
return (res & 0xffff) != 0;
}
// Arithmetic operators, software emulated on other CPU
// Arithmetic operators for float16, software emulated on other CPU
#else
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
HOST inline float16 operator+(const float16& a, const float16& b) {
return float16(float(a) + float(b));
}
HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
HOST inline float16 operator-(const float16& a, const float16& b) {
return float16(float(a) - float(b));
}
HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
HOST inline float16 operator*(const float16& a, const float16& b) {
return float16(float(a) * float(b));
}
HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
HOST inline float16 operator/(const float16& a, const float16& b) {
return float16(float(a) / float(b));
}
HOSTDEVICE inline float16 operator-(const float16& a) {
HOST inline float16 operator-(const float16& a) {
float16 res;
res.x = a.x ^ 0x8000;
return res;
}
HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) {
HOST inline float16& operator+=(float16& a, const float16& b) {
a = float16(float(a) + float(b));
return a;
}
HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) {
HOST inline float16& operator-=(float16& a, const float16& b) {
a = float16(float(a) - float(b));
return a;
}
HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) {
HOST inline float16& operator*=(float16& a, const float16& b) {
a = float16(float(a) * float(b));
return a;
}
HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) {
HOST inline float16& operator/=(float16& a, const float16& b) {
a = float16(float(a) / float(b));
return a;
}
HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
HOST inline bool operator==(const float16& a, const float16& b) {
return float(a) == float(b);
}
HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
HOST inline bool operator!=(const float16& a, const float16& b) {
return float(a) != float(b);
}
HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
HOST inline bool operator<(const float16& a, const float16& b) {
return float(a) < float(b);
}
HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
HOST inline bool operator<=(const float16& a, const float16& b) {
return float(a) <= float(b);
}
HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
HOST inline bool operator>(const float16& a, const float16& b) {
return float(a) > float(b);
}
HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
HOST inline bool operator>=(const float16& a, const float16& b) {
return float(a) >= float(b);
}
#endif
......
......@@ -14,6 +14,7 @@
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
......@@ -82,5 +83,37 @@ class TestDropoutOp5(OpTest):
self.check_output()
class TestFP16DropoutOp(OpTest):
def setUp(self):
self.op_type = "dropout"
self.init_test_case()
x = np.random.random(self.input_size).astype("float16")
out = x * (1.0 - self.prob)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.attrs = {
'dropout_prob': self.prob,
'fix_seed': self.fix_seed,
'is_test': True
}
self.outputs = {'Out': out}
def init_test_case(self):
self.input_size = [32, 64]
self.prob = 0.35
self.fix_seed = True
def test_check_output(self):
if core.is_compiled_with_cuda() and core.op_support_gpu("dropout"):
self.check_output_with_place(core.CUDAPlace(0), atol=1e-3)
class TestFP16DropoutOp2(TestFP16DropoutOp):
def init_test_case(self):
self.input_size = [32, 64, 3]
self.prob = 0.75
self.fix_seed = False
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册