Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
d1258513
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
d1258513
编写于
8月 11, 2020
作者:
W
Wilber
提交者:
GitHub
8月 11, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Kernel] Update gru op and batch_norm op. (#4088)
上级
0ac48fcb
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
15 addition
and
6 deletion
+15
-6
lite/kernels/cuda/sequence_topk_avg_pooling_compute.cu
lite/kernels/cuda/sequence_topk_avg_pooling_compute.cu
+6
-3
lite/operators/dropout_op.cc
lite/operators/dropout_op.cc
+6
-2
lite/operators/gru_op.cc
lite/operators/gru_op.cc
+3
-1
未找到文件。
lite/kernels/cuda/sequence_topk_avg_pooling_compute.cu
浏览文件 @
d1258513
...
...
@@ -55,7 +55,9 @@ __global__ void topk_avg_pooling_kernel_by_row_improve(
output_data
+
(
gpu_input_offset_l
[
blockIdx
.
x
]
+
idx
)
*
feat_map_num
*
topk_size
+
blockIdx
.
y
*
topk_size
;
for
(
int
i
=
0
;
i
<
topk_size
;
++
i
)
{
fm_row_out_data
[
i
]
=
0
;
}
Dtype
*
smem_start_col
=
smem
+
idx
*
col_max
;
int
counter
=
max_k
;
// topk_size;
...
...
@@ -151,6 +153,9 @@ __global__ void topk_avg_pooling_kernel_for_big_data(
blockIdx
.
z
*
actual_row_in_shared_mem
+
idx
)
*
feat_map_num
*
topk_size
+
blockIdx
.
y
*
topk_size
;
for
(
int
i
=
0
;
i
<
topk_size
;
++
i
)
{
fm_row_out_data
[
i
]
=
0
;
}
Dtype
*
smem_start_col
=
smem
+
idx
*
col_max
;
...
...
@@ -239,8 +244,6 @@ void SequenceTopkAvgPoolingCompute<T>::Run() {
Tensor
*
out_tensor
=
param
.
Out
;
const
T
*
in_data
=
x_tensor
->
data
<
T
>
();
T
*
out_data
=
out_tensor
->
mutable_data
<
T
>
(
TARGET
(
kCUDA
));
TargetWrapperCuda
::
MemsetAsync
(
out_data
,
0
,
sizeof
(
T
)
*
param
.
Out
->
numel
(),
cuda_stream
);
int
topk_num
=
param
.
topks
.
size
();
lite
::
DDim
top_ks_shape
(
std
::
vector
<
int64_t
>
{
topk_num
,
1
,
1
,
1
});
...
...
lite/operators/dropout_op.cc
浏览文件 @
d1258513
...
...
@@ -12,8 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/dropout_op.h"
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
...
...
@@ -66,8 +68,10 @@ bool DropoutOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
param_
.
fix_seed
=
op_desc
.
GetAttr
<
bool
>
(
"fix_seed"
);
param_
.
seed
=
op_desc
.
GetAttr
<
int
>
(
"seed"
);
param_
.
dropout_implementation
=
op_desc
.
GetAttr
<
std
::
string
>
(
"dropout_implementation"
);
if
(
op_desc
.
HasAttr
(
"dropout_implementation"
))
{
param_
.
dropout_implementation
=
op_desc
.
GetAttr
<
std
::
string
>
(
"dropout_implementation"
);
}
return
true
;
}
...
...
lite/operators/gru_op.cc
浏览文件 @
d1258513
...
...
@@ -97,7 +97,9 @@ bool GRUOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
param_
.
gate_activation
=
op_desc
.
GetAttr
<
std
::
string
>
(
"gate_activation"
);
param_
.
activation
=
op_desc
.
GetAttr
<
std
::
string
>
(
"activation"
);
param_
.
is_reverse
=
op_desc
.
GetAttr
<
bool
>
(
"is_reverse"
);
param_
.
origin_mode
=
op_desc
.
GetAttr
<
bool
>
(
"origin_mode"
);
if
(
op_desc
.
HasAttr
(
"origin_mode"
))
{
param_
.
origin_mode
=
op_desc
.
GetAttr
<
bool
>
(
"origin_mode"
);
}
return
true
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录