Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a0767228
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a0767228
编写于
10月 07, 2017
作者:
Q
qiaolongfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
merge InferShapeContext and ExecutionContext
上级
c3b46d16
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
24 addition
and
33 deletion
+24
-33
paddle/framework/operator.cc
paddle/framework/operator.cc
+4
-4
paddle/framework/operator.h
paddle/framework/operator.h
+20
-29
未找到文件。
paddle/framework/operator.cc
浏览文件 @
a0767228
...
@@ -205,13 +205,13 @@ void OperatorBase::GenerateTemporaryNames() {
...
@@ -205,13 +205,13 @@ void OperatorBase::GenerateTemporaryNames() {
}
}
template
<
>
template
<
>
const
Tensor
*
InferShape
Context
::
Input
<
Tensor
>
(
const
std
::
string
&
name
)
const
{
const
Tensor
*
Execution
Context
::
Input
<
Tensor
>
(
const
std
::
string
&
name
)
const
{
auto
*
var
=
InputVar
(
name
);
auto
*
var
=
InputVar
(
name
);
return
var
==
nullptr
?
nullptr
:
GetTensorFromVar
(
var
);
return
var
==
nullptr
?
nullptr
:
GetTensorFromVar
(
var
);
}
}
template
<
>
template
<
>
const
std
::
vector
<
const
Tensor
*>
InferShape
Context
::
MultiInput
<
Tensor
>
(
const
std
::
vector
<
const
Tensor
*>
Execution
Context
::
MultiInput
<
Tensor
>
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
name
)
const
{
auto
names
=
op
().
Inputs
(
name
);
auto
names
=
op
().
Inputs
(
name
);
std
::
vector
<
const
Tensor
*>
res
;
std
::
vector
<
const
Tensor
*>
res
;
...
@@ -225,13 +225,13 @@ const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
...
@@ -225,13 +225,13 @@ const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
}
}
template
<
>
template
<
>
Tensor
*
InferShape
Context
::
Output
<
Tensor
>
(
const
std
::
string
&
name
)
const
{
Tensor
*
Execution
Context
::
Output
<
Tensor
>
(
const
std
::
string
&
name
)
const
{
auto
var
=
OutputVar
(
name
);
auto
var
=
OutputVar
(
name
);
return
var
==
nullptr
?
nullptr
:
var
->
GetMutable
<
LoDTensor
>
();
return
var
==
nullptr
?
nullptr
:
var
->
GetMutable
<
LoDTensor
>
();
}
}
template
<
>
template
<
>
std
::
vector
<
Tensor
*>
InferShape
Context
::
MultiOutput
<
Tensor
>
(
std
::
vector
<
Tensor
*>
Execution
Context
::
MultiOutput
<
Tensor
>
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
name
)
const
{
auto
names
=
op
().
Outputs
(
name
);
auto
names
=
op
().
Outputs
(
name
);
std
::
vector
<
Tensor
*>
res
;
std
::
vector
<
Tensor
*>
res
;
...
...
paddle/framework/operator.h
浏览文件 @
a0767228
...
@@ -57,7 +57,6 @@ inline std::string GradVarName(const std::string& var_name) {
...
@@ -57,7 +57,6 @@ inline std::string GradVarName(const std::string& var_name) {
}
}
class
OperatorBase
;
class
OperatorBase
;
class
InferShapeContext
;
class
ExecutionContext
;
class
ExecutionContext
;
extern
const
Tensor
*
GetTensorFromVar
(
const
Variable
*
var
);
extern
const
Tensor
*
GetTensorFromVar
(
const
Variable
*
var
);
...
@@ -169,10 +168,11 @@ class NOP : public OperatorBase {
...
@@ -169,10 +168,11 @@ class NOP : public OperatorBase {
}
}
};
};
class
InferShape
Context
{
class
Execution
Context
{
public:
public:
InferShapeContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
)
ExecutionContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
,
:
op_
(
op
),
scope_
(
scope
)
{}
const
platform
::
DeviceContext
&
device_context
)
:
op_
(
op
),
scope_
(
scope
),
device_context_
(
device_context
)
{}
const
OperatorBase
&
op
()
const
{
return
op_
;
}
const
OperatorBase
&
op
()
const
{
return
op_
;
}
...
@@ -278,31 +278,6 @@ class InferShapeContext {
...
@@ -278,31 +278,6 @@ class InferShapeContext {
out_tensor
->
set_lod
(
in_tensor
.
lod
());
out_tensor
->
set_lod
(
in_tensor
.
lod
());
}
}
private:
const
OperatorBase
&
op_
;
const
Scope
&
scope_
;
};
template
<>
const
Tensor
*
InferShapeContext
::
Input
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
template
<>
const
std
::
vector
<
const
Tensor
*>
InferShapeContext
::
MultiInput
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
template
<>
Tensor
*
InferShapeContext
::
Output
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
template
<>
std
::
vector
<
Tensor
*>
InferShapeContext
::
MultiOutput
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
class
ExecutionContext
:
public
InferShapeContext
{
public:
ExecutionContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
device_context
)
:
InferShapeContext
(
op
,
scope
),
device_context_
(
device_context
)
{}
template
<
typename
PlaceType
,
template
<
typename
PlaceType
,
typename
DeviceType
=
typename
platform
::
EigenDeviceConverter
<
typename
DeviceType
=
typename
platform
::
EigenDeviceConverter
<
PlaceType
>::
EigenDeviceType
>
PlaceType
>::
EigenDeviceType
>
...
@@ -315,9 +290,25 @@ class ExecutionContext : public InferShapeContext {
...
@@ -315,9 +290,25 @@ class ExecutionContext : public InferShapeContext {
}
}
private:
private:
const
OperatorBase
&
op_
;
const
Scope
&
scope_
;
const
platform
::
DeviceContext
&
device_context_
;
const
platform
::
DeviceContext
&
device_context_
;
};
};
template
<>
const
Tensor
*
ExecutionContext
::
Input
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
template
<>
const
std
::
vector
<
const
Tensor
*>
ExecutionContext
::
MultiInput
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
template
<>
Tensor
*
ExecutionContext
::
Output
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
template
<>
std
::
vector
<
Tensor
*>
ExecutionContext
::
MultiOutput
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
class
CompileTimeInferShapeContext
:
public
InferShapeContextBase
{
class
CompileTimeInferShapeContext
:
public
InferShapeContextBase
{
public:
public:
CompileTimeInferShapeContext
(
const
OpDescBind
&
op
,
const
BlockDescBind
&
block
)
CompileTimeInferShapeContext
(
const
OpDescBind
&
op
,
const
BlockDescBind
&
block
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录