Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
25168309
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
25168309
编写于
7月 16, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 16, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3045 Gpu support TopK kernel
Merge pull request !3045 from chenweifeng/sort
上级
e249197c
c10e0773
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
420 addition
and
1 deletion
+420
-1
mindspore/ccsrc/CMakeLists.txt
mindspore/ccsrc/CMakeLists.txt
+1
-1
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.cc
...src/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.cc
+29
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h
...csrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h
+110
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cu
.../ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cu
+162
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh
...ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh
+32
-0
mindspore/ccsrc/runtime/device/gpu/cuda_common.h
mindspore/ccsrc/runtime/device/gpu/cuda_common.h
+4
-0
tests/st/ops/gpu/test_topk_op.py
tests/st/ops/gpu/test_topk_op.py
+82
-0
未找到文件。
mindspore/ccsrc/CMakeLists.txt
浏览文件 @
25168309
...
@@ -44,7 +44,7 @@ if(ENABLE_GPU)
...
@@ -44,7 +44,7 @@ if(ENABLE_GPU)
"backend/kernel_compiler/akg/akg_kernel_attrs_process.cc"
"backend/kernel_compiler/akg/akg_kernel_attrs_process.cc"
)
)
list
(
APPEND CUDA_NVCC_FLAGS -arch=sm_53
)
list
(
APPEND CUDA_NVCC_FLAGS -arch=sm_53
--expt-relaxed-constexpr
)
list
(
REMOVE_ITEM GPU_SRC_LIST
"runtime/device/gpu/blocking_queue.cc"
"runtime/device/gpu/gpu_buffer_mgr.cc"
)
list
(
REMOVE_ITEM GPU_SRC_LIST
"runtime/device/gpu/blocking_queue.cc"
"runtime/device/gpu/gpu_buffer_mgr.cc"
)
list
(
REMOVE_ITEM GPU_SRC_LIST
"runtime/device/gpu/mpi/mpi_initializer.cc"
list
(
REMOVE_ITEM GPU_SRC_LIST
"runtime/device/gpu/mpi/mpi_initializer.cc"
"runtime/device/gpu/distribution/collective_wrapper.cc"
"runtime/device/gpu/distribution/collective_wrapper.cc"
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.cc
0 → 100644
浏览文件 @
25168309
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h"
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_TWO
(
TopK
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeInt32
)
.
AddOutputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeInt32
),
TopKGpuKernel
,
float
,
int
)
}
}
// namespace mindspore
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/topk_gpu_kernel.h
0 → 100644
浏览文件 @
25168309
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_TOPK_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_TOPK_H_
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh"
namespace
mindspore
{
namespace
kernel
{
template
<
typename
T
,
typename
S
>
class
TopKGpuKernel
:
public
GpuKernel
{
public:
TopKGpuKernel
()
:
sorted_
(
false
),
outer_size_
(
1
),
inner_size_
(
1
),
k_
(
1
),
use_share_mem_
(
true
),
ceil_power2_
(
0
)
{}
~
TopKGpuKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspaces
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
override
{
T
*
input_addr
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
S
*
k
=
GetDeviceAddress
<
S
>
(
inputs
,
1
);
T
*
output_addr
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
S
*
indices
=
GetDeviceAddress
<
S
>
(
outputs
,
1
);
T
*
data_buff
=
nullptr
;
S
*
index_buff
=
nullptr
;
if
(
use_share_mem_
==
false
)
{
data_buff
=
GetDeviceAddress
<
T
>
(
workspaces
,
0
);
index_buff
=
GetDeviceAddress
<
S
>
(
workspaces
,
1
);
}
TopK
(
outer_size_
,
inner_size_
,
input_addr
,
k
,
output_addr
,
indices
,
data_buff
,
index_buff
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
if
(
sorted_
==
false
)
{
std
::
cout
<<
"================BitonicSortByKey"
<<
std
::
endl
;
BitonicSortByKey
(
outer_size_
,
k_
,
output_addr
,
indices
,
data_buff
,
index_buff
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
auto
input_shapes
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
auto
output_shapes
=
AnfAlgo
::
GetOutputInferShape
(
kernel_node
,
0
);
for
(
size_t
i
=
0
;
i
<
input_shapes
.
size
()
-
1
;
i
++
)
{
outer_size_
*=
input_shapes
[
i
];
}
inner_size_
=
input_shapes
[
input_shapes
.
size
()
-
1
];
k_
=
output_shapes
[
output_shapes
.
size
()
-
1
];
sorted_
=
GetAttr
<
bool
>
(
kernel_node
,
"sorted"
);
ceil_power2_
=
RoundUpPower2
(
inner_size_
);
size_t
buffer_size
=
ceil_power2_
*
(
sizeof
(
T
)
+
sizeof
(
S
));
if
(
buffer_size
>
SHARED_MEM_PER_BLOCK
)
{
use_share_mem_
=
false
;
MS_LOG
(
WARNING
)
<<
"CUDA share memory not enough, sort with RAM"
;
}
InitSizeLists
();
return
true
;
}
protected:
void
InitSizeLists
()
override
{
input_size_list_
.
push_back
(
outer_size_
*
inner_size_
*
sizeof
(
T
));
input_size_list_
.
push_back
(
sizeof
(
S
));
output_size_list_
.
push_back
(
outer_size_
*
k_
*
sizeof
(
T
));
output_size_list_
.
push_back
(
outer_size_
*
k_
*
sizeof
(
S
));
if
(
use_share_mem_
==
false
)
{
workspace_size_list_
.
push_back
(
outer_size_
*
ceil_power2_
*
sizeof
(
T
));
workspace_size_list_
.
push_back
(
outer_size_
*
ceil_power2_
*
sizeof
(
S
));
}
}
private:
bool
sorted_
;
int
outer_size_
;
int
inner_size_
;
int
k_
;
bool
use_share_mem_
;
int
ceil_power2_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
};
}
// namespace kernel
}
// namespace mindspore
#endif // TopKpuKernel
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cu
0 → 100644
浏览文件 @
25168309
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh"
#include <limits>
#include <algorithm>
int
RoundUpPower2
(
int
v
)
{
v
--
;
v
|=
v
>>
1
;
v
|=
v
>>
2
;
v
|=
v
>>
4
;
v
|=
v
>>
8
;
v
|=
v
>>
16
;
v
++
;
return
v
;
}
template
<
typename
T
>
__inline__
__device__
void
Swap
(
T
*
lhs
,
T
*
rhs
)
{
T
tmp
=
lhs
[
0
];
lhs
[
0
]
=
rhs
[
0
];
rhs
[
0
]
=
tmp
;
}
template
<
typename
T
,
typename
S
>
__global__
void
TopkKernel
(
const
int
outer
,
const
int
inner
,
const
int
ceil_power2
,
const
T
*
input
,
const
S
*
k
,
T
*
output
,
S
*
indices
,
T
*
data_buff
,
S
*
index_buff
)
{
// default: sort with share memory
extern
__shared__
T
share_mem
[];
T
*
data_arr
=
share_mem
;
S
*
index_arr
=
reinterpret_cast
<
S
*>
(
data_arr
+
ceil_power2
);
// sort with RAM
if
(
data_buff
!=
nullptr
&&
index_buff
!=
nullptr
)
{
data_arr
=
data_buff
+
blockIdx
.
x
*
ceil_power2
;
index_arr
=
index_buff
+
blockIdx
.
x
*
ceil_power2
;
}
for
(
int
i
=
threadIdx
.
x
;
i
<
ceil_power2
;
i
+=
blockDim
.
x
)
{
data_arr
[
i
]
=
(
i
<
inner
)
?
input
[
blockIdx
.
x
*
inner
+
i
]
:
std
::
numeric_limits
<
T
>::
max
();
index_arr
[
i
]
=
i
;
}
__syncthreads
();
for
(
size_t
i
=
2
;
i
<=
ceil_power2
;
i
<<=
1
)
{
for
(
size_t
j
=
(
i
>>
1
);
j
>
0
;
j
>>=
1
)
{
for
(
size_t
tid
=
threadIdx
.
x
;
tid
<
ceil_power2
;
tid
+=
blockDim
.
x
)
{
size_t
tid_comp
=
tid
^
j
;
if
(
tid_comp
>
tid
)
{
if
((
tid
&
i
)
==
0
)
{
if
(
data_arr
[
tid
]
>
data_arr
[
tid_comp
])
{
Swap
(
&
data_arr
[
tid
],
&
data_arr
[
tid_comp
]);
Swap
(
&
index_arr
[
tid
],
&
index_arr
[
tid_comp
]);
}
}
else
{
if
(
data_arr
[
tid
]
<
data_arr
[
tid_comp
])
{
Swap
(
&
data_arr
[
tid
],
&
data_arr
[
tid_comp
]);
Swap
(
&
index_arr
[
tid
],
&
index_arr
[
tid_comp
]);
}
}
}
}
__syncthreads
();
}
}
for
(
size_t
tid
=
threadIdx
.
x
;
tid
<
k
[
0
];
tid
+=
blockDim
.
x
)
{
output
[
blockIdx
.
x
*
k
[
0
]
+
tid
]
=
data_arr
[
inner
-
tid
-
1
];
indices
[
blockIdx
.
x
*
k
[
0
]
+
tid
]
=
index_arr
[
inner
-
tid
-
1
];
}
}
template
<
typename
T
,
typename
S
>
void
TopK
(
const
int
&
outer
,
const
int
&
inner
,
const
T
*
input
,
const
S
*
k
,
T
*
output
,
S
*
indices
,
T
*
data_buff
,
S
*
index_buff
,
cudaStream_t
stream
)
{
int
ceil_power2
=
RoundUpPower2
(
inner
);
int
share_mem
=
(
data_buff
==
nullptr
)
?
ceil_power2
*
(
sizeof
(
T
)
+
sizeof
(
S
))
:
0
;
int
thread
=
std
::
min
(
ceil_power2
,
GET_THREADS
);
TopkKernel
<<<
outer
,
thread
,
share_mem
,
stream
>>>
(
outer
,
inner
,
ceil_power2
,
input
,
k
,
output
,
indices
,
data_buff
,
index_buff
);
}
template
<
typename
T
,
typename
S
>
__global__
void
BitonicSortByKeyKernel
(
const
int
outer
,
const
int
inner
,
const
int
ceil_power2
,
T
*
input
,
S
*
indices
,
T
*
data_buff
,
S
*
index_buff
)
{
// default: sort with share memory
extern
__shared__
T
share_mem
[];
T
*
data_arr
=
share_mem
;
S
*
index_arr
=
reinterpret_cast
<
S
*>
(
data_arr
+
ceil_power2
);
// sort with RAM
if
(
data_buff
!=
nullptr
&&
index_buff
!=
nullptr
)
{
data_arr
=
data_buff
+
blockIdx
.
x
*
ceil_power2
;
index_arr
=
index_buff
+
blockIdx
.
x
*
ceil_power2
;
}
for
(
int
i
=
threadIdx
.
x
;
i
<
ceil_power2
;
i
+=
blockDim
.
x
)
{
data_arr
[
i
]
=
(
i
<
inner
)
?
input
[
blockIdx
.
x
*
inner
+
i
]
:
std
::
numeric_limits
<
T
>::
max
();
index_arr
[
i
]
=
(
i
<
inner
)
?
indices
[
blockIdx
.
x
*
inner
+
i
]
:
std
::
numeric_limits
<
S
>::
max
();;
}
__syncthreads
();
for
(
size_t
i
=
2
;
i
<=
ceil_power2
;
i
<<=
1
)
{
for
(
size_t
j
=
(
i
>>
1
);
j
>
0
;
j
>>=
1
)
{
for
(
size_t
tid
=
threadIdx
.
x
;
tid
<
ceil_power2
;
tid
+=
blockDim
.
x
)
{
size_t
tid_comp
=
tid
^
j
;
if
(
tid_comp
>
tid
)
{
if
((
tid
&
i
)
==
0
)
{
if
(
index_arr
[
tid
]
>
index_arr
[
tid_comp
])
{
Swap
(
&
data_arr
[
tid
],
&
data_arr
[
tid_comp
]);
Swap
(
&
index_arr
[
tid
],
&
index_arr
[
tid_comp
]);
}
}
else
{
if
(
index_arr
[
tid
]
<
index_arr
[
tid_comp
])
{
Swap
(
&
data_arr
[
tid
],
&
data_arr
[
tid_comp
]);
Swap
(
&
index_arr
[
tid
],
&
index_arr
[
tid_comp
]);
}
}
}
}
__syncthreads
();
}
}
for
(
size_t
tid
=
threadIdx
.
x
;
tid
<
inner
;
tid
+=
blockDim
.
x
)
{
input
[
blockIdx
.
x
*
inner
+
tid
]
=
data_arr
[
tid
];
indices
[
blockIdx
.
x
*
inner
+
tid
]
=
index_arr
[
tid
];
}
}
template
<
typename
T
,
typename
S
>
void
BitonicSortByKey
(
const
int
&
outer
,
const
int
&
inner
,
T
*
input
,
S
*
indices
,
T
*
data_buff
,
S
*
index_buff
,
cudaStream_t
stream
)
{
int
ceil_power2
=
RoundUpPower2
(
inner
);
size_t
share_mem
=
ceil_power2
*
(
sizeof
(
T
)
+
sizeof
(
S
));
if
(
share_mem
>
SHARED_MEM_PER_BLOCK
)
{
share_mem
=
0
;
}
else
{
data_buff
=
nullptr
;
index_buff
=
nullptr
;
}
int
thread
=
std
::
min
(
ceil_power2
,
GET_THREADS
);
BitonicSortByKeyKernel
<<<
outer
,
thread
,
share_mem
,
stream
>>>
(
outer
,
inner
,
ceil_power2
,
input
,
indices
,
data_buff
,
index_buff
);
}
template
void
TopK
(
const
int
&
outer
,
const
int
&
inner
,
const
float
*
input_addr
,
const
int
*
k
,
float
*
output
,
int
*
indices
,
float
*
data_buff
,
int
*
index_buff
,
cudaStream_t
stream
);
template
void
BitonicSortByKey
(
const
int
&
outer
,
const
int
&
inner
,
float
*
input
,
int
*
indices
,
float
*
data_buff
,
int
*
index_buff
,
cudaStream_t
stream
);
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/topk_impl.cuh
0 → 100644
浏览文件 @
25168309
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_
#include <cuda_runtime.h>
#include "runtime/device/gpu/cuda_common.h"
template
<
typename
T
,
typename
S
>
void
TopK
(
const
int
&
outer
,
const
int
&
inner
,
const
T
*
input_addr
,
const
S
*
k
,
T
*
output
,
S
*
indices
,
T
*
data_buff
,
S
*
index_buff
,
cudaStream_t
stream
);
template
<
typename
T
,
typename
S
>
void
BitonicSortByKey
(
const
int
&
outer
,
const
int
&
inner
,
T
*
input
,
S
*
indices
,
T
*
data_buff
,
S
*
index_buff
,
cudaStream_t
stream
);
int
RoundUpPower2
(
int
v
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_TOPK_H_
mindspore/ccsrc/runtime/device/gpu/cuda_common.h
浏览文件 @
25168309
...
@@ -30,6 +30,7 @@ class CudaCommon {
...
@@ -30,6 +30,7 @@ class CudaCommon {
inline
int
blocks_num
(
const
int
total_threads
)
const
{
inline
int
blocks_num
(
const
int
total_threads
)
const
{
return
std
::
min
(((
total_threads
-
1
)
/
threads_per_block_
)
+
1
,
max_blocks_
);
return
std
::
min
(((
total_threads
-
1
)
/
threads_per_block_
)
+
1
,
max_blocks_
);
}
}
size_t
share_memory_size
()
const
{
return
max_share_memory_
;
}
static
CudaCommon
&
GetInstance
()
{
static
CudaCommon
&
GetInstance
()
{
static
CudaCommon
instance
;
static
CudaCommon
instance
;
...
@@ -44,6 +45,7 @@ class CudaCommon {
...
@@ -44,6 +45,7 @@ class CudaCommon {
threads_per_block_
=
prop
.
maxThreadsPerBlock
;
threads_per_block_
=
prop
.
maxThreadsPerBlock
;
max_blocks_
=
prop
.
multiProcessorCount
;
max_blocks_
=
prop
.
multiProcessorCount
;
major_sm_
=
prop
.
major
;
major_sm_
=
prop
.
major
;
max_share_memory_
=
prop
.
sharedMemPerBlock
;
}
}
~
CudaCommon
()
=
default
;
~
CudaCommon
()
=
default
;
CudaCommon
(
const
CudaCommon
&
)
=
delete
;
CudaCommon
(
const
CudaCommon
&
)
=
delete
;
...
@@ -52,10 +54,12 @@ class CudaCommon {
...
@@ -52,10 +54,12 @@ class CudaCommon {
int
max_blocks_
;
int
max_blocks_
;
int
threads_per_block_
;
int
threads_per_block_
;
int
major_sm_
;
int
major_sm_
;
size_t
max_share_memory_
;
};
};
#define GET_BLOCKS(total_threads) mindspore::device::gpu::CudaCommon::GetInstance().blocks_num(total_threads)
#define GET_BLOCKS(total_threads) mindspore::device::gpu::CudaCommon::GetInstance().blocks_num(total_threads)
#define GET_THREADS mindspore::device::gpu::CudaCommon::GetInstance().threads_num()
#define GET_THREADS mindspore::device::gpu::CudaCommon::GetInstance().threads_num()
#define GET_MAJOR_SM mindspore::device::gpu::CudaCommon::GetInstance().major_sm()
#define GET_MAJOR_SM mindspore::device::gpu::CudaCommon::GetInstance().major_sm()
#define SHARED_MEM_PER_BLOCK mindspore::device::gpu::CudaCommon::GetInstance().share_memory_size()
#define MINIUM_SM 6
#define MINIUM_SM 6
#define RECOMMEND_SM 7
#define RECOMMEND_SM 7
}
// namespace gpu
}
// namespace gpu
...
...
tests/st/ops/gpu/test_topk_op.py
0 → 100644
浏览文件 @
25168309
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
from
mindspore
import
Tensor
from
mindspore.ops
import
operations
as
P
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_topk
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
x_np
=
np
.
random
.
rand
(
3
,
4
).
astype
(
np
.
float32
)
k
=
4
ms_output
=
P
.
TopK
(
True
)(
Tensor
(
x_np
),
k
)
np_output
=
np
.
sort
(
x_np
,
axis
=-
1
)[...,
::
-
1
][...,
0
:
k
]
assert
np
.
allclose
(
ms_output
[
0
].
asnumpy
(),
np_output
)
x_np
=
np
.
random
.
rand
(
3
,
4
).
astype
(
np
.
float32
)
k
=
4
ms_output
=
P
.
TopK
(
False
)(
Tensor
(
x_np
),
k
)
assert
np
.
allclose
(
ms_output
[
0
].
asnumpy
(),
x_np
)
x_np
=
np
.
random
.
rand
(
2
,
3
,
4
).
astype
(
np
.
float32
)
k
=
2
ms_output
=
P
.
TopK
(
True
)(
Tensor
(
x_np
),
k
)
np_output
=
np
.
sort
(
x_np
,
axis
=-
1
)[...,
::
-
1
][...,
0
:
k
]
assert
np
.
allclose
(
ms_output
[
0
].
asnumpy
(),
np_output
)
x_np
=
np
.
random
.
rand
(
512
,
1024
).
astype
(
np
.
float32
)
k
=
512
ms_output
=
P
.
TopK
(
True
)(
Tensor
(
x_np
),
k
)
np_output
=
np
.
sort
(
x_np
,
axis
=-
1
)[...,
::
-
1
][...,
0
:
k
]
assert
np
.
allclose
(
ms_output
[
0
].
asnumpy
(),
np_output
)
# sorted elements num greater than max thread per block
x_np
=
np
.
random
.
rand
(
512
,
2048
).
astype
(
np
.
float32
)
k
=
1
ms_output
=
P
.
TopK
(
True
)(
Tensor
(
x_np
),
k
)
np_output
=
np
.
sort
(
x_np
,
axis
=-
1
)[...,
::
-
1
][...,
0
:
k
]
assert
np
.
allclose
(
ms_output
[
0
].
asnumpy
(),
np_output
)
x_np
=
np
.
random
.
rand
(
512
,
2048
).
astype
(
np
.
float32
)
k
=
2048
ms_output
=
P
.
TopK
(
True
)(
Tensor
(
x_np
),
k
)
np_output
=
np
.
sort
(
x_np
,
axis
=-
1
)[...,
::
-
1
][...,
0
:
k
]
assert
np
.
allclose
(
ms_output
[
0
].
asnumpy
(),
np_output
)
# sorted elements num greater than max share memory per block
x_np
=
np
.
random
.
rand
(
512
,
40960
).
astype
(
np
.
float32
)
k
=
1
ms_output
=
P
.
TopK
(
True
)(
Tensor
(
x_np
),
k
)
np_output
=
np
.
sort
(
x_np
,
axis
=-
1
)[...,
::
-
1
][...,
0
:
k
]
assert
np
.
allclose
(
ms_output
[
0
].
asnumpy
(),
np_output
)
x_np
=
np
.
random
.
rand
(
512
,
40960
).
astype
(
np
.
float32
)
k
=
40960
ms_output
=
P
.
TopK
(
True
)(
Tensor
(
x_np
),
k
)
np_output
=
np
.
sort
(
x_np
,
axis
=-
1
)[...,
::
-
1
][...,
0
:
k
]
assert
np
.
allclose
(
ms_output
[
0
].
asnumpy
(),
np_output
)
x_np
=
np
.
random
.
rand
(
512
,
40960
).
astype
(
np
.
float32
)
k
=
40960
ms_output
=
P
.
TopK
(
False
)(
Tensor
(
x_np
),
k
)
assert
np
.
allclose
(
ms_output
[
0
].
asnumpy
(),
x_np
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录