Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
94096ae5
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
94096ae5
编写于
12月 27, 2017
作者:
Q
QI JUN
提交者:
GitHub
12月 27, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add memory switch mechanism in operator kernel switch (#6991)
* add memory switch mechanism in operator kernel switch
上级
bff0cbfc
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
44 addition
and
13 deletion
+44
-13
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+4
-4
paddle/framework/data_transform.h
paddle/framework/data_transform.h
+7
-8
paddle/framework/operator.cc
paddle/framework/operator.cc
+33
-1
未找到文件。
paddle/framework/CMakeLists.txt
浏览文件 @
94096ae5
...
...
@@ -21,6 +21,8 @@ cc_test(variable_test SRCS variable_test.cc)
cc_library
(
scope SRCS scope.cc DEPS glog
)
cc_test
(
scope_test SRCS scope_test.cc DEPS scope
)
cc_library
(
data_transform SRCS data_transform.cc DEPS tensor framework_proto
)
cc_test
(
data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context
)
cc_library
(
attribute SRCS attribute.cc DEPS framework_proto
)
cc_test
(
program_desc_test SRCS program_desc_test.cc DEPS proto_desc
...
...
@@ -29,7 +31,8 @@ cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test
(
op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker
)
cc_library
(
op_info SRCS op_info.cc DEPS attribute framework_proto
)
cc_library
(
shape_inference SRCS shape_inference.cc DEPS ddim attribute
)
cc_library
(
operator SRCS operator.cc DEPS op_info device_context tensor scope glog shape_inference
)
cc_library
(
operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry init
)
cc_library
(
proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog
)
...
...
@@ -65,6 +68,3 @@ cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece)
cc_test
(
init_test SRCS init_test.cc DEPS init
)
cc_test
(
op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto
)
cc_library
(
data_transform SRCS data_transform.cc DEPS tensor framework_proto
)
cc_test
(
data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context
)
paddle/framework/data_transform.h
浏览文件 @
94096ae5
...
...
@@ -32,17 +32,16 @@ using DataTransformFN =
const
Variable
&
in
,
Variable
*
out
)
>
;
using
KernelTypePair
=
std
::
pair
<
OpKernelType
,
OpKernelType
>
;
static
void
hash_combine
(
std
::
size_t
&
seed
,
const
OpKernelType
&
t
)
{
OpKernelType
::
Hash
kernel_type_hasher
;
seed
^=
kernel_type_hasher
(
t
)
+
0x9e3779b9
+
(
seed
<<
6
)
+
(
seed
>>
2
);
}
struct
KernelTypePairHash
{
static
void
HashCombine
(
const
OpKernelType
&
t
,
std
::
size_t
*
seed
)
{
OpKernelType
::
Hash
kernel_type_hasher
;
(
*
seed
)
^=
kernel_type_hasher
(
t
)
+
0x9e3779b9
+
(
*
seed
<<
6
)
+
(
*
seed
>>
2
);
}
size_t
operator
()(
const
KernelTypePair
&
kernel_pair
)
const
{
std
::
size_t
seed
=
0
;
hash_combine
(
seed
,
kernel_pair
.
first
);
hash_combine
(
seed
,
kernel_pair
.
second
);
HashCombine
(
kernel_pair
.
first
,
&
seed
);
HashCombine
(
kernel_pair
.
second
,
&
seed
);
return
seed
;
}
};
...
...
paddle/framework/operator.cc
浏览文件 @
94096ae5
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include <algorithm>
#include <atomic>
#include "paddle/framework/data_transform.h"
#include "paddle/framework/executor.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/operator.h"
...
...
@@ -411,7 +412,38 @@ void OperatorWithKernel::Run(const Scope& scope,
expected_kernel_key
);
}
kernel_iter
->
second
->
Compute
(
ctx
);
if
(
actual_kernel_key
==
expected_kernel_key
)
{
kernel_iter
->
second
->
Compute
(
ctx
);
}
else
{
Scope
&
op_scope
=
scope
.
NewScope
();
auto
input_vars
=
this
->
InputVars
();
for
(
auto
var_name
:
input_vars
)
{
op_scope
.
Var
(
var_name
);
}
// TODO(qijun) get appropriate DeviceContext from DeviceContext pool
platform
::
DeviceContext
*
trans_dev_ctx
=
nullptr
;
std
::
vector
<
platform
::
DeviceContext
*>
trans_dev_ctx_vec
{
trans_dev_ctx
};
// TODO(qijun) get appropriate DataTransformFN from global map
framework
::
DataTransformFN
trans_fun
=
nullptr
;
// Wait for transform starting
dev_ctx
->
Wait
();
for
(
auto
var_name
:
input_vars
)
{
trans_fun
(
trans_dev_ctx_vec
,
*
(
scope
.
FindVar
(
var_name
)),
op_scope
.
FindVar
(
var_name
));
}
// Wait for data transform finishing
for
(
auto
ctx
:
trans_dev_ctx_vec
)
{
ctx
->
Wait
();
}
// Create a new ExecutionContext
ExecutionContext
op_ctx
(
*
this
,
op_scope
,
*
dev_ctx
);
kernel_iter
->
second
->
Compute
(
op_ctx
);
}
}
OpKernelType
OperatorWithKernel
::
GetActualKernelType
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录