Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
840e6729
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
840e6729
编写于
12月 17, 2018
作者:
X
Xin Pan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
inject context
test=develop
上级
bbff0df3
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
31 addition
and
31 deletion
+31
-31
paddle/fluid/framework/ngraph_operator.cc
paddle/fluid/framework/ngraph_operator.cc
+1
-13
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+22
-14
paddle/fluid/framework/operator.h
paddle/fluid/framework/operator.h
+6
-3
paddle/fluid/operators/beam_search_decode_op.cc
paddle/fluid/operators/beam_search_decode_op.cc
+2
-1
未找到文件。
paddle/fluid/framework/ngraph_operator.cc
浏览文件 @
840e6729
...
...
@@ -278,19 +278,7 @@ std::shared_ptr<ngraph::runtime::Backend> NgraphEngine::backend_ =
ngraph
::
runtime
::
Backend
::
create
(
"CPU"
);
void
NgraphEngine
::
GetNgInputShape
(
std
::
shared_ptr
<
OperatorBase
>
op
)
{
RuntimeContext
ctx
;
for
(
auto
&
var_name_item
:
op
->
Inputs
())
{
std
::
vector
<
Variable
*>
input_vars
=
ctx
.
inputs
[
var_name_item
.
first
];
for
(
auto
&
var_name
:
var_name_item
.
second
)
{
input_vars
.
push_back
(
scope_
.
FindVar
(
var_name
));
}
}
for
(
auto
&
var_name_item
:
op
->
Outputs
())
{
std
::
vector
<
Variable
*>
output_vars
=
ctx
.
outputs
[
var_name_item
.
first
];
for
(
auto
&
var_name
:
var_name_item
.
second
)
{
output_vars
.
push_back
(
scope_
.
FindVar
(
var_name
));
}
}
RuntimeContext
ctx
(
op
->
Inputs
(),
op
->
Outputs
(),
scope_
);
op
->
RuntimeInferShape
(
scope_
,
place_
,
ctx
);
for
(
auto
&
var_name_item
:
op
->
Inputs
())
{
for
(
auto
&
var_name
:
var_name_item
.
second
)
{
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
840e6729
...
...
@@ -137,6 +137,23 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
}
}
RuntimeContext
::
RuntimeContext
(
const
VariableNameMap
&
innames
,
const
VariableNameMap
&
outnames
,
const
Scope
&
scope
)
{
for
(
auto
&
var_name_item
:
innames
)
{
std
::
vector
<
Variable
*>&
input_vars
=
inputs
[
var_name_item
.
first
];
for
(
auto
&
var_name
:
var_name_item
.
second
)
{
input_vars
.
push_back
(
scope
.
FindVar
(
var_name
));
}
}
for
(
auto
&
var_name_item
:
outnames
)
{
std
::
vector
<
Variable
*>&
output_vars
=
outputs
[
var_name_item
.
first
];
for
(
auto
&
var_name
:
var_name_item
.
second
)
{
output_vars
.
push_back
(
scope
.
FindVar
(
var_name
));
}
}
}
void
OperatorBase
::
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
{
VLOG
(
4
)
<<
place
<<
" "
<<
DebugStringEx
(
&
scope
);
if
(
platform
::
is_gpu_place
(
place
))
{
...
...
@@ -704,6 +721,7 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
void
OperatorWithKernel
::
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
RuntimeContext
ctx
(
Inputs
(),
Outputs
(),
scope
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
*
dev_ctx
=
pool
.
Get
(
place
);
...
...
@@ -717,15 +735,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
OpKernelMap
&
kernels
=
kernels_iter
->
second
;
// TODO(dzhwinter) : kernel fallback mechanism will be added when all the
// transform functions are ready.
// for (auto& candidate : kKernelPriority) {
// Do selection
// }
auto
expected_kernel_key
=
this
->
GetExpectedKernelType
(
ExecutionContext
(
*
this
,
scope
,
*
dev_ctx
));
auto
expected_kernel_key
=
this
->
GetExpectedKernelType
(
ExecutionContext
(
*
this
,
scope
,
*
dev_ctx
,
ctx
));
VLOG
(
3
)
<<
"expected_kernel_key:"
<<
expected_kernel_key
;
auto
kernel_iter
=
kernels
.
find
(
expected_kernel_key
);
...
...
@@ -744,7 +755,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
KernelTypeToString
(
expected_kernel_key
));
}
RuntimeContext
ctx
;
// do data transformScope &transfer_scope;
std
::
vector
<
std
::
string
>
transfered_inplace_vars
;
auto
*
transfer_scope
=
...
...
@@ -760,7 +770,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
RuntimeInferShapeContext
infer_shape_ctx
(
*
this
,
exec_scope
,
ctx
);
this
->
InferShape
(
&
infer_shape_ctx
);
kernel_iter
->
second
(
ExecutionContext
(
*
this
,
exec_scope
,
*
dev_ctx
));
kernel_iter
->
second
(
ExecutionContext
(
*
this
,
exec_scope
,
*
dev_ctx
,
ctx
));
if
(
!
transfered_inplace_vars
.
empty
())
{
// there is inplace variable has been transfered.
...
...
@@ -784,6 +794,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
}
}
}
void
OperatorWithKernel
::
TransferInplaceVarsBack
(
const
Scope
&
scope
,
const
std
::
vector
<
std
::
string
>&
inplace_vars
,
const
Scope
&
transfer_scope
)
const
{
...
...
@@ -806,7 +817,6 @@ Scope* OperatorWithKernel::PrepareData(
Scope
*
new_scope
=
nullptr
;
for
(
auto
&
var_name_item
:
Inputs
())
{
std
::
vector
<
Variable
*>&
input_vars
=
ctx
->
inputs
[
var_name_item
.
first
];
input_vars
.
resize
(
var_name_item
.
second
.
size
());
for
(
size_t
i
=
0
;
i
<
var_name_item
.
second
.
size
();
++
i
)
{
auto
&
var_name
=
var_name_item
.
second
[
i
];
...
...
@@ -869,8 +879,6 @@ Scope* OperatorWithKernel::PrepareData(
}
for
(
auto
&
var_name_item
:
Outputs
())
{
std
::
vector
<
Variable
*>&
output_vars
=
ctx
->
outputs
[
var_name_item
.
first
];
output_vars
.
resize
(
var_name_item
.
second
.
size
());
for
(
size_t
i
=
0
;
i
<
var_name_item
.
second
.
size
();
++
i
)
{
auto
&
var_name
=
var_name_item
.
second
[
i
];
output_vars
[
i
]
=
scope
.
FindVar
(
var_name
);
...
...
paddle/fluid/framework/operator.h
浏览文件 @
840e6729
...
...
@@ -72,7 +72,8 @@ class ExecutionContext;
class
RuntimeContext
{
public:
RuntimeContext
()
{}
RuntimeContext
(
const
VariableNameMap
&
innames
,
const
VariableNameMap
&
outnames
,
const
Scope
&
scope
);
VariableValueMap
inputs
;
VariableValueMap
outputs
;
...
...
@@ -165,8 +166,9 @@ class OperatorBase {
class
ExecutionContext
{
public:
ExecutionContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
device_context
)
:
op_
(
op
),
scope_
(
scope
),
device_context_
(
device_context
)
{}
const
platform
::
DeviceContext
&
device_context
,
const
RuntimeContext
&
ctx
)
:
op_
(
op
),
scope_
(
scope
),
device_context_
(
device_context
),
ctx_
(
ctx
)
{}
const
OperatorBase
&
op
()
const
{
return
op_
;
}
...
...
@@ -295,6 +297,7 @@ class ExecutionContext {
const
OperatorBase
&
op_
;
const
Scope
&
scope_
;
const
platform
::
DeviceContext
&
device_context_
;
const
RuntimeContext
&
ctx_
;
};
template
<>
...
...
paddle/fluid/operators/beam_search_decode_op.cc
浏览文件 @
840e6729
...
...
@@ -122,7 +122,8 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
framework
::
ExecutionContext
ctx
(
*
this
,
scope
,
dev_ctx
);
framework
::
RuntimeContext
run_ctx
(
Inputs
(),
Outputs
(),
scope
);
framework
::
ExecutionContext
ctx
(
*
this
,
scope
,
dev_ctx
,
run_ctx
);
const
LoDTensorArray
*
ids
=
ctx
.
Input
<
LoDTensorArray
>
(
"Ids"
);
const
LoDTensorArray
*
scores
=
ctx
.
Input
<
LoDTensorArray
>
(
"Scores"
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录