Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
ca90356b
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ca90356b
编写于
1月 08, 2018
作者:
Q
qiaolongfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add back priority
上级
0f353ab4
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
17 addition
and
4 deletion
+17
-4
paddle/framework/op_registry_test.cc
paddle/framework/op_registry_test.cc
+3
-3
paddle/framework/operator.cc
paddle/framework/operator.cc
+14
-1
未找到文件。
paddle/framework/op_registry_test.cc
浏览文件 @
ca90356b
...
@@ -376,16 +376,16 @@ TEST(OperatorRegistrar, OpWithMultiKernel) {
...
@@ -376,16 +376,16 @@ TEST(OperatorRegistrar, OpWithMultiKernel) {
paddle
::
framework
::
UseCPU
();
paddle
::
framework
::
UseCPU
();
op
->
Run
(
scope
,
cpu_place
);
op
->
Run
(
scope
,
cpu_place
);
EXPECT_EQ
(
op_test_value
,
-
20
);
EXPECT_EQ
(
op_test_value
,
-
9
);
// add cuda kernels
// add cuda kernels
paddle
::
framework
::
UseCUDA
();
paddle
::
framework
::
UseCUDA
();
op
->
Run
(
scope
,
cuda_place
);
op
->
Run
(
scope
,
cuda_place
);
EXPECT_EQ
(
op_test_value
,
-
3
0
);
EXPECT_EQ
(
op_test_value
,
-
1
0
);
// use cudnn kernel
// use cudnn kernel
paddle
::
framework
::
UseCUDNN
();
paddle
::
framework
::
UseCUDNN
();
op
->
Run
(
scope
,
cuda_place
);
op
->
Run
(
scope
,
cuda_place
);
EXPECT_EQ
(
op_test_value
,
-
4
0
);
EXPECT_EQ
(
op_test_value
,
-
2
0
);
}
}
paddle/framework/operator.cc
浏览文件 @
ca90356b
...
@@ -474,6 +474,20 @@ void OperatorWithKernel::Run(const Scope& scope,
...
@@ -474,6 +474,20 @@ void OperatorWithKernel::Run(const Scope& scope,
ExecutionContext
ctx
(
*
this
,
scope
,
*
dev_ctx
);
ExecutionContext
ctx
(
*
this
,
scope
,
*
dev_ctx
);
auto
expected_kernel_key
=
this
->
GetExpectedKernelType
(
ctx
);
auto
expected_kernel_key
=
this
->
GetExpectedKernelType
(
ctx
);
OpKernelMap
&
kernels
=
kernels_iter
->
second
;
for
(
auto
&
candidate
:
kKernelPriority
)
{
auto
candidate_key
=
OpKernelType
(
expected_kernel_key
.
data_type_
,
std
::
get
<
0
>
(
candidate
),
expected_kernel_key
.
data_layout_
,
std
::
get
<
1
>
(
candidate
));
if
((
candidate_key
==
expected_kernel_key
)
||
(
kernels
.
count
(
candidate_key
)))
{
expected_kernel_key
=
candidate_key
;
break
;
}
}
Scope
&
new_scope
=
scope
.
NewScope
();
Scope
&
new_scope
=
scope
.
NewScope
();
for
(
auto
&
var_name_item
:
this
->
Inputs
())
{
for
(
auto
&
var_name_item
:
this
->
Inputs
())
{
...
@@ -504,7 +518,6 @@ void OperatorWithKernel::Run(const Scope& scope,
...
@@ -504,7 +518,6 @@ void OperatorWithKernel::Run(const Scope& scope,
}
}
}
}
OpKernelMap
&
kernels
=
kernels_iter
->
second
;
auto
kernel_iter
=
kernels
.
find
(
expected_kernel_key
);
auto
kernel_iter
=
kernels
.
find
(
expected_kernel_key
);
kernel_iter
->
second
->
Compute
(
ExecutionContext
(
*
this
,
new_scope
,
*
dev_ctx
));
kernel_iter
->
second
->
Compute
(
ExecutionContext
(
*
this
,
new_scope
,
*
dev_ctx
));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录