Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b8f7fa97
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b8f7fa97
编写于
5月 02, 2018
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
replace __shfl with __shfl_sync
上级
7c90d7a3
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
18 addition
and
5 deletion
+18
-5
paddle/cuda/src/hl_top_k.cu
paddle/cuda/src/hl_top_k.cu
+5
-4
paddle/fluid/operators/top_k_op.cu
paddle/fluid/operators/top_k_op.cu
+6
-1
paddle/fluid/platform/cuda_primitives.h
paddle/fluid/platform/cuda_primitives.h
+7
-0
未找到文件。
paddle/cuda/src/hl_top_k.cu
浏览文件 @
b8f7fa97
...
...
@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_base.h"
#include "hl_sparse.ph"
#include "hl_top_k.h"
#include "
paddle/cuda/include/
hl_base.h"
#include "
paddle/cuda/include/
hl_sparse.ph"
#include "
paddle/cuda/include/
hl_top_k.h"
#include "paddle/utils/Logging.h"
// using namespace hppl;
...
...
@@ -244,8 +244,9 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK,
if
(
--
beamSize
==
0
)
break
;
__syncthreads
();
// temporary solution
unsigned
mask
=
0u
;
// CREATE_SHFL_MASK(mask, tid < len
);
CREATE_SHFL_MASK
(
mask
,
true
);
if
(
tid
==
maxId
[
0
])
{
if
(
beam
<
maxLength
)
{
...
...
paddle/fluid/operators/top_k_op.cu
浏览文件 @
b8f7fa97
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -235,8 +236,12 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
sh_topk
[
tid
]
=
topk
[
*
beam
];
}
}
// temporary solution
unsigned
mask
=
0u
;
CREATE_SHFL_MASK
(
mask
,
true
);
if
(
maxid
[
0
]
/
32
==
warp
)
{
if
(
__shfl
(
*
beam
,
(
maxid
[
0
])
%
32
,
32
)
==
MaxLength
)
break
;
if
(
__shfl
_sync
(
mask
,
*
beam
,
(
maxid
[
0
])
%
32
,
32
)
==
MaxLength
)
break
;
}
}
}
...
...
paddle/fluid/platform/cuda_primitives.h
浏览文件 @
b8f7fa97
...
...
@@ -72,6 +72,13 @@ template <typename T>
__forceinline__
__device__
T
__shfl_down_sync
(
unsigned
,
T
val
,
int
delta
)
{
return
__shfl_down
(
val
,
delta
);
}
template
<
typename
T
>
__forceinline__
__device__
T
__shfl_sync
(
unsigned
,
T
val
,
int
src_line
,
int
width
)
{
return
__shfl
(
val
,
src_line
,
width
);
}
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录