Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a872eb90
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a872eb90
编写于
12月 20, 2018
作者:
X
Xin Pan
提交者:
GitHub
12月 20, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #14959 from panyx0718/clean2
Further op RunImpl refactor
上级
550e7e41
1fe3ac35
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
295 addition
and
209 deletion
+295
-209
paddle/fluid/framework/op_desc.cc
paddle/fluid/framework/op_desc.cc
+108
-24
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+142
-46
paddle/fluid/framework/shape_inference.cc
paddle/fluid/framework/shape_inference.cc
+0
-98
paddle/fluid/framework/shape_inference.h
paddle/fluid/framework/shape_inference.h
+16
-27
paddle/fluid/operators/controlflow/while_op.cc
paddle/fluid/operators/controlflow/while_op.cc
+29
-14
未找到文件。
paddle/fluid/framework/op_desc.cc
浏览文件 @
a872eb90
...
@@ -110,22 +110,125 @@ class CompileTimeInferShapeContext : public InferShapeContext {
...
@@ -110,22 +110,125 @@ class CompileTimeInferShapeContext : public InferShapeContext {
}
}
}
}
std
::
vector
<
InferShapeVarPtr
>
GetInputVarPtrs
(
const
std
::
string
&
name
)
override
{
const
std
::
vector
<
std
::
string
>
arg_names
=
Inputs
(
name
);
std
::
vector
<
InferShapeVarPtr
>
res
;
res
.
reserve
(
arg_names
.
size
());
std
::
transform
(
arg_names
.
begin
(),
arg_names
.
end
(),
std
::
back_inserter
(
res
),
[
this
](
const
std
::
string
&
name
)
{
return
block_
.
FindVarRecursive
(
name
);
});
return
res
;
}
std
::
vector
<
InferShapeVarPtr
>
GetOutputVarPtrs
(
const
std
::
string
&
name
)
override
{
const
std
::
vector
<
std
::
string
>
arg_names
=
Outputs
(
name
);
std
::
vector
<
InferShapeVarPtr
>
res
;
res
.
reserve
(
arg_names
.
size
());
std
::
transform
(
arg_names
.
begin
(),
arg_names
.
end
(),
std
::
back_inserter
(
res
),
[
this
](
const
std
::
string
&
name
)
{
return
block_
.
FindVarRecursive
(
name
);
});
return
res
;
}
DDim
GetInputDim
(
const
std
::
string
&
name
)
const
override
{
const
std
::
vector
<
std
::
string
>
&
arg_names
=
Inputs
(
name
);
PADDLE_ENFORCE_EQ
(
arg_names
.
size
(),
1UL
,
"Input(%s) should hold one element, but now it holds %d"
,
name
,
arg_names
.
size
());
return
this
->
GetDim
(
arg_names
[
0
]);
}
std
::
vector
<
DDim
>
GetInputsDim
(
const
std
::
string
&
name
)
const
override
{
const
std
::
vector
<
std
::
string
>
&
arg_names
=
Inputs
(
name
);
return
GetDims
(
arg_names
);
}
bool
IsRuntime
()
const
override
;
bool
IsRuntime
()
const
override
;
std
::
vector
<
proto
::
VarType
::
Type
>
GetInputsVarType
(
const
std
::
string
&
name
)
const
override
{
return
GetVarTypes
(
Inputs
(
name
));
}
std
::
vector
<
proto
::
VarType
::
Type
>
GetOutputsVarType
(
const
std
::
string
&
name
)
const
override
{
return
GetVarTypes
(
Outputs
(
name
));
}
void
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
{
auto
&
arg_names
=
Outputs
(
name
);
PADDLE_ENFORCE_EQ
(
arg_names
.
size
(),
1UL
,
"Output(%s) should hold one element, but now it holds %d"
,
name
,
arg_names
.
size
());
SetDim
(
arg_names
[
0
],
dim
);
}
void
SetOutputsDim
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
)
override
{
auto
&
names
=
Outputs
(
name
);
SetDims
(
names
,
dims
);
}
protected:
protected:
proto
::
VarType
::
Type
GetVarType
(
const
std
::
string
&
name
)
const
override
;
std
::
vector
<
proto
::
VarType
::
Type
>
GetVarTypes
(
const
std
::
vector
<
std
::
string
>
&
names
)
const
{
std
::
vector
<
proto
::
VarType
::
Type
>
retv
;
retv
.
resize
(
names
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
retv
.
begin
(),
std
::
bind
(
std
::
mem_fn
(
&
CompileTimeInferShapeContext
::
GetVarType
),
this
,
std
::
placeholders
::
_1
));
return
retv
;
}
proto
::
VarType
::
Type
GetVarType
(
const
std
::
string
&
name
)
const
;
DDim
GetDim
(
const
std
::
string
&
name
)
const
{
auto
var
=
block_
.
FindVarRecursive
(
name
);
PADDLE_ENFORCE
(
var
!=
nullptr
,
"Cannot find variable %s"
,
name
);
DDim
res
;
try
{
auto
shape
=
var
->
GetShape
();
res
=
shape
.
empty
()
?
make_ddim
({
0UL
})
:
make_ddim
(
shape
);
}
catch
(...)
{
VLOG
(
5
)
<<
"GetDim of variable "
<<
name
<<
" error"
;
std
::
rethrow_exception
(
std
::
current_exception
());
}
return
res
;
}
DDim
GetDim
(
const
std
::
string
&
name
)
const
override
;
std
::
vector
<
DDim
>
GetDims
(
const
std
::
vector
<
std
::
string
>
&
names
)
const
{
std
::
vector
<
DDim
>
ret
;
ret
.
reserve
(
names
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
std
::
back_inserter
(
ret
),
[
this
](
const
std
::
string
&
name
)
{
return
this
->
GetDim
(
name
);
});
return
ret
;
}
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
);
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
;
void
SetDims
(
const
std
::
vector
<
std
::
string
>
&
names
,
const
std
::
vector
<
DDim
>
&
dims
)
{
size_t
length
=
names
.
size
();
PADDLE_ENFORCE_EQ
(
length
,
dims
.
size
());
for
(
size_t
i
=
0
;
i
<
length
;
++
i
)
{
if
(
names
[
i
]
==
framework
::
kEmptyVarName
)
{
continue
;
}
SetDim
(
names
[
i
],
dims
[
i
]);
}
}
std
::
vector
<
DDim
>
GetRepeatedDims
(
const
std
::
string
&
name
)
const
override
;
std
::
vector
<
DDim
>
GetRepeatedDims
(
const
std
::
string
&
name
)
const
override
;
void
SetRepeatedDims
(
const
std
::
string
&
name
,
void
SetRepeatedDims
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
)
override
;
const
std
::
vector
<
DDim
>
&
dims
)
override
;
InferShapeVarPtr
GetVarPtr
(
const
std
::
string
&
name
)
override
;
const
OpDesc
&
op_
;
const
OpDesc
&
op_
;
const
BlockDesc
&
block_
;
const
BlockDesc
&
block_
;
};
};
...
@@ -644,20 +747,6 @@ const std::vector<std::string> &CompileTimeInferShapeContext::Outputs(
...
@@ -644,20 +747,6 @@ const std::vector<std::string> &CompileTimeInferShapeContext::Outputs(
return
op_
.
Output
(
name
);
return
op_
.
Output
(
name
);
}
}
DDim
CompileTimeInferShapeContext
::
GetDim
(
const
std
::
string
&
name
)
const
{
auto
var
=
block_
.
FindVarRecursive
(
name
);
PADDLE_ENFORCE
(
var
!=
nullptr
,
"Cannot find variable %s"
,
name
);
DDim
res
;
try
{
auto
shape
=
var
->
GetShape
();
res
=
shape
.
empty
()
?
make_ddim
({
0UL
})
:
make_ddim
(
shape
);
}
catch
(...)
{
VLOG
(
5
)
<<
"GetDim of variable "
<<
name
<<
" error"
;
std
::
rethrow_exception
(
std
::
current_exception
());
}
return
res
;
}
std
::
vector
<
DDim
>
CompileTimeInferShapeContext
::
GetRepeatedDims
(
std
::
vector
<
DDim
>
CompileTimeInferShapeContext
::
GetRepeatedDims
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
name
)
const
{
auto
var
=
block_
.
FindVarRecursive
(
name
);
auto
var
=
block_
.
FindVarRecursive
(
name
);
...
@@ -696,10 +785,5 @@ proto::VarType::Type CompileTimeInferShapeContext::GetVarType(
...
@@ -696,10 +785,5 @@ proto::VarType::Type CompileTimeInferShapeContext::GetVarType(
return
block_
.
FindVarRecursive
(
name
)
->
GetType
();
return
block_
.
FindVarRecursive
(
name
)
->
GetType
();
}
}
InferShapeVarPtr
CompileTimeInferShapeContext
::
GetVarPtr
(
const
std
::
string
&
name
)
{
return
block_
.
FindVarRecursive
(
name
);
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/operator.cc
浏览文件 @
a872eb90
...
@@ -142,12 +142,14 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames,
...
@@ -142,12 +142,14 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames,
const
Scope
&
scope
)
{
const
Scope
&
scope
)
{
for
(
auto
&
var_name_item
:
innames
)
{
for
(
auto
&
var_name_item
:
innames
)
{
std
::
vector
<
Variable
*>&
input_vars
=
inputs
[
var_name_item
.
first
];
std
::
vector
<
Variable
*>&
input_vars
=
inputs
[
var_name_item
.
first
];
input_vars
.
reserve
(
var_name_item
.
second
.
size
());
for
(
auto
&
var_name
:
var_name_item
.
second
)
{
for
(
auto
&
var_name
:
var_name_item
.
second
)
{
input_vars
.
push_back
(
scope
.
FindVar
(
var_name
));
input_vars
.
push_back
(
scope
.
FindVar
(
var_name
));
}
}
}
}
for
(
auto
&
var_name_item
:
outnames
)
{
for
(
auto
&
var_name_item
:
outnames
)
{
std
::
vector
<
Variable
*>&
output_vars
=
outputs
[
var_name_item
.
first
];
std
::
vector
<
Variable
*>&
output_vars
=
outputs
[
var_name_item
.
first
];
output_vars
.
reserve
(
var_name_item
.
second
.
size
());
for
(
auto
&
var_name
:
var_name_item
.
second
)
{
for
(
auto
&
var_name
:
var_name_item
.
second
)
{
output_vars
.
push_back
(
scope
.
FindVar
(
var_name
));
output_vars
.
push_back
(
scope
.
FindVar
(
var_name
));
}
}
...
@@ -556,30 +558,28 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -556,30 +558,28 @@ class RuntimeInferShapeContext : public InferShapeContext {
bool
HasOutput
(
const
std
::
string
&
name
)
const
override
{
bool
HasOutput
(
const
std
::
string
&
name
)
const
override
{
// has only one output
// has only one output
const
auto
&
outs
=
op_
.
Outputs
()
;
const
auto
&
outs
=
ctx_
.
outputs
;
auto
it
=
outs
.
find
(
name
);
auto
it
=
outs
.
find
(
name
);
if
(
it
==
outs
.
end
())
{
if
(
it
==
outs
.
end
())
{
return
false
;
return
false
;
}
}
const
auto
&
out
=
it
->
second
;
const
auto
&
out
=
it
->
second
;
if
(
out
.
size
()
==
0
||
out
[
0
]
==
kEmptyVarName
)
{
if
(
out
.
size
()
==
0
)
{
return
false
;
return
false
;
}
}
PADDLE_ENFORCE_EQ
(
out
.
size
(),
1UL
,
PADDLE_ENFORCE_EQ
(
out
.
size
(),
1UL
,
"Output %s should not have more than one outputs"
,
name
);
"Output %s should not have more than one outputs"
,
name
);
return
scope_
.
FindVar
(
out
[
0
])
!=
nullptr
;
return
out
[
0
]
!=
nullptr
;
}
}
bool
HasInputs
(
const
std
::
string
&
name
)
const
override
{
bool
HasInputs
(
const
std
::
string
&
name
)
const
override
{
if
(
!
op_
.
HasInputs
(
name
))
{
const
auto
&
ins
=
ctx_
.
inputs
;
return
false
;
auto
it
=
ins
.
find
(
name
);
}
if
(
it
==
ins
.
end
()
||
it
->
second
.
empty
())
{
auto
inputs
=
op_
.
Inputs
(
name
);
if
(
inputs
.
empty
())
{
return
false
;
return
false
;
}
}
for
(
auto
&
input
:
i
nputs
)
{
for
(
auto
&
input
:
i
t
->
second
)
{
if
(
scope_
.
FindVar
(
input
)
==
nullptr
)
{
if
(
input
==
nullptr
)
{
return
false
;
return
false
;
}
}
}
}
...
@@ -587,15 +587,13 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -587,15 +587,13 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
}
bool
HasOutputs
(
const
std
::
string
&
name
)
const
override
{
bool
HasOutputs
(
const
std
::
string
&
name
)
const
override
{
if
(
!
op_
.
HasOutputs
(
name
))
{
const
auto
&
outs
=
ctx_
.
outputs
;
return
false
;
auto
it
=
outs
.
find
(
name
);
}
if
(
it
==
outs
.
end
()
||
it
->
second
.
empty
())
{
auto
outputs
=
op_
.
Outputs
(
name
);
if
(
outputs
.
empty
())
{
return
false
;
return
false
;
}
}
for
(
auto
&
output
:
outputs
)
{
for
(
auto
&
output
:
it
->
second
)
{
if
(
scope_
.
FindVar
(
output
)
==
nullptr
)
{
if
(
output
==
nullptr
)
{
return
false
;
return
false
;
}
}
}
}
...
@@ -616,16 +614,18 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -616,16 +614,18 @@ class RuntimeInferShapeContext : public InferShapeContext {
void
ShareDim
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
void
ShareDim
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
size_t
j
=
0
)
override
{
size_t
j
=
0
)
override
{
PADDLE_ENFORCE_LT
(
i
,
Inputs
(
in
).
size
());
auto
in_it
=
ctx_
.
inputs
.
find
(
in
);
PADDLE_ENFORCE_LT
(
j
,
Outputs
(
out
).
size
());
auto
out_it
=
ctx_
.
outputs
.
find
(
out
);
const
std
::
string
&
input_n
=
Inputs
(
in
)[
i
];
PADDLE_ENFORCE
(
in_it
!=
ctx_
.
inputs
.
end
()
&&
in_it
->
second
.
size
()
>
i
,
const
std
::
string
&
output_n
=
Outputs
(
out
)[
j
];
"Inputs %s should have %llu argument"
,
in
,
i
);
PADDLE_ENFORCE
(
out_it
!=
ctx_
.
outputs
.
end
()
&&
out_it
->
second
.
size
()
>
j
,
"Outputs %s should have %llu argument"
,
out
,
j
);
Variable
*
in_var
=
in_it
->
second
[
i
];
Variable
*
out_var
=
out_it
->
second
[
j
];
Variable
*
in_var
=
scope_
.
FindVar
(
input_n
);
Variable
*
out_var
=
scope_
.
FindVar
(
output_n
);
PADDLE_ENFORCE
(
in_var
->
Type
()
==
out_var
->
Type
(),
PADDLE_ENFORCE
(
in_var
->
Type
()
==
out_var
->
Type
(),
"The type of %s and %s is not the same."
,
output_n
,
"The type of %s and %s is not the same."
,
in
,
out
);
GetDim
(
input_n
));
if
(
in_var
->
IsType
<
framework
::
SelectedRows
>
())
{
if
(
in_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
&
in_sele_rows
=
in_var
->
Get
<
framework
::
SelectedRows
>
();
auto
&
in_sele_rows
=
in_var
->
Get
<
framework
::
SelectedRows
>
();
...
@@ -646,13 +646,16 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -646,13 +646,16 @@ class RuntimeInferShapeContext : public InferShapeContext {
void
ShareLoD
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
void
ShareLoD
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
size_t
j
=
0
)
const
override
{
size_t
j
=
0
)
const
override
{
const
std
::
vector
<
std
::
string
>&
inputs
=
Inputs
(
in
);
auto
in_it
=
ctx_
.
inputs
.
find
(
in
);
const
std
::
vector
<
std
::
string
>&
outputs
=
Outputs
(
out
);
auto
out_it
=
ctx_
.
outputs
.
find
(
out
);
PADDLE_ENFORCE_LT
(
i
,
inputs
.
size
());
PADDLE_ENFORCE
(
in_it
!=
ctx_
.
inputs
.
end
()
&&
in_it
->
second
.
size
()
>
i
,
PADDLE_ENFORCE_LT
(
j
,
outputs
.
size
());
"Inputs %s should have %llu argument"
,
in
,
i
);
Variable
*
in_var
=
scope_
.
FindVar
(
inputs
.
at
(
i
));
PADDLE_ENFORCE
(
out_it
!=
ctx_
.
outputs
.
end
()
&&
out_it
->
second
.
size
()
>
j
,
"Outputs %s should have %llu argument"
,
out
,
j
);
Variable
*
in_var
=
in_it
->
second
.
at
(
i
);
if
(
!
in_var
->
IsType
<
LoDTensor
>
())
return
;
if
(
!
in_var
->
IsType
<
LoDTensor
>
())
return
;
Variable
*
out_var
=
scope_
.
FindVar
(
outputs
.
at
(
j
)
);
Variable
*
out_var
=
out_it
->
second
.
at
(
j
);
PADDLE_ENFORCE
(
out_var
->
IsType
<
LoDTensor
>
(),
PADDLE_ENFORCE
(
out_var
->
IsType
<
LoDTensor
>
(),
"The %d-th output of Output(%s) must be LoDTensor."
,
j
,
out
);
"The %d-th output of Output(%s) must be LoDTensor."
,
j
,
out
);
auto
in_tensor
=
in_var
->
Get
<
LoDTensor
>
();
auto
in_tensor
=
in_var
->
Get
<
LoDTensor
>
();
...
@@ -687,9 +690,64 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -687,9 +690,64 @@ class RuntimeInferShapeContext : public InferShapeContext {
bool
IsRuntime
()
const
override
{
return
true
;
}
bool
IsRuntime
()
const
override
{
return
true
;
}
// TODO(paddle-dev): Can this be template?
std
::
vector
<
InferShapeVarPtr
>
GetInputVarPtrs
(
const
std
::
string
&
name
)
override
{
const
std
::
vector
<
Variable
*>&
vars
=
InputVars
(
name
);
std
::
vector
<
InferShapeVarPtr
>
res
;
res
.
reserve
(
vars
.
size
());
res
.
insert
(
res
.
begin
(),
vars
.
begin
(),
vars
.
end
());
return
res
;
}
std
::
vector
<
InferShapeVarPtr
>
GetOutputVarPtrs
(
const
std
::
string
&
name
)
override
{
const
std
::
vector
<
Variable
*>&
vars
=
OutputVars
(
name
);
std
::
vector
<
InferShapeVarPtr
>
res
;
res
.
reserve
(
vars
.
size
());
res
.
insert
(
res
.
begin
(),
vars
.
begin
(),
vars
.
end
());
return
res
;
}
DDim
GetInputDim
(
const
std
::
string
&
name
)
const
override
{
const
std
::
vector
<
Variable
*>&
vars
=
InputVars
(
name
);
PADDLE_ENFORCE_EQ
(
vars
.
size
(),
1UL
,
"Input(%s) should hold one element, but now it holds %d"
,
name
,
vars
.
size
());
return
this
->
GetDim
(
vars
[
0
]);
}
std
::
vector
<
DDim
>
GetInputsDim
(
const
std
::
string
&
name
)
const
override
{
const
std
::
vector
<
Variable
*>&
vars
=
InputVars
(
name
);
return
GetDims
(
vars
);
}
std
::
vector
<
proto
::
VarType
::
Type
>
GetInputsVarType
(
const
std
::
string
&
name
)
const
override
{
return
GetVarTypes
(
InputVars
(
name
));
}
std
::
vector
<
proto
::
VarType
::
Type
>
GetOutputsVarType
(
const
std
::
string
&
name
)
const
override
{
return
GetVarTypes
(
OutputVars
(
name
));
}
void
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
{
auto
&
vars
=
OutputVars
(
name
);
PADDLE_ENFORCE_EQ
(
vars
.
size
(),
1UL
,
"Output(%s) should hold one element, but now it holds %d"
,
name
,
vars
.
size
());
SetDim
(
vars
[
0
],
dim
);
}
void
SetOutputsDim
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>&
dims
)
override
{
auto
&
vars
=
OutputVars
(
name
);
SetDims
(
vars
,
dims
);
}
protected:
protected:
DDim
GetDim
(
const
std
::
string
&
name
)
const
override
{
DDim
GetDim
(
Variable
*
var
)
const
{
Variable
*
var
=
scope_
.
FindVar
(
name
);
PADDLE_ENFORCE_NOT_NULL
(
var
);
PADDLE_ENFORCE_NOT_NULL
(
var
);
if
(
var
->
IsType
<
LoDTensor
>
())
{
if
(
var
->
IsType
<
LoDTensor
>
())
{
return
var
->
Get
<
LoDTensor
>
().
dims
();
return
var
->
Get
<
LoDTensor
>
().
dims
();
...
@@ -697,25 +755,44 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -697,25 +755,44 @@ class RuntimeInferShapeContext : public InferShapeContext {
return
var
->
Get
<
SelectedRows
>
().
GetCompleteDims
();
return
var
->
Get
<
SelectedRows
>
().
GetCompleteDims
();
}
else
{
}
else
{
PADDLE_THROW
(
PADDLE_THROW
(
"Only LoDTensor/SelectedRows support 'GetDim', but Variable
%s'
s "
"Only LoDTensor/SelectedRows support 'GetDim', but Variables "
"type_id is %s."
,
"type_id is %s."
,
name
,
var
->
Type
().
name
());
var
->
Type
().
name
());
}
}
}
}
std
::
vector
<
DDim
>
GetDims
(
const
std
::
vector
<
Variable
*>&
vars
)
const
{
std
::
vector
<
DDim
>
ret
;
ret
.
reserve
(
vars
.
size
());
std
::
transform
(
vars
.
begin
(),
vars
.
end
(),
std
::
back_inserter
(
ret
),
[
this
](
Variable
*
var
)
{
return
this
->
GetDim
(
var
);
});
return
ret
;
}
std
::
vector
<
DDim
>
GetRepeatedDims
(
const
std
::
string
&
name
)
const
override
{
std
::
vector
<
DDim
>
GetRepeatedDims
(
const
std
::
string
&
name
)
const
override
{
PADDLE_THROW
(
"Only compile time support this method"
);
PADDLE_THROW
(
"Only compile time support this method"
);
}
}
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
{
void
SetDim
(
Variable
*
var
,
const
DDim
&
dim
)
{
Variable
*
var
=
scope_
.
FindVar
(
name
);
if
(
var
->
IsType
<
LoDTensor
>
())
{
if
(
var
->
IsType
<
LoDTensor
>
())
{
var
->
GetMutable
<
LoDTensor
>
()
->
Resize
(
dim
);
var
->
GetMutable
<
LoDTensor
>
()
->
Resize
(
dim
);
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
var
->
GetMutable
<
SelectedRows
>
()
->
set_height
(
dim
[
0
]);
var
->
GetMutable
<
SelectedRows
>
()
->
set_height
(
dim
[
0
]);
}
else
{
}
else
{
PADDLE_THROW
(
"Variable %s type_id %s, expect LoDTensor/SelectedRows."
,
PADDLE_THROW
(
"Variable type_id %s, expect LoDTensor/SelectedRows."
,
name
,
var
->
Type
().
name
());
var
->
Type
().
name
());
}
}
void
SetDims
(
const
std
::
vector
<
Variable
*>&
vars
,
const
std
::
vector
<
DDim
>&
dims
)
{
size_t
length
=
vars
.
size
();
PADDLE_ENFORCE_EQ
(
length
,
dims
.
size
());
for
(
size_t
i
=
0
;
i
<
length
;
++
i
)
{
if
(
vars
[
i
]
==
nullptr
)
{
continue
;
}
SetDim
(
vars
[
i
],
dims
[
i
]);
}
}
}
}
...
@@ -724,16 +801,36 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -724,16 +801,36 @@ class RuntimeInferShapeContext : public InferShapeContext {
PADDLE_THROW
(
"Only compile time support this method"
);
PADDLE_THROW
(
"Only compile time support this method"
);
}
}
proto
::
VarType
::
Type
GetVarType
(
const
std
::
string
&
name
)
const
override
{
std
::
vector
<
proto
::
VarType
::
Type
>
GetVarTypes
(
auto
*
var
=
scope_
.
FindVar
(
name
);
const
std
::
vector
<
Variable
*>&
vars
)
const
{
return
ToVarType
(
var
->
Type
());
std
::
vector
<
proto
::
VarType
::
Type
>
retv
;
retv
.
resize
(
vars
.
size
());
std
::
transform
(
vars
.
begin
(),
vars
.
end
(),
retv
.
begin
(),
std
::
bind
(
std
::
mem_fn
(
&
RuntimeInferShapeContext
::
GetVarType
),
this
,
std
::
placeholders
::
_1
));
return
retv
;
}
}
InferShapeVarPtr
GetVarPtr
(
const
std
::
string
&
name
)
override
{
proto
::
VarType
::
Type
GetVarType
(
Variable
*
var
)
const
{
return
scope_
.
FindVar
(
name
);
return
ToVarType
(
var
->
Type
()
);
}
}
private:
private:
const
std
::
vector
<
Variable
*>&
InputVars
(
const
std
::
string
&
name
)
const
{
auto
it
=
ctx_
.
inputs
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
ctx_
.
inputs
.
end
(),
"Operator %s does not have the input %s."
,
op_
.
Type
(),
name
);
return
it
->
second
;
}
const
std
::
vector
<
Variable
*>&
OutputVars
(
const
std
::
string
&
name
)
const
{
auto
it
=
ctx_
.
outputs
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
ctx_
.
outputs
.
end
(),
"Operator %s does not have the outputs %s."
,
op_
.
Type
(),
name
);
return
it
->
second
;
}
const
OperatorBase
&
op_
;
const
OperatorBase
&
op_
;
const
Scope
&
scope_
;
const
Scope
&
scope_
;
const
RuntimeContext
&
ctx_
;
const
RuntimeContext
&
ctx_
;
...
@@ -864,8 +961,7 @@ Scope* OperatorWithKernel::PrepareData(
...
@@ -864,8 +961,7 @@ Scope* OperatorWithKernel::PrepareData(
for
(
size_t
i
=
0
;
i
<
var_name_item
.
second
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
var_name_item
.
second
.
size
();
++
i
)
{
auto
&
var_name
=
var_name_item
.
second
[
i
];
auto
&
var_name
=
var_name_item
.
second
[
i
];
auto
*
var
=
scope
.
FindVar
(
var_name
);
auto
*
var
=
input_vars
[
i
];
input_vars
[
i
]
=
var
;
// Only tensor can be tranfer to another device.
// Only tensor can be tranfer to another device.
if
(
var
==
nullptr
||
!
VarIsTensor
(
*
var
))
{
if
(
var
==
nullptr
||
!
VarIsTensor
(
*
var
))
{
...
...
paddle/fluid/framework/shape_inference.cc
浏览文件 @
a872eb90
...
@@ -22,20 +22,6 @@ limitations under the License. */
...
@@ -22,20 +22,6 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
DDim
InferShapeContext
::
GetInputDim
(
const
std
::
string
&
name
)
const
{
const
std
::
vector
<
std
::
string
>
&
arg_names
=
Inputs
(
name
);
PADDLE_ENFORCE_EQ
(
arg_names
.
size
(),
1UL
,
"Input(%s) should hold one element, but now it holds %d"
,
name
,
arg_names
.
size
());
return
this
->
GetDim
(
arg_names
[
0
]);
}
std
::
vector
<
DDim
>
InferShapeContext
::
GetInputsDim
(
const
std
::
string
&
name
)
const
{
const
std
::
vector
<
std
::
string
>
&
arg_names
=
Inputs
(
name
);
return
GetDims
(
arg_names
);
}
std
::
vector
<
DDim
>
InferShapeContext
::
GetReaderDims
(
std
::
vector
<
DDim
>
InferShapeContext
::
GetReaderDims
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
name
)
const
{
const
std
::
vector
<
std
::
string
>
&
arg_names
=
Inputs
(
name
);
const
std
::
vector
<
std
::
string
>
&
arg_names
=
Inputs
(
name
);
...
@@ -46,26 +32,6 @@ std::vector<DDim> InferShapeContext::GetReaderDims(
...
@@ -46,26 +32,6 @@ std::vector<DDim> InferShapeContext::GetReaderDims(
return
this
->
GetRepeatedDims
(
arg_names
[
0
]);
return
this
->
GetRepeatedDims
(
arg_names
[
0
]);
}
}
DDim
InferShapeContext
::
GetInputsElementDim
(
const
std
::
string
&
name
,
int
idx
)
const
{
const
std
::
vector
<
std
::
string
>
&
names
=
Inputs
(
name
);
return
this
->
GetDim
(
names
[
idx
]);
}
void
InferShapeContext
::
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
{
auto
&
arg_names
=
Outputs
(
name
);
PADDLE_ENFORCE_EQ
(
arg_names
.
size
(),
1UL
,
"Output(%s) should hold one element, but now it holds %d"
,
name
,
arg_names
.
size
());
SetDim
(
arg_names
[
0
],
dim
);
}
void
InferShapeContext
::
SetOutputsDim
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
)
{
auto
&
names
=
Outputs
(
name
);
SetDims
(
names
,
dims
);
}
void
InferShapeContext
::
SetReaderDims
(
const
std
::
string
&
name
,
void
InferShapeContext
::
SetReaderDims
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
)
{
const
std
::
vector
<
DDim
>
&
dims
)
{
const
std
::
vector
<
std
::
string
>
&
arg_names
=
Outputs
(
name
);
const
std
::
vector
<
std
::
string
>
&
arg_names
=
Outputs
(
name
);
...
@@ -76,69 +42,5 @@ void InferShapeContext::SetReaderDims(const std::string &name,
...
@@ -76,69 +42,5 @@ void InferShapeContext::SetReaderDims(const std::string &name,
return
this
->
SetRepeatedDims
(
arg_names
[
0
],
dims
);
return
this
->
SetRepeatedDims
(
arg_names
[
0
],
dims
);
}
}
std
::
vector
<
InferShapeVarPtr
>
InferShapeContext
::
GetInputVarPtrs
(
const
std
::
string
&
name
)
{
const
std
::
vector
<
std
::
string
>
arg_names
=
Inputs
(
name
);
std
::
vector
<
InferShapeVarPtr
>
res
;
res
.
reserve
(
arg_names
.
size
());
std
::
transform
(
arg_names
.
begin
(),
arg_names
.
end
(),
std
::
back_inserter
(
res
),
[
this
](
const
std
::
string
&
name
)
{
return
this
->
GetVarPtr
(
name
);
});
return
res
;
}
std
::
vector
<
InferShapeVarPtr
>
InferShapeContext
::
GetOutputVarPtrs
(
const
std
::
string
&
name
)
{
const
std
::
vector
<
std
::
string
>
arg_names
=
Outputs
(
name
);
std
::
vector
<
InferShapeVarPtr
>
res
;
res
.
reserve
(
arg_names
.
size
());
std
::
transform
(
arg_names
.
begin
(),
arg_names
.
end
(),
std
::
back_inserter
(
res
),
[
this
](
const
std
::
string
&
name
)
{
return
this
->
GetVarPtr
(
name
);
});
return
res
;
}
std
::
vector
<
DDim
>
InferShapeContext
::
GetDims
(
const
std
::
vector
<
std
::
string
>
&
names
)
const
{
std
::
vector
<
DDim
>
ret
;
ret
.
reserve
(
names
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
std
::
back_inserter
(
ret
),
[
this
](
const
std
::
string
&
name
)
{
return
this
->
GetDim
(
name
);
});
return
ret
;
}
void
InferShapeContext
::
SetDims
(
const
std
::
vector
<
std
::
string
>
&
names
,
const
std
::
vector
<
DDim
>
&
dims
)
{
size_t
length
=
names
.
size
();
PADDLE_ENFORCE_EQ
(
length
,
dims
.
size
());
for
(
size_t
i
=
0
;
i
<
length
;
++
i
)
{
if
(
names
[
i
]
==
framework
::
kEmptyVarName
)
{
continue
;
}
SetDim
(
names
[
i
],
dims
[
i
]);
}
}
std
::
vector
<
proto
::
VarType
::
Type
>
InferShapeContext
::
GetInputsVarType
(
const
std
::
string
&
name
)
const
{
return
GetVarTypes
(
Inputs
(
name
));
}
std
::
vector
<
proto
::
VarType
::
Type
>
InferShapeContext
::
GetOutputsVarType
(
const
std
::
string
&
name
)
const
{
return
GetVarTypes
(
Outputs
(
name
));
}
std
::
vector
<
proto
::
VarType
::
Type
>
InferShapeContext
::
GetVarTypes
(
const
std
::
vector
<
std
::
string
>
&
names
)
const
{
std
::
vector
<
proto
::
VarType
::
Type
>
retv
;
retv
.
resize
(
names
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
retv
.
begin
(),
std
::
bind
(
std
::
mem_fn
(
&
InferShapeContext
::
GetVarType
),
this
,
std
::
placeholders
::
_1
));
return
retv
;
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/shape_inference.h
浏览文件 @
a872eb90
...
@@ -33,22 +33,23 @@ class InferShapeContext {
...
@@ -33,22 +33,23 @@ class InferShapeContext {
virtual
bool
HasInput
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasInput
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasOutput
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasOutput
(
const
std
::
string
&
name
)
const
=
0
;
std
::
vector
<
proto
::
VarType
::
Type
>
GetInputsVarType
(
virtual
std
::
vector
<
proto
::
VarType
::
Type
>
GetInputsVarType
(
const
std
::
string
&
name
)
const
;
const
std
::
string
&
name
)
const
=
0
;
std
::
vector
<
proto
::
VarType
::
Type
>
GetOutputsVarType
(
virtual
std
::
vector
<
proto
::
VarType
::
Type
>
GetOutputsVarType
(
const
std
::
string
&
name
)
const
;
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasInputs
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasInputs
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasOutputs
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasOutputs
(
const
std
::
string
&
name
)
const
=
0
;
DDim
GetInputDim
(
const
std
::
string
&
name
)
const
;
virtual
DDim
GetInputDim
(
const
std
::
string
&
name
)
const
=
0
;
std
::
vector
<
DDim
>
GetInputsDim
(
const
std
::
string
&
name
)
const
;
virtual
std
::
vector
<
DDim
>
GetInputsDim
(
const
std
::
string
&
name
)
const
=
0
;
std
::
vector
<
DDim
>
GetReaderDims
(
const
std
::
string
&
name
)
const
;
virtual
std
::
vector
<
DDim
>
GetReaderDims
(
const
std
::
string
&
name
)
const
;
DDim
GetInputsElementDim
(
const
std
::
string
&
name
,
int
idx
)
const
;
void
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
);
virtual
void
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
=
0
;
void
SetOutputsDim
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
);
virtual
void
SetOutputsDim
(
const
std
::
string
&
name
,
void
SetReaderDims
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
);
const
std
::
vector
<
DDim
>
&
dims
)
=
0
;
virtual
void
SetReaderDims
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
);
virtual
AttrReader
Attrs
()
const
=
0
;
virtual
AttrReader
Attrs
()
const
=
0
;
virtual
const
std
::
vector
<
std
::
string
>
&
Inputs
(
virtual
const
std
::
vector
<
std
::
string
>
&
Inputs
(
...
@@ -67,27 +68,15 @@ class InferShapeContext {
...
@@ -67,27 +68,15 @@ class InferShapeContext {
virtual
bool
IsRuntime
()
const
=
0
;
virtual
bool
IsRuntime
()
const
=
0
;
std
::
vector
<
InferShapeVarPtr
>
GetInputVarPtrs
(
const
std
::
string
&
name
);
virtual
std
::
vector
<
InferShapeVarPtr
>
GetInputVarPtrs
(
std
::
vector
<
InferShapeVarPtr
>
GetOutputVarPtrs
(
const
std
::
string
&
name
);
const
std
::
string
&
name
)
=
0
;
virtual
InferShapeVarPtr
GetVarPtr
(
const
std
::
string
&
name
)
=
0
;
virtual
std
::
vector
<
InferShapeVarPtr
>
GetOutputVarPtrs
(
const
std
::
string
&
name
)
=
0
;
// Note: In while op, we need this to be public
void
SetDims
(
const
std
::
vector
<
std
::
string
>
&
names
,
const
std
::
vector
<
DDim
>
&
dims
);
protected:
protected:
virtual
DDim
GetDim
(
const
std
::
string
&
name
)
const
=
0
;
virtual
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
=
0
;
virtual
std
::
vector
<
DDim
>
GetRepeatedDims
(
const
std
::
string
&
name
)
const
=
0
;
virtual
std
::
vector
<
DDim
>
GetRepeatedDims
(
const
std
::
string
&
name
)
const
=
0
;
virtual
void
SetRepeatedDims
(
const
std
::
string
&
name
,
virtual
void
SetRepeatedDims
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
)
=
0
;
const
std
::
vector
<
DDim
>
&
dims
)
=
0
;
std
::
vector
<
DDim
>
GetDims
(
const
std
::
vector
<
std
::
string
>
&
names
)
const
;
std
::
vector
<
proto
::
VarType
::
Type
>
GetVarTypes
(
const
std
::
vector
<
std
::
string
>
&
names
)
const
;
virtual
proto
::
VarType
::
Type
GetVarType
(
const
std
::
string
&
name
)
const
=
0
;
};
};
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/operators/controlflow/while_op.cc
浏览文件 @
a872eb90
...
@@ -399,26 +399,41 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
...
@@ -399,26 +399,41 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
ctx
->
HasInputs
(
kOutputs
);
ctx
->
HasInputs
(
kOutputs
);
ctx
->
HasInputs
(
framework
::
GradVarName
(
kOutputs
));
ctx
->
HasInputs
(
framework
::
GradVarName
(
kOutputs
));
auto
p_names
=
ctx
->
Inputs
(
kX
);
auto
pg_ig_names
=
ctx
->
Outputs
(
kXGRAD
);
auto
pg_ig_names
=
ctx
->
Outputs
(
kXGRAD
);
auto
var_types
=
ctx
->
GetInputsVarType
(
kX
);
std
::
vector
<
framework
::
InferShapeVarPtr
>
in_var_ptrs
=
std
::
vector
<
std
::
string
>
names_to_set
;
ctx
->
GetInputVarPtrs
(
kX
);
std
::
vector
<
framework
::
DDim
>
dims_to_set
;
std
::
vector
<
framework
::
InferShapeVarPtr
>
out_var_ptrs
=
for
(
size_t
i
=
0
;
i
<
p_names
.
size
();
++
i
)
{
ctx
->
GetOutputVarPtrs
(
kXGRAD
);
PADDLE_ENFORCE
(
in_var_ptrs
.
size
()
==
out_var_ptrs
.
size
());
for
(
size_t
i
=
0
;
i
<
in_var_ptrs
.
size
();
++
i
)
{
if
(
pg_ig_names
[
i
]
==
framework
::
kEmptyVarName
)
{
if
(
pg_ig_names
[
i
]
==
framework
::
kEmptyVarName
)
{
continue
;
continue
;
}
}
auto
dims
=
ctx
->
GetInputsElementDim
(
kX
,
i
);
if
(
ctx
->
IsRuntime
())
{
if
(
var_types
[
i
]
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
framework
::
Variable
*
in_var
=
names_to_set
.
push_back
(
pg_ig_names
[
i
]);
boost
::
get
<
framework
::
Variable
*>
(
in_var_ptrs
[
i
]);
dims_to_set
.
push_back
(
dims
);
framework
::
Variable
*
out_var
=
}
else
if
(
var_types
[
i
]
==
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
{
boost
::
get
<
framework
::
Variable
*>
(
out_var_ptrs
[
i
]);
// not sure how to set the dim of LOD_TENSOR_ARRAY
names_to_set
.
push_back
(
pg_ig_names
[
i
]);
auto
type
=
framework
::
ToVarType
(
in_var
->
Type
());
dims_to_set
.
push_back
(
dims
);
if
(
type
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
out_var
->
GetMutable
<
LoDTensor
>
()
->
Resize
(
in_var
->
Get
<
framework
::
LoDTensor
>
().
dims
());
}
else
if
(
type
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
out_var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
set_height
(
in_var
->
Get
<
framework
::
SelectedRows
>
().
GetCompleteDims
()[
0
]);
}
else
if
(
type
==
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
{
PADDLE_THROW
(
"WhileGradOp doesn't support type %d"
,
static_cast
<
int
>
(
type
));
}
}
else
{
framework
::
VarDesc
*
in_var
=
boost
::
get
<
framework
::
VarDesc
*>
(
in_var_ptrs
[
i
]);
boost
::
get
<
framework
::
VarDesc
*>
(
out_var_ptrs
[
i
])
->
SetShape
(
in_var
->
GetShape
());
}
}
}
}
ctx
->
SetDims
(
names_to_set
,
dims_to_set
);
}
}
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录