Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindspore
提交
6719169a
M
mindspore
项目概览
MindSpore
/
mindspore
通知
35
Star
15
Fork
15
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
6719169a
编写于
8月 18, 2020
作者:
P
Peilin Wang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
added type support for atomic add and scatternd
fix ci fix ci
上级
0e27a04d
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
224 addition
and
115 deletion
+224
-115
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.cc
...ckend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.cc
+39
-33
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu
...kend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu
+4
-4
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu
...c/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu
+1
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cu
...mpiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cu
+1
-1
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu
...c/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu
+4
-4
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cu
...ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cu
+80
-70
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/util.cuh
...pore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/util.cuh
+95
-3
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.cc
浏览文件 @
6719169a
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h"
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_TWO
(
ScatterNd
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
ScatterNdGpuFwdKernel
,
float
,
int
)
MS_REG_GPU_KERNEL_TWO
(
ScatterNd
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
ScatterNdGpuFwdKernel
,
half
,
int
)
MS_REG_GPU_KERNEL_TWO
(
ScatterNd
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeInt32
),
ScatterNdGpuFwdKernel
,
int
,
int
)
}
// namespace kernel
}
// namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h"
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_TWO
(
ScatterNd
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
ScatterNdGpuFwdKernel
,
float
,
int
)
MS_REG_GPU_KERNEL_TWO
(
ScatterNd
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
ScatterNdGpuFwdKernel
,
half
,
int
)
MS_REG_GPU_KERNEL_TWO
(
ScatterNd
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeInt32
),
ScatterNdGpuFwdKernel
,
int
,
int
)
MS_REG_GPU_KERNEL_TWO
(
ScatterNd
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddInputAttr
(
kNumberTypeInt16
).
AddOutputAttr
(
kNumberTypeInt16
),
ScatterNdGpuFwdKernel
,
short
,
int
)
// NOLINT
MS_REG_GPU_KERNEL_TWO
(
ScatterNd
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddInputAttr
(
kNumberTypeUInt8
).
AddOutputAttr
(
kNumberTypeUInt8
),
ScatterNdGpuFwdKernel
,
uchar
,
int
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu
浏览文件 @
6719169a
...
...
@@ -23,9 +23,9 @@ struct MinimumGradFunc {
__device__
__forceinline__
void
operator
()(
const
T
&
x1
,
const
T
&
x2
,
const
bool
&
grad_x1
,
const
bool
&
grad_x2
,
const
T
&
dy
,
T
*
dx1
,
T
*
dx2
)
{
if
(
grad_x1
&&
x1
<
x2
)
{
ms_atomic_a
dd
(
dx1
,
dy
);
MsAtomicA
dd
(
dx1
,
dy
);
}
else
if
(
grad_x2
&&
x1
>=
x2
)
{
ms_atomic_a
dd
(
dx2
,
dy
);
MsAtomicA
dd
(
dx2
,
dy
);
}
}
};
...
...
@@ -35,9 +35,9 @@ struct MaximumGradFunc {
__device__
__forceinline__
void
operator
()(
const
T
&
x1
,
const
T
&
x2
,
const
bool
&
grad_x1
,
const
bool
&
grad_x2
,
const
T
&
dy
,
T
*
dx1
,
T
*
dx2
)
{
if
(
grad_x1
&&
x1
>
x2
)
{
ms_atomic_a
dd
(
dx1
,
dy
);
MsAtomicA
dd
(
dx1
,
dy
);
}
else
if
(
grad_x2
&&
x1
<=
x2
)
{
ms_atomic_a
dd
(
dx2
,
dy
);
MsAtomicA
dd
(
dx2
,
dy
);
}
}
};
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu
浏览文件 @
6719169a
...
...
@@ -15,6 +15,7 @@
*/
#include <vector>
#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cu
浏览文件 @
6719169a
...
...
@@ -61,7 +61,7 @@ __global__ void ResizeNearestNeighborGrad(const int input_size, const T *input,
out_width
-
1
);
// pos_array[0] N, pos_array[1] C, out_y H, out_x W
output_pos
=
pos_array
[
0
]
*
d2
*
d3
*
d4
+
pos_array
[
1
]
*
d3
*
d4
+
out_y
*
d4
+
out_x
;
ms_atomic_a
dd
(
&
output
[
output_pos
],
input
[
pos
]);
MsAtomicA
dd
(
&
output
[
output_pos
],
input
[
pos
]);
}
}
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/roi_align_impl.cu
浏览文件 @
6719169a
...
...
@@ -218,10 +218,10 @@ __global__ void ROIAlignGradKernel(size_t size, const T *dy, const T *roi_boxes,
T
*
dx_3
=
dx
+
offset
+
y_high
*
width
+
x_low
;
T
*
dx_4
=
dx
+
offset
+
y_high
*
width
+
x_high
;
if
(
x_low
>=
0
&&
x_high
>=
0
&&
y_low
>=
0
&&
y_high
>=
0
)
{
ms_atomic_a
dd
(
dx_1
,
g1
);
ms_atomic_a
dd
(
dx_2
,
g2
);
ms_atomic_a
dd
(
dx_3
,
g3
);
ms_atomic_a
dd
(
dx_4
,
g4
);
MsAtomicA
dd
(
dx_1
,
g1
);
MsAtomicA
dd
(
dx_2
,
g2
);
MsAtomicA
dd
(
dx_3
,
g3
);
MsAtomicA
dd
(
dx_4
,
g4
);
}
}
}
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cu
浏览文件 @
6719169a
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
#include "runtime/device/gpu/cuda_common.h"
template
<
typename
T
,
typename
S
>
__global__
void
ScatterNdKernel
(
S
*
indices
,
T
*
update
,
T
*
output
,
const
size_t
block_size
,
const
size_t
input_size
,
const
size_t
output_size
,
const
size_t
indices_dim_0
,
const
size_t
indices_dim_1
,
S
*
indices_stride
,
S
*
work_shape
)
{
int
i
,
j
;
for
(
int
read_index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
read_index
<
input_size
;
read_index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
write_index
=
0
;
bool
out_bound
=
false
;
i
=
read_index
/
block_size
;
j
=
read_index
%
block_size
;
for
(
size_t
k
=
0
;
k
<
indices_dim_1
;
k
++
)
{
S
indices_i
=
indices
[
i
*
indices_dim_1
+
k
];
out_bound
|=
indices_i
>=
work_shape
[
k
];
write_index
+=
indices_i
*
indices_stride
[
k
];
}
write_index
+=
j
;
out_bound
|=
write_index
>=
output_size
;
if
(
!
out_bound
)
{
ms_atomic_add
(
&
output
[
write_index
],
update
[
read_index
]);
}
}
}
template
<
typename
T
,
typename
S
>
void
ScatterNd
(
S
*
indices
,
T
*
update
,
T
*
output
,
const
size_t
&
block_size
,
const
size_t
&
input_size
,
const
size_t
&
output_size
,
const
size_t
&
indices_dim_0
,
const
size_t
&
indices_dim_1
,
S
*
indices_stride
,
S
*
work_shape
,
cudaStream_t
stream
)
{
ScatterNdKernel
<<<
GET_BLOCKS
(
output_size
),
GET_THREADS
,
0
,
stream
>>>
(
indices
,
update
,
output
,
block_size
,
input_size
,
output_size
,
indices_dim_0
,
indices_dim_1
,
indices_stride
,
work_shape
);
return
;
}
template
void
ScatterNd
<
float
,
int
>(
int
*
indices
,
float
*
update
,
float
*
output
,
const
size_t
&
block_size
,
const
size_t
&
input_size
,
const
size_t
&
output_size
,
const
size_t
&
indices_dim_0
,
const
size_t
&
indices_dim_1
,
int
*
indices_stride
,
int
*
work_shape
,
cudaStream_t
stream
);
template
void
ScatterNd
<
half
,
int
>(
int
*
indices
,
half
*
update
,
half
*
output
,
const
size_t
&
block_size
,
const
size_t
&
input_size
,
const
size_t
&
output_size
,
const
size_t
&
indices_dim_0
,
const
size_t
&
indices_dim_1
,
int
*
indices_stride
,
int
*
work_shape
,
cudaStream_t
stream
);
template
void
ScatterNd
<
int
,
int
>(
int
*
indices
,
int
*
update
,
int
*
output
,
const
size_t
&
block_size
,
const
size_t
&
input_size
,
const
size_t
&
output_size
,
const
size_t
&
indices_dim_0
,
const
size_t
&
indices_dim_1
,
int
*
indices_stride
,
int
*
work_shape
,
cudaStream_t
stream
);
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
#include "runtime/device/gpu/cuda_common.h"
template
<
typename
T
,
typename
S
>
__global__
void
ScatterNdKernel
(
S
*
indices
,
T
*
update
,
T
*
output
,
const
size_t
block_size
,
const
size_t
input_size
,
const
size_t
output_size
,
const
size_t
indices_dim_0
,
const
size_t
indices_dim_1
,
S
*
indices_stride
,
S
*
work_shape
)
{
int
i
,
j
;
for
(
int
read_index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
read_index
<
input_size
;
read_index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
write_index
=
0
;
bool
out_bound
=
false
;
i
=
read_index
/
block_size
;
j
=
read_index
%
block_size
;
for
(
size_t
k
=
0
;
k
<
indices_dim_1
;
k
++
)
{
S
indices_i
=
indices
[
i
*
indices_dim_1
+
k
];
out_bound
|=
indices_i
>=
work_shape
[
k
];
write_index
+=
indices_i
*
indices_stride
[
k
];
}
write_index
+=
j
;
out_bound
|=
write_index
>=
output_size
;
if
(
!
out_bound
)
{
MsAtomicAdd
(
&
output
[
write_index
],
update
[
read_index
]);
}
}
}
template
<
typename
T
,
typename
S
>
void
ScatterNd
(
S
*
indices
,
T
*
update
,
T
*
output
,
const
size_t
&
block_size
,
const
size_t
&
input_size
,
const
size_t
&
output_size
,
const
size_t
&
indices_dim_0
,
const
size_t
&
indices_dim_1
,
S
*
indices_stride
,
S
*
work_shape
,
cudaStream_t
stream
)
{
ScatterNdKernel
<<<
GET_BLOCKS
(
output_size
),
GET_THREADS
,
0
,
stream
>>>
(
indices
,
update
,
output
,
block_size
,
input_size
,
output_size
,
indices_dim_0
,
indices_dim_1
,
indices_stride
,
work_shape
);
return
;
}
template
void
ScatterNd
<
float
,
int
>(
int
*
indices
,
float
*
update
,
float
*
output
,
const
size_t
&
block_size
,
const
size_t
&
input_size
,
const
size_t
&
output_size
,
const
size_t
&
indices_dim_0
,
const
size_t
&
indices_dim_1
,
int
*
indices_stride
,
int
*
work_shape
,
cudaStream_t
stream
);
template
void
ScatterNd
<
half
,
int
>(
int
*
indices
,
half
*
update
,
half
*
output
,
const
size_t
&
block_size
,
const
size_t
&
input_size
,
const
size_t
&
output_size
,
const
size_t
&
indices_dim_0
,
const
size_t
&
indices_dim_1
,
int
*
indices_stride
,
int
*
work_shape
,
cudaStream_t
stream
);
template
void
ScatterNd
<
int
,
int
>(
int
*
indices
,
int
*
update
,
int
*
output
,
const
size_t
&
block_size
,
const
size_t
&
input_size
,
const
size_t
&
output_size
,
const
size_t
&
indices_dim_0
,
const
size_t
&
indices_dim_1
,
int
*
indices_stride
,
int
*
work_shape
,
cudaStream_t
stream
);
// NOLINTNEXTLINE
template
void
ScatterNd
<
short
,
int
>(
int
*
indices
,
short
*
update
,
short
*
output
,
const
size_t
&
block_size
,
const
size_t
&
input_size
,
const
size_t
&
output_size
,
const
size_t
&
indices_dim_0
,
const
size_t
&
indices_dim_1
,
int
*
indices_stride
,
int
*
work_shape
,
cudaStream_t
stream
);
template
void
ScatterNd
<
unsigned
char
,
int
>(
int
*
indices
,
unsigned
char
*
update
,
unsigned
char
*
output
,
const
size_t
&
block_size
,
const
size_t
&
input_size
,
const
size_t
&
output_size
,
const
size_t
&
indices_dim_0
,
const
size_t
&
indices_dim_1
,
int
*
indices_stride
,
int
*
work_shape
,
cudaStream_t
stream
);
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/util.cuh
浏览文件 @
6719169a
...
...
@@ -19,11 +19,41 @@
#include <cuda_fp16.h>
inline
__device__
float
ms_atomic_add
(
float
*
address
,
float
val
)
{
return
atomicAdd
(
address
,
val
);
}
__device__
static
inline
float
MsAtomicAdd
(
float
*
address
,
const
float
val
)
{
return
atomicAdd
(
address
,
val
);
}
inline
__device__
int
ms_atomic_a
dd
(
int
*
address
,
int
val
)
{
return
atomicAdd
(
address
,
val
);
}
__device__
static
inline
int
MsAtomicA
dd
(
int
*
address
,
int
val
)
{
return
atomicAdd
(
address
,
val
);
}
inline
__device__
half
ms_atomic_add
(
half
*
address
,
half
val
)
{
__device__
static
inline
unsigned
int
MsAtomicAdd
(
unsigned
int
*
address
,
unsigned
int
val
)
{
return
atomicAdd
(
address
,
val
);
}
__device__
static
inline
unsigned
char
MsAtomicAdd
(
short
*
address
,
short
val
)
{
// NOLINT
bool
is_4_byte_aligned
=
((
size_t
)
address
&
2
)
==
0
;
unsigned
int
*
aligned
=
(
unsigned
int
*
)
((
size_t
)
address
&
~
2
);
unsigned
int
old
=
*
aligned
;
unsigned
int
assumed
;
do
{
assumed
=
old
;
unsigned
int
replacement
;
if
(
is_4_byte_aligned
)
{
replacement
=
(
old
&
0xffff0000
)
|
(((
old
&
0xffff
)
+
val
)
&
0xffff
);
}
else
{
replacement
=
old
+
((
unsigned
int
)
val
<<
16
);
}
old
=
atomicCAS
(
aligned
,
assumed
,
replacement
);
}
while
(
assumed
!=
old
);
if
(
is_4_byte_aligned
)
{
return
(
short
)
(
old
&
0xffff
);
// NOLINT
}
else
{
return
(
short
)
(
old
>>
16
);
// NOLINT
}
}
__device__
static
inline
half
MsAtomicAdd
(
half
*
address
,
half
val
)
{
unsigned
int
*
aligned
=
reinterpret_cast
<
unsigned
int
*>
(
reinterpret_cast
<
size_t
>
(
address
)
-
(
reinterpret_cast
<
size_t
>
(
address
)
&
2
));
unsigned
int
old
=
*
aligned
;
...
...
@@ -42,4 +72,66 @@ inline __device__ half ms_atomic_add(half *address, half val) {
return
half
(
raw
);
}
__device__
static
inline
unsigned
char
MsAtomicAdd
(
unsigned
char
*
address
,
unsigned
char
val
)
{
// We use cuda's atomicCAS(unsigned int*, unsigned int, unsigned int) to
// implement MsAtomicAdd. An unsigned char may not be 4 byte aligned, but
// unsigned int* must be 4 byte aligned. This variable contains the offset,
// in bytes, of the beginning of address, within the 4 byte aligned space that
// contains it.
size_t
address_offset
=
(
size_t
)
address
&
3
;
// Address of the 4 byte aligned space that contains address.
unsigned
int
*
aligned
=
(
unsigned
int
*
)
((
unsigned
char
*
)
address
-
address_offset
);
// Constants which will be used later with __byte_perm. __byte_perm is a cuda
// function which takes 3 unsigned int's (x, y, selector) as parameters and
// returns an int. __byte_perm returns an integer by selecting bytes from x
// and y based on the given selector. The selector 0x3210 in will select all
// four bytes from x, preserving their original order. The position of the
// "4" in the selector indicates the position in the output where the first
// byte of y will end up.
unsigned
int
selectors
[]
=
{
0x3214
,
0x3240
,
0x3410
,
0x4210
};
// Gets the selector that will select the bytes at address from aligned
unsigned
int
selector
=
selectors
[
address_offset
];
unsigned
int
old
=
*
aligned
;
unsigned
int
assumed
=
0
;
do
{
assumed
=
old
;
// Selects the byte associated with address and put it as the first byte of
// this variable, so that we can add val to the value at address.
unsigned
int
sum
=
val
+
__byte_perm
(
old
,
0
,
address_offset
);
// Takes old and replaces the byte corresponding to address with the sum.
unsigned
int
replacement
=
__byte_perm
(
old
,
sum
,
selector
);
// Try to replace the old value with the new value
old
=
atomicCAS
(
aligned
,
assumed
,
replacement
);
}
while
(
old
!=
assumed
);
// Select the single byte corredsponding to address and return it.
return
__byte_perm
(
old
,
0
,
address_offset
);
}
__device__
static
inline
char
MsAtomicAdd
(
char
*
address
,
char
val
)
{
size_t
address_offset
=
(
size_t
)
address
&
3
;
unsigned
int
*
aligned
=
reinterpret_cast
<
unsigned
int
*>
(
reinterpret_cast
<
char
*>
(
address
)
-
address_offset
);
unsigned
int
selectors
[]
=
{
0x3214
,
0x3240
,
0x3410
,
0x4210
};
unsigned
int
selector
=
selectors
[
address_offset
];
unsigned
int
old
=
*
aligned
;
unsigned
int
assumed
=
0
;
do
{
assumed
=
old
;
unsigned
int
sum
=
val
+
__byte_perm
(
old
,
0
,
address_offset
);
unsigned
int
replacement
=
__byte_perm
(
old
,
sum
,
selector
);
old
=
atomicCAS
(
aligned
,
assumed
,
replacement
);
}
while
(
old
!=
assumed
);
return
__byte_perm
(
old
,
0
,
address_offset
);
}
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UTIL_H_
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录