Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
39ac9e39
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
39ac9e39
编写于
7月 30, 2018
作者:
D
dzhwinter
提交者:
GitHub
7月 30, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
float16 type support enhance (#12181)
* cherry picked * "cherry picked platform" * "add comment" * "fix ci"
上级
19ef4bab
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
334 addition
and
7 deletion
+334
-7
paddle/fluid/platform/CMakeLists.txt
paddle/fluid/platform/CMakeLists.txt
+4
-0
paddle/fluid/platform/cuda_device_function.h
paddle/fluid/platform/cuda_device_function.h
+21
-0
paddle/fluid/platform/cuda_helper_test.cu
paddle/fluid/platform/cuda_helper_test.cu
+118
-0
paddle/fluid/platform/cuda_primitives.h
paddle/fluid/platform/cuda_primitives.h
+69
-6
paddle/fluid/platform/float16.h
paddle/fluid/platform/float16.h
+27
-0
paddle/fluid/platform/float16_test.cc
paddle/fluid/platform/float16_test.cc
+26
-0
paddle/fluid/platform/float16_test.cu
paddle/fluid/platform/float16_test.cu
+69
-1
未找到文件。
paddle/fluid/platform/CMakeLists.txt
浏览文件 @
39ac9e39
...
...
@@ -60,3 +60,7 @@ cc_test(profiler_test SRCS profiler_test.cc DEPS profiler)
nv_test
(
float16_gpu_test SRCS float16_test.cu DEPS lod_tensor
)
cc_test
(
float16_test SRCS float16_test.cc DEPS lod_tensor
)
IF
(
WITH_GPU
)
nv_test
(
cuda_helper_test SRCS cuda_helper_test.cu
)
ENDIF
()
paddle/fluid/platform/cuda_device_function.h
浏览文件 @
39ac9e39
...
...
@@ -14,6 +14,10 @@ limitations under the License. */
#pragma once
#include <cuda.h>
// NOTE(): support float16 to half in header file.
#define PADDLE_CUDA_FP16
#include <cuda_fp16.h>
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
platform
{
...
...
@@ -36,6 +40,18 @@ __forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val,
#endif
}
// CUDA 9.0 have native compatible float16 shfl_down
#if CUDA_VERSION < 9000
template
<
>
__forceinline__
__device__
float16
CudaShuffleDownSync
(
unsigned
mask
,
float16
val
,
int
delta
,
int
width
)
{
half
tmp
=
static_cast
<
half
>
(
val
);
__shfl_down
(
tmp
,
static_cast
<
unsigned
>
(
delta
),
width
);
return
float16
(
tmp
);
}
#endif
template
<
typename
T
>
__forceinline__
__device__
T
CudaShuffleSync
(
unsigned
mask
,
T
val
,
int
src_line
,
int
width
=
32
)
{
...
...
@@ -46,6 +62,11 @@ __forceinline__ __device__ T CudaShuffleSync(unsigned mask, T val, int src_line,
#endif
}
template
<
typename
T
>
HOSTDEVICE
T
Infinity
()
{
return
INFINITY
;
}
template
<
typename
T
>
__device__
T
reduceSum
(
T
val
,
int
tid
,
int
len
)
{
// NOTE(zcd): The warp size should be taken from the
...
...
paddle/fluid/platform/cuda_helper_test.cu
0 → 100644
浏览文件 @
39ac9e39
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <bitset>
#include <iostream>
#include <random>
#define PADDLE_CUDA_FP16
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
using
paddle
::
platform
::
PADDLE_CUDA_NUM_THREADS
;
using
paddle
::
platform
::
float16
;
#define CUDA_ATOMIC_KERNEL(op, T) \
__global__ void op##Kernel(const T* data_a, T* data_b, size_t num) { \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; \
i += blockDim.x * gridDim.x) { \
paddle::platform::CudaAtomic##op(&data_b[i], data_a[i]); \
} \
}
template
<
typename
T
>
struct
AddFunctor
{
T
operator
()(
const
T
&
a
,
const
T
&
b
)
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
SubFunctor
{
T
operator
()(
const
T
&
a
,
const
T
&
b
)
{
return
a
-
b
;
}
};
// NOTE(dzhwinter): the float16 add has small underflow/overflow
// so we use EXPECT_NEAR to check the result.
#define ARITHMETIC_KERNEL_LAUNCH(op, T) \
void Test##T##op(size_t num) { \
T *in1, *in2, *out; \
T *d_in1, *d_in2; \
size_t size = sizeof(T) * num; \
cudaMalloc(reinterpret_cast<void**>(&d_in1), size); \
cudaMalloc(reinterpret_cast<void**>(&d_in2), size); \
in1 = reinterpret_cast<T*>(malloc(size)); \
in2 = reinterpret_cast<T*>(malloc(size)); \
out = reinterpret_cast<T*>(malloc(size)); \
std::minstd_rand engine; \
std::uniform_real_distribution<double> dist(0.0, 1.0); \
for (size_t i = 0; i < num; ++i) { \
in1[i] = static_cast<T>(dist(engine)); \
in2[i] = static_cast<T>(dist(engine)); \
} \
cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); \
cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice); \
op##Kernel<<<1, PADDLE_CUDA_NUM_THREADS>>>(d_in1, d_in2, num); \
cudaDeviceSynchronize(); \
cudaMemcpy(out, d_in2, size, cudaMemcpyDeviceToHost); \
cudaDeviceSynchronize(); \
for (size_t i = 0; i < num; ++i) { \
EXPECT_NEAR(static_cast<float>(out[i]), \
static_cast<float>(op##Functor<T>()(in1[i], in2[i])), \
0.001); \
} \
free(in1); \
free(in2); \
free(out); \
cudaFree(d_in1); \
cudaFree(d_in2); \
}
CUDA_ATOMIC_KERNEL
(
Add
,
float
);
CUDA_ATOMIC_KERNEL
(
Add
,
double
);
CUDA_ATOMIC_KERNEL
(
Add
,
float16
);
ARITHMETIC_KERNEL_LAUNCH
(
Add
,
float
);
ARITHMETIC_KERNEL_LAUNCH
(
Add
,
double
);
ARITHMETIC_KERNEL_LAUNCH
(
Add
,
float16
);
namespace
paddle
{
namespace
platform
{
USE_CUDA_ATOMIC
(
Sub
,
int
);
};
};
CUDA_ATOMIC_KERNEL
(
Sub
,
int
);
ARITHMETIC_KERNEL_LAUNCH
(
Sub
,
int
);
// cuda primitives
TEST
(
CudaAtomic
,
Add
)
{
TestfloatAdd
(
static_cast
<
size_t
>
(
10
));
TestfloatAdd
(
static_cast
<
size_t
>
(
1024
*
1024
));
TestdoubleAdd
(
static_cast
<
size_t
>
(
10
));
TestdoubleAdd
(
static_cast
<
size_t
>
(
1024
*
1024
));
}
TEST
(
CudaAtomic
,
Sub
)
{
TestintSub
(
static_cast
<
size_t
>
(
10
));
TestintSub
(
static_cast
<
size_t
>
(
1024
*
1024
));
}
TEST
(
CudaAtomic
,
float16
)
{
using
paddle
::
platform
::
float16
;
Testfloat16Add
(
static_cast
<
size_t
>
(
1
));
Testfloat16Add
(
static_cast
<
size_t
>
(
2
));
Testfloat16Add
(
static_cast
<
size_t
>
(
3
));
Testfloat16Add
(
static_cast
<
size_t
>
(
10
));
Testfloat16Add
(
static_cast
<
size_t
>
(
1024
*
1024
));
}
paddle/fluid/platform/cuda_primitives.h
浏览文件 @
39ac9e39
...
...
@@ -14,12 +14,14 @@ limitations under the License. */
#pragma once
#include <cuda.h>
#include <stdio.h>
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
platform
{
#define CUDA_ATOMIC_WRAPPER(op, T) \
__device__ __forceinline__ T CudaAtomic##op(T
*
address, const T val)
__device__ __forceinline__ T CudaAtomic##op(T
*
address, const T val)
#define USE_CUDA_ATOMIC(op, T) \
CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }
...
...
@@ -42,7 +44,7 @@ CUDA_ATOMIC_WRAPPER(Add, int64_t) {
static_assert
(
sizeof
(
int64_t
)
==
sizeof
(
long
long
int
),
// NOLINT
"long long should be int64"
);
return
CudaAtomicAdd
(
reinterpret_cast
<
unsigned
long
long
int
*>
(
address
),
// NOLINT
reinterpret_cast
<
unsigned
long
long
int
*>
(
address
),
// NOLINT
static_cast
<
unsigned
long
long
int
>
(
val
));
// NOLINT
}
...
...
@@ -50,8 +52,8 @@ CUDA_ATOMIC_WRAPPER(Add, int64_t) {
USE_CUDA_ATOMIC
(
Add
,
double
);
#else
CUDA_ATOMIC_WRAPPER
(
Add
,
double
)
{
unsigned
long
long
int
*
address_as_ull
=
// NOLINT
reinterpret_cast
<
unsigned
long
long
int
*>
(
address
);
// NOLINT
unsigned
long
long
int
*
address_as_ull
=
// NOLINT
reinterpret_cast
<
unsigned
long
long
int
*>
(
address
);
// NOLINT
unsigned
long
long
int
old
=
*
address_as_ull
,
assumed
;
// NOLINT
do
{
...
...
@@ -64,6 +66,67 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
return
__longlong_as_double
(
old
);
}
#endif
#ifdef PADDLE_CUDA_FP16
// NOTE(dzhwinter): cuda do not have atomicCAS for half.
// Just use the half address as a unsigned value address and
// do the atomicCAS. According to the value store at high 16 bits
// or low 16 bits, then do a different sum and CAS.
// Given most warp-threads will failed on the atomicCAS, so this
// implemented should be avoided in high concurrency. It's will be
// slower than the way convert value into 32bits and do a full atomicCAS.
// convert the value into float and do the add arithmetic.
// then store the result into a uint32.
inline
__device__
uint32_t
add_to_low_half
(
uint32_t
val
,
float
x
)
{
float16
low_half
;
// the float16 in lower 16bits
low_half
.
x
=
static_cast
<
uint16_t
>
(
val
&
0xffffu
);
low_half
=
static_cast
<
float16
>
(
static_cast
<
float
>
(
low_half
)
+
x
);
return
(
val
&
0xffff0000u
)
|
low_half
.
x
;
}
inline
__device__
uint32_t
add_to_high_half
(
uint32_t
val
,
float
x
)
{
float16
high_half
;
// the float16 in higher 16bits
high_half
.
x
=
static_cast
<
uint16_t
>
(
val
>>
16
);
high_half
=
static_cast
<
float16
>
(
static_cast
<
float
>
(
high_half
)
+
x
);
return
(
val
&
0xffffu
)
|
(
static_cast
<
uint32_t
>
(
high_half
.
x
)
<<
16
);
}
CUDA_ATOMIC_WRAPPER
(
Add
,
float16
)
{
// concrete packed float16 value may exsits in lower or higher 16bits
// of the 32bits address.
uint32_t
*
address_as_ui
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
char
*>
(
address
)
-
(
reinterpret_cast
<
size_t
>
(
address
)
&
2
));
float
val_f
=
static_cast
<
float
>
(
val
);
uint32_t
old
=
*
address_as_ui
;
uint32_t
sum
;
uint32_t
newval
;
uint32_t
assumed
;
if
(((
size_t
)
address
&
2
)
==
0
)
{
// the float16 value stay at lower 16 bits of the address.
do
{
assumed
=
old
;
old
=
atomicCAS
(
address_as_ui
,
assumed
,
add_to_low_half
(
assumed
,
val_f
));
}
while
(
old
!=
assumed
);
float16
ret
;
ret
.
x
=
old
&
0xffffu
;
return
ret
;
}
else
{
// the float16 value stay at higher 16 bits of the address.
do
{
assumed
=
old
;
old
=
atomicCAS
(
address_as_ui
,
assumed
,
add_to_high_half
(
assumed
,
val_f
));
}
while
(
old
!=
assumed
);
float16
ret
;
ret
.
x
=
old
>>
16
;
return
ret
;
}
}
#endif
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/float16.h
浏览文件 @
39ac9e39
...
...
@@ -67,8 +67,11 @@ struct float16;
}
// namespace platform
}
// namespace paddle
// NOTE():
// Do not move the eigen.h header, otherwise the eigen_vector<bool> will failed.
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/platform/hostdevice.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace
paddle
{
namespace
platform
{
...
...
@@ -898,6 +901,30 @@ struct is_pod<paddle::platform::float16> {
is_standard_layout
<
paddle
::
platform
::
float16
>::
value
;
};
template
<
>
struct
is_floating_point
<
paddle
::
platform
::
float16
>
:
std
::
integral_constant
<
bool
,
std
::
is_same
<
paddle
::
platform
::
float16
,
typename
std
::
remove_cv
<
paddle
::
platform
::
float16
>::
type
>::
value
>
{};
template
<
>
struct
is_signed
<
paddle
::
platform
::
float16
>
{
static
const
bool
value
=
true
;
};
template
<
>
struct
is_unsigned
<
paddle
::
platform
::
float16
>
{
static
const
bool
value
=
false
;
};
inline
bool
isnan
(
const
paddle
::
platform
::
float16
&
a
)
{
return
paddle
::
platform
::
isnan
(
a
);
}
inline
bool
isinf
(
const
paddle
::
platform
::
float16
&
a
)
{
return
paddle
::
platform
::
isinf
(
a
);
}
template
<
>
struct
numeric_limits
<
paddle
::
platform
::
float16
>
{
static
const
bool
is_specialized
=
true
;
...
...
paddle/fluid/platform/float16_test.cc
浏览文件 @
39ac9e39
...
...
@@ -141,10 +141,36 @@ TEST(float16, lod_tensor_cpu) {
}
}
TEST
(
float16
,
floating
)
{
// compile time assert.
PADDLE_ASSERT
(
std
::
is_floating_point
<
float16
>::
value
);
}
TEST
(
float16
,
print
)
{
float16
a
=
float16
(
1.0
f
);
std
::
cout
<<
a
<<
std
::
endl
;
}
// CPU test
TEST
(
float16
,
isinf
)
{
float16
a
;
a
.
x
=
0x7c00
;
float16
b
=
float16
(
INFINITY
);
float16
c
=
static_cast
<
float16
>
(
INFINITY
);
EXPECT_EQ
(
std
::
isinf
(
a
),
true
);
EXPECT_EQ
(
std
::
isinf
(
b
),
true
);
EXPECT_EQ
(
std
::
isinf
(
c
),
true
);
}
TEST
(
float16
,
isnan
)
{
float16
a
;
a
.
x
=
0x7fff
;
float16
b
=
float16
(
NAN
);
float16
c
=
static_cast
<
float16
>
(
NAN
);
EXPECT_EQ
(
std
::
isnan
(
a
),
true
);
EXPECT_EQ
(
std
::
isnan
(
b
),
true
);
EXPECT_EQ
(
std
::
isnan
(
c
),
true
);
}
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/float16_test.cu
浏览文件 @
39ac9e39
...
...
@@ -11,11 +11,13 @@ limitations under the License. */
#include "paddle/fluid/platform/float16.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <bitset>
#include <iostream>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/legacy/utils/Logging.h"
#define ARITHMETIC_KERNEL(op_type, sign) \
__global__ void op_type(const half* in1, const half* in2, half* out) { \
...
...
@@ -241,6 +243,72 @@ TEST(float16, lod_tensor_on_gpu) {
}
}
template
<
typename
T
>
struct
Functor
{
bool
operator
()(
const
T
&
val
)
{
return
std
::
type_index
(
typeid
(
T
))
==
std
::
type_index
(
typeid
(
platform
::
float16
));
}
};
TEST
(
float16
,
typeid
)
{
// the framework heavily used typeid hash
Functor
<
float16
>
functor
;
float16
a
=
float16
(
.0
f
);
Functor
<
int
>
functor2
;
int
b
(
0
);
// compile time assert
PADDLE_ASSERT
(
functor
(
a
)
==
true
);
PADDLE_ASSERT
(
functor2
(
b
)
==
false
);
}
// GPU test
TEST
(
float16
,
isinf
)
{
float16
a
;
a
.
x
=
0x7c00
;
float16
b
=
float16
(
INFINITY
);
// underflow to 0
float16
native_a
(
5e-40
f
);
// overflow to inf
float16
native_b
(
5e40
f
);
EXPECT_EQ
(
std
::
isinf
(
a
),
true
);
EXPECT_EQ
(
std
::
isinf
(
b
),
true
);
EXPECT_EQ
(
std
::
isinf
(
native_b
),
true
);
EXPECT_EQ
(
native_a
,
float16
(
0
));
}
TEST
(
float16
,
isnan
)
{
float16
a
;
a
.
x
=
0x7fff
;
float16
b
=
float16
(
NAN
);
float16
c
=
float16
(
5e40
);
// inf * +-0 will get a nan
float16
d
=
c
*
float16
(
0
);
EXPECT_EQ
(
std
::
isnan
(
a
),
true
);
EXPECT_EQ
(
std
::
isnan
(
b
),
true
);
EXPECT_EQ
(
std
::
isnan
(
d
),
true
);
}
TEST
(
float16
,
cast
)
{
float16
a
;
a
.
x
=
0x0070
;
auto
b
=
a
;
{
// change semantic, keep the same value
float16
c
=
reinterpret_cast
<
float16
&>
(
reinterpret_cast
<
unsigned
&>
(
b
));
EXPECT_EQ
(
b
,
c
);
}
{
// use uint32 low 16 bit store float16
uint32_t
c
=
reinterpret_cast
<
uint32_t
&>
(
b
);
float16
d
;
d
.
x
=
c
;
EXPECT_EQ
(
b
,
d
);
}
}
}
// namespace platform
}
// namespace paddle
#endif // PADDLE_CUDA_FP16
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录