Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a5ca2672
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a5ca2672
编写于
3月 29, 2023
作者:
C
chenxujun
提交者:
GitHub
3月 29, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix the type conflicts against the openblas (#52187)
上级
ad01eccd
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
54 addition
and
43 deletion
+54
-43
paddle/phi/backends/gpu/gpu_primitives.h
paddle/phi/backends/gpu/gpu_primitives.h
+54
-43
未找到文件。
paddle/phi/backends/gpu/gpu_primitives.h
浏览文件 @
a5ca2672
...
@@ -28,9 +28,6 @@ limitations under the License. */
...
@@ -28,9 +28,6 @@ limitations under the License. */
template
<
typename
T
>
template
<
typename
T
>
using
complex
=
phi
::
dtype
::
complex
<
T
>
;
using
complex
=
phi
::
dtype
::
complex
<
T
>
;
using
float16
=
phi
::
dtype
::
float16
;
using
bfloat16
=
phi
::
dtype
::
bfloat16
;
namespace
phi
{
namespace
phi
{
#define CUDA_ATOMIC_WRAPPER(op, T) \
#define CUDA_ATOMIC_WRAPPER(op, T) \
...
@@ -94,36 +91,39 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
...
@@ -94,36 +91,39 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
// convert the value into float and do the add arithmetic.
// convert the value into float and do the add arithmetic.
// then store the result into a uint32.
// then store the result into a uint32.
inline
static
__device__
uint32_t
add_to_low_half
(
uint32_t
val
,
float
x
)
{
inline
static
__device__
uint32_t
add_to_low_half
(
uint32_t
val
,
float
x
)
{
float16
low_half
;
phi
::
dtype
::
float16
low_half
;
// the float16 in lower 16bits
// the float16 in lower 16bits
low_half
.
x
=
static_cast
<
uint16_t
>
(
val
&
0xFFFFu
);
low_half
.
x
=
static_cast
<
uint16_t
>
(
val
&
0xFFFFu
);
low_half
=
static_cast
<
float16
>
(
static_cast
<
float
>
(
low_half
)
+
x
);
low_half
=
static_cast
<
phi
::
dtype
::
float16
>
(
static_cast
<
float
>
(
low_half
)
+
x
);
return
(
val
&
0xFFFF0000u
)
|
low_half
.
x
;
return
(
val
&
0xFFFF0000u
)
|
low_half
.
x
;
}
}
inline
static
__device__
uint32_t
add_to_high_half
(
uint32_t
val
,
float
x
)
{
inline
static
__device__
uint32_t
add_to_high_half
(
uint32_t
val
,
float
x
)
{
float16
high_half
;
phi
::
dtype
::
float16
high_half
;
// the float16 in higher 16bits
// the float16 in higher 16bits
high_half
.
x
=
static_cast
<
uint16_t
>
(
val
>>
16
);
high_half
.
x
=
static_cast
<
uint16_t
>
(
val
>>
16
);
high_half
=
static_cast
<
float16
>
(
static_cast
<
float
>
(
high_half
)
+
x
);
high_half
=
static_cast
<
phi
::
dtype
::
float16
>
(
static_cast
<
float
>
(
high_half
)
+
x
);
return
(
val
&
0xFFFFu
)
|
(
static_cast
<
uint32_t
>
(
high_half
.
x
)
<<
16
);
return
(
val
&
0xFFFFu
)
|
(
static_cast
<
uint32_t
>
(
high_half
.
x
)
<<
16
);
}
}
#if CUDA_VERSION >= 10000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
#if CUDA_VERSION >= 10000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
static
__device__
__forceinline__
float16
CUDAFP16ToPDFP16
(
__half
x
)
{
static
__device__
__forceinline__
phi
::
dtype
::
float16
CUDAFP16ToPDFP16
(
return
*
reinterpret_cast
<
float16
*>
(
&
x
);
__half
x
)
{
return
*
reinterpret_cast
<
phi
::
dtype
::
float16
*>
(
&
x
);
}
}
static
__device__
__forceinline__
__half
PDFP16ToCUDAFP16
(
float16
x
)
{
static
__device__
__forceinline__
__half
PDFP16ToCUDAFP16
(
phi
::
dtype
::
float16
x
)
{
return
*
reinterpret_cast
<
__half
*>
(
&
x
);
return
*
reinterpret_cast
<
__half
*>
(
&
x
);
}
}
CUDA_ATOMIC_WRAPPER
(
Add
,
float16
)
{
CUDA_ATOMIC_WRAPPER
(
Add
,
phi
::
dtype
::
float16
)
{
return
CUDAFP16ToPDFP16
(
return
CUDAFP16ToPDFP16
(
atomicAdd
(
reinterpret_cast
<
__half
*>
(
address
),
PDFP16ToCUDAFP16
(
val
)));
atomicAdd
(
reinterpret_cast
<
__half
*>
(
address
),
PDFP16ToCUDAFP16
(
val
)));
}
}
#else
#else
CUDA_ATOMIC_WRAPPER
(
Add
,
float16
)
{
CUDA_ATOMIC_WRAPPER
(
Add
,
phi
::
dtype
::
float16
)
{
// concrete packed float16 value may exsits in lower or higher 16bits
// concrete packed float16 value may exsits in lower or higher 16bits
// of the 32bits address.
// of the 32bits address.
uint32_t
*
address_as_ui
=
reinterpret_cast
<
uint32_t
*>
(
uint32_t
*
address_as_ui
=
reinterpret_cast
<
uint32_t
*>
(
...
@@ -140,7 +140,7 @@ CUDA_ATOMIC_WRAPPER(Add, float16) {
...
@@ -140,7 +140,7 @@ CUDA_ATOMIC_WRAPPER(Add, float16) {
assumed
=
old
;
assumed
=
old
;
old
=
atomicCAS
(
address_as_ui
,
assumed
,
add_to_low_half
(
assumed
,
val_f
));
old
=
atomicCAS
(
address_as_ui
,
assumed
,
add_to_low_half
(
assumed
,
val_f
));
}
while
(
old
!=
assumed
);
}
while
(
old
!=
assumed
);
float16
ret
;
phi
::
dtype
::
float16
ret
;
ret
.
x
=
old
&
0xFFFFu
;
ret
.
x
=
old
&
0xFFFFu
;
return
ret
;
return
ret
;
}
else
{
}
else
{
...
@@ -149,7 +149,7 @@ CUDA_ATOMIC_WRAPPER(Add, float16) {
...
@@ -149,7 +149,7 @@ CUDA_ATOMIC_WRAPPER(Add, float16) {
assumed
=
old
;
assumed
=
old
;
old
=
atomicCAS
(
address_as_ui
,
assumed
,
add_to_high_half
(
assumed
,
val_f
));
old
=
atomicCAS
(
address_as_ui
,
assumed
,
add_to_high_half
(
assumed
,
val_f
));
}
while
(
old
!=
assumed
);
}
while
(
old
!=
assumed
);
float16
ret
;
phi
::
dtype
::
float16
ret
;
ret
.
x
=
old
>>
16
;
ret
.
x
=
old
>>
16
;
return
ret
;
return
ret
;
}
}
...
@@ -168,14 +168,17 @@ struct VecAtomicAddHelper : VecAtomicAddHelperBase<T, false, void, void> {};
...
@@ -168,14 +168,17 @@ struct VecAtomicAddHelper : VecAtomicAddHelperBase<T, false, void, void> {};
#if CUDA_VERSION >= 10000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
#if CUDA_VERSION >= 10000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
template
<
>
template
<
>
struct
VecAtomicAddHelper
<
float16
>
struct
VecAtomicAddHelper
<
phi
::
dtype
::
float16
>
:
VecAtomicAddHelperBase
<
float16
,
true
,
__half
,
__half2
>
{};
:
VecAtomicAddHelperBase
<
phi
::
dtype
::
float16
,
true
,
__half
,
__half2
>
{};
#endif
#endif
#if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template
<
>
template
<
>
struct
VecAtomicAddHelper
<
bfloat16
>
struct
VecAtomicAddHelper
<
phi
::
dtype
::
bfloat16
>
:
VecAtomicAddHelperBase
<
bfloat16
,
true
,
__nv_bfloat16
,
__nv_bfloat162
>
{};
:
VecAtomicAddHelperBase
<
phi
::
dtype
::
bfloat16
,
true
,
__nv_bfloat16
,
__nv_bfloat162
>
{};
#endif
#endif
// The performance of "atomicAdd(half* )" is bad, but for "atomicAdd(half2* )"
// The performance of "atomicAdd(half* )" is bad, but for "atomicAdd(half2* )"
...
@@ -225,36 +228,40 @@ __device__ __forceinline__ void fastAtomicAdd(T *arr,
...
@@ -225,36 +228,40 @@ __device__ __forceinline__ void fastAtomicAdd(T *arr,
// NOTE(zhangbo): cuda do not have atomicCAS for __nv_bfloat16.
// NOTE(zhangbo): cuda do not have atomicCAS for __nv_bfloat16.
inline
static
__device__
uint32_t
bf16_add_to_low_half
(
uint32_t
val
,
float
x
)
{
inline
static
__device__
uint32_t
bf16_add_to_low_half
(
uint32_t
val
,
float
x
)
{
bfloat16
low_half
;
phi
::
dtype
::
bfloat16
low_half
;
// the bfloat16 in lower 16bits
// the bfloat16 in lower 16bits
low_half
.
x
=
static_cast
<
uint16_t
>
(
val
&
0xFFFFu
);
low_half
.
x
=
static_cast
<
uint16_t
>
(
val
&
0xFFFFu
);
low_half
=
static_cast
<
bfloat16
>
(
static_cast
<
float
>
(
low_half
)
+
x
);
low_half
=
static_cast
<
phi
::
dtype
::
bfloat16
>
(
static_cast
<
float
>
(
low_half
)
+
x
);
return
(
val
&
0xFFFF0000u
)
|
low_half
.
x
;
return
(
val
&
0xFFFF0000u
)
|
low_half
.
x
;
}
}
inline
static
__device__
uint32_t
bf16_add_to_high_half
(
uint32_t
val
,
float
x
)
{
inline
static
__device__
uint32_t
bf16_add_to_high_half
(
uint32_t
val
,
float
x
)
{
bfloat16
high_half
;
phi
::
dtype
::
bfloat16
high_half
;
// the bfloat16 in higher 16bits
// the bfloat16 in higher 16bits
high_half
.
x
=
static_cast
<
uint16_t
>
(
val
>>
16
);
high_half
.
x
=
static_cast
<
uint16_t
>
(
val
>>
16
);
high_half
=
static_cast
<
bfloat16
>
(
static_cast
<
float
>
(
high_half
)
+
x
);
high_half
=
static_cast
<
phi
::
dtype
::
bfloat16
>
(
static_cast
<
float
>
(
high_half
)
+
x
);
return
(
val
&
0xFFFFu
)
|
(
static_cast
<
uint32_t
>
(
high_half
.
x
)
<<
16
);
return
(
val
&
0xFFFFu
)
|
(
static_cast
<
uint32_t
>
(
high_half
.
x
)
<<
16
);
}
}
#if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static
__device__
__forceinline__
bfloat16
CUDABF16ToPDBF16
(
__nv_bfloat16
x
)
{
static
__device__
__forceinline__
phi
::
dtype
::
bfloat16
CUDABF16ToPDBF16
(
return
*
reinterpret_cast
<
bfloat16
*>
(
&
x
);
__nv_bfloat16
x
)
{
return
*
reinterpret_cast
<
phi
::
dtype
::
bfloat16
*>
(
&
x
);
}
}
static
__device__
__forceinline__
__nv_bfloat16
PDBF16ToCUDABF16
(
bfloat16
x
)
{
static
__device__
__forceinline__
__nv_bfloat16
PDBF16ToCUDABF16
(
phi
::
dtype
::
bfloat16
x
)
{
return
*
reinterpret_cast
<
__nv_bfloat16
*>
(
&
x
);
return
*
reinterpret_cast
<
__nv_bfloat16
*>
(
&
x
);
}
}
CUDA_ATOMIC_WRAPPER
(
Add
,
bfloat16
)
{
CUDA_ATOMIC_WRAPPER
(
Add
,
phi
::
dtype
::
bfloat16
)
{
return
CUDABF16ToPDBF16
(
atomicAdd
(
reinterpret_cast
<
__nv_bfloat16
*>
(
address
),
return
CUDABF16ToPDBF16
(
atomicAdd
(
reinterpret_cast
<
__nv_bfloat16
*>
(
address
),
PDBF16ToCUDABF16
(
val
)));
PDBF16ToCUDABF16
(
val
)));
}
}
#else
#else
CUDA_ATOMIC_WRAPPER
(
Add
,
bfloat16
)
{
CUDA_ATOMIC_WRAPPER
(
Add
,
phi
::
dtype
::
bfloat16
)
{
// concrete packed bfloat16 value may exsits in lower or higher 16bits
// concrete packed bfloat16 value may exsits in lower or higher 16bits
// of the 32bits address.
// of the 32bits address.
uint32_t
*
address_as_ui
=
reinterpret_cast
<
uint32_t
*>
(
uint32_t
*
address_as_ui
=
reinterpret_cast
<
uint32_t
*>
(
...
@@ -272,7 +279,7 @@ CUDA_ATOMIC_WRAPPER(Add, bfloat16) {
...
@@ -272,7 +279,7 @@ CUDA_ATOMIC_WRAPPER(Add, bfloat16) {
old
=
atomicCAS
(
old
=
atomicCAS
(
address_as_ui
,
assumed
,
bf16_add_to_low_half
(
assumed
,
val_f
));
address_as_ui
,
assumed
,
bf16_add_to_low_half
(
assumed
,
val_f
));
}
while
(
old
!=
assumed
);
}
while
(
old
!=
assumed
);
bfloat16
ret
;
phi
::
dtype
::
bfloat16
ret
;
ret
.
x
=
old
&
0xFFFFu
;
ret
.
x
=
old
&
0xFFFFu
;
return
ret
;
return
ret
;
}
else
{
}
else
{
...
@@ -282,7 +289,7 @@ CUDA_ATOMIC_WRAPPER(Add, bfloat16) {
...
@@ -282,7 +289,7 @@ CUDA_ATOMIC_WRAPPER(Add, bfloat16) {
old
=
atomicCAS
(
old
=
atomicCAS
(
address_as_ui
,
assumed
,
bf16_add_to_high_half
(
assumed
,
val_f
));
address_as_ui
,
assumed
,
bf16_add_to_high_half
(
assumed
,
val_f
));
}
while
(
old
!=
assumed
);
}
while
(
old
!=
assumed
);
bfloat16
ret
;
phi
::
dtype
::
bfloat16
ret
;
ret
.
x
=
old
>>
16
;
ret
.
x
=
old
>>
16
;
return
ret
;
return
ret
;
}
}
...
@@ -389,22 +396,24 @@ CUDA_ATOMIC_WRAPPER(Max, double) {
...
@@ -389,22 +396,24 @@ CUDA_ATOMIC_WRAPPER(Max, double) {
#ifdef PADDLE_CUDA_FP16
#ifdef PADDLE_CUDA_FP16
inline
static
__device__
uint32_t
max_to_low_half
(
uint32_t
val
,
float
x
)
{
inline
static
__device__
uint32_t
max_to_low_half
(
uint32_t
val
,
float
x
)
{
float16
low_half
;
phi
::
dtype
::
float16
low_half
;
// The float16 in lower 16bits
// The float16 in lower 16bits
low_half
.
x
=
static_cast
<
uint16_t
>
(
val
&
0xFFFFu
);
low_half
.
x
=
static_cast
<
uint16_t
>
(
val
&
0xFFFFu
);
low_half
=
static_cast
<
float16
>
(
max
(
static_cast
<
float
>
(
low_half
),
x
));
low_half
=
static_cast
<
phi
::
dtype
::
float16
>
(
max
(
static_cast
<
float
>
(
low_half
),
x
));
return
(
val
&
0xFFFF0000u
)
|
low_half
.
x
;
return
(
val
&
0xFFFF0000u
)
|
low_half
.
x
;
}
}
inline
static
__device__
uint32_t
max_to_high_half
(
uint32_t
val
,
float
x
)
{
inline
static
__device__
uint32_t
max_to_high_half
(
uint32_t
val
,
float
x
)
{
float16
high_half
;
phi
::
dtype
::
float16
high_half
;
// The float16 in higher 16bits
// The float16 in higher 16bits
high_half
.
x
=
static_cast
<
uint16_t
>
(
val
>>
16
);
high_half
.
x
=
static_cast
<
uint16_t
>
(
val
>>
16
);
high_half
=
static_cast
<
float16
>
(
max
(
static_cast
<
float
>
(
high_half
),
x
));
high_half
=
static_cast
<
phi
::
dtype
::
float16
>
(
max
(
static_cast
<
float
>
(
high_half
),
x
));
return
(
val
&
0xFFFFu
)
|
(
static_cast
<
uint32_t
>
(
high_half
.
x
)
<<
16
);
return
(
val
&
0xFFFFu
)
|
(
static_cast
<
uint32_t
>
(
high_half
.
x
)
<<
16
);
}
}
CUDA_ATOMIC_WRAPPER
(
Max
,
float16
)
{
CUDA_ATOMIC_WRAPPER
(
Max
,
phi
::
dtype
::
float16
)
{
if
(
*
address
>=
val
)
{
if
(
*
address
>=
val
)
{
return
*
address
;
return
*
address
;
}
}
...
@@ -420,7 +429,7 @@ CUDA_ATOMIC_WRAPPER(Max, float16) {
...
@@ -420,7 +429,7 @@ CUDA_ATOMIC_WRAPPER(Max, float16) {
assumed
=
old
;
assumed
=
old
;
old
=
atomicCAS
(
address_as_ui
,
assumed
,
max_to_low_half
(
assumed
,
val_f
));
old
=
atomicCAS
(
address_as_ui
,
assumed
,
max_to_low_half
(
assumed
,
val_f
));
}
while
(
old
!=
assumed
);
}
while
(
old
!=
assumed
);
float16
ret
;
phi
::
dtype
::
float16
ret
;
ret
.
x
=
old
&
0xFFFFu
;
ret
.
x
=
old
&
0xFFFFu
;
return
ret
;
return
ret
;
}
else
{
}
else
{
...
@@ -429,7 +438,7 @@ CUDA_ATOMIC_WRAPPER(Max, float16) {
...
@@ -429,7 +438,7 @@ CUDA_ATOMIC_WRAPPER(Max, float16) {
assumed
=
old
;
assumed
=
old
;
old
=
atomicCAS
(
address_as_ui
,
assumed
,
max_to_high_half
(
assumed
,
val_f
));
old
=
atomicCAS
(
address_as_ui
,
assumed
,
max_to_high_half
(
assumed
,
val_f
));
}
while
(
old
!=
assumed
);
}
while
(
old
!=
assumed
);
float16
ret
;
phi
::
dtype
::
float16
ret
;
ret
.
x
=
old
>>
16
;
ret
.
x
=
old
>>
16
;
return
ret
;
return
ret
;
}
}
...
@@ -522,22 +531,24 @@ CUDA_ATOMIC_WRAPPER(Min, double) {
...
@@ -522,22 +531,24 @@ CUDA_ATOMIC_WRAPPER(Min, double) {
#ifdef PADDLE_CUDA_FP16
#ifdef PADDLE_CUDA_FP16
inline
static
__device__
uint32_t
min_to_low_half
(
uint32_t
val
,
float
x
)
{
inline
static
__device__
uint32_t
min_to_low_half
(
uint32_t
val
,
float
x
)
{
float16
low_half
;
phi
::
dtype
::
float16
low_half
;
// The float16 in lower 16bits
// The float16 in lower 16bits
low_half
.
x
=
static_cast
<
uint16_t
>
(
val
&
0xFFFFu
);
low_half
.
x
=
static_cast
<
uint16_t
>
(
val
&
0xFFFFu
);
low_half
=
static_cast
<
float16
>
(
min
(
static_cast
<
float
>
(
low_half
),
x
));
low_half
=
static_cast
<
phi
::
dtype
::
float16
>
(
min
(
static_cast
<
float
>
(
low_half
),
x
));
return
(
val
&
0xFFFF0000u
)
|
low_half
.
x
;
return
(
val
&
0xFFFF0000u
)
|
low_half
.
x
;
}
}
inline
static
__device__
uint32_t
min_to_high_half
(
uint32_t
val
,
float
x
)
{
inline
static
__device__
uint32_t
min_to_high_half
(
uint32_t
val
,
float
x
)
{
float16
high_half
;
phi
::
dtype
::
float16
high_half
;
// The float16 in higher 16bits
// The float16 in higher 16bits
high_half
.
x
=
static_cast
<
uint16_t
>
(
val
>>
16
);
high_half
.
x
=
static_cast
<
uint16_t
>
(
val
>>
16
);
high_half
=
static_cast
<
float16
>
(
min
(
static_cast
<
float
>
(
high_half
),
x
));
high_half
=
static_cast
<
phi
::
dtype
::
float16
>
(
min
(
static_cast
<
float
>
(
high_half
),
x
));
return
(
val
&
0xFFFFu
)
|
(
static_cast
<
uint32_t
>
(
high_half
.
x
)
<<
16
);
return
(
val
&
0xFFFFu
)
|
(
static_cast
<
uint32_t
>
(
high_half
.
x
)
<<
16
);
}
}
CUDA_ATOMIC_WRAPPER
(
Min
,
float16
)
{
CUDA_ATOMIC_WRAPPER
(
Min
,
phi
::
dtype
::
float16
)
{
if
(
*
address
<=
val
)
{
if
(
*
address
<=
val
)
{
return
*
address
;
return
*
address
;
}
}
...
@@ -553,7 +564,7 @@ CUDA_ATOMIC_WRAPPER(Min, float16) {
...
@@ -553,7 +564,7 @@ CUDA_ATOMIC_WRAPPER(Min, float16) {
assumed
=
old
;
assumed
=
old
;
old
=
atomicCAS
(
address_as_ui
,
assumed
,
min_to_low_half
(
assumed
,
val_f
));
old
=
atomicCAS
(
address_as_ui
,
assumed
,
min_to_low_half
(
assumed
,
val_f
));
}
while
(
old
!=
assumed
);
}
while
(
old
!=
assumed
);
float16
ret
;
phi
::
dtype
::
float16
ret
;
ret
.
x
=
old
&
0xFFFFu
;
ret
.
x
=
old
&
0xFFFFu
;
return
ret
;
return
ret
;
}
else
{
}
else
{
...
@@ -562,7 +573,7 @@ CUDA_ATOMIC_WRAPPER(Min, float16) {
...
@@ -562,7 +573,7 @@ CUDA_ATOMIC_WRAPPER(Min, float16) {
assumed
=
old
;
assumed
=
old
;
old
=
atomicCAS
(
address_as_ui
,
assumed
,
min_to_high_half
(
assumed
,
val_f
));
old
=
atomicCAS
(
address_as_ui
,
assumed
,
min_to_high_half
(
assumed
,
val_f
));
}
while
(
old
!=
assumed
);
}
while
(
old
!=
assumed
);
float16
ret
;
phi
::
dtype
::
float16
ret
;
ret
.
x
=
old
>>
16
;
ret
.
x
=
old
>>
16
;
return
ret
;
return
ret
;
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录