Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
a872eb90
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a872eb90
编写于
12月 20, 2018
作者:
X
Xin Pan
提交者:
GitHub
12月 20, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #14959 from panyx0718/clean2
Further op RunImpl refactor
上级
550e7e41
1fe3ac35
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
295 addition
and
209 deletion
+295
-209
paddle/fluid/framework/op_desc.cc
paddle/fluid/framework/op_desc.cc
+108
-24
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+142
-46
paddle/fluid/framework/shape_inference.cc
paddle/fluid/framework/shape_inference.cc
+0
-98
paddle/fluid/framework/shape_inference.h
paddle/fluid/framework/shape_inference.h
+16
-27
paddle/fluid/operators/controlflow/while_op.cc
paddle/fluid/operators/controlflow/while_op.cc
+29
-14
未找到文件。
paddle/fluid/framework/op_desc.cc
浏览文件 @
a872eb90
...
...
@@ -110,22 +110,125 @@ class CompileTimeInferShapeContext : public InferShapeContext {
}
}
std
::
vector
<
InferShapeVarPtr
>
GetInputVarPtrs
(
const
std
::
string
&
name
)
override
{
const
std
::
vector
<
std
::
string
>
arg_names
=
Inputs
(
name
);
std
::
vector
<
InferShapeVarPtr
>
res
;
res
.
reserve
(
arg_names
.
size
());
std
::
transform
(
arg_names
.
begin
(),
arg_names
.
end
(),
std
::
back_inserter
(
res
),
[
this
](
const
std
::
string
&
name
)
{
return
block_
.
FindVarRecursive
(
name
);
});
return
res
;
}
std
::
vector
<
InferShapeVarPtr
>
GetOutputVarPtrs
(
const
std
::
string
&
name
)
override
{
const
std
::
vector
<
std
::
string
>
arg_names
=
Outputs
(
name
);
std
::
vector
<
InferShapeVarPtr
>
res
;
res
.
reserve
(
arg_names
.
size
());
std
::
transform
(
arg_names
.
begin
(),
arg_names
.
end
(),
std
::
back_inserter
(
res
),
[
this
](
const
std
::
string
&
name
)
{
return
block_
.
FindVarRecursive
(
name
);
});
return
res
;
}
DDim
GetInputDim
(
const
std
::
string
&
name
)
const
override
{
const
std
::
vector
<
std
::
string
>
&
arg_names
=
Inputs
(
name
);
PADDLE_ENFORCE_EQ
(
arg_names
.
size
(),
1UL
,
"Input(%s) should hold one element, but now it holds %d"
,
name
,
arg_names
.
size
());
return
this
->
GetDim
(
arg_names
[
0
]);
}
std
::
vector
<
DDim
>
GetInputsDim
(
const
std
::
string
&
name
)
const
override
{
const
std
::
vector
<
std
::
string
>
&
arg_names
=
Inputs
(
name
);
return
GetDims
(
arg_names
);
}
bool
IsRuntime
()
const
override
;
std
::
vector
<
proto
::
VarType
::
Type
>
GetInputsVarType
(
const
std
::
string
&
name
)
const
override
{
return
GetVarTypes
(
Inputs
(
name
));
}
std
::
vector
<
proto
::
VarType
::
Type
>
GetOutputsVarType
(
const
std
::
string
&
name
)
const
override
{
return
GetVarTypes
(
Outputs
(
name
));
}
void
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
{
auto
&
arg_names
=
Outputs
(
name
);
PADDLE_ENFORCE_EQ
(
arg_names
.
size
(),
1UL
,
"Output(%s) should hold one element, but now it holds %d"
,
name
,
arg_names
.
size
());
SetDim
(
arg_names
[
0
],
dim
);
}
void
SetOutputsDim
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
)
override
{
auto
&
names
=
Outputs
(
name
);
SetDims
(
names
,
dims
);
}
protected:
proto
::
VarType
::
Type
GetVarType
(
const
std
::
string
&
name
)
const
override
;
std
::
vector
<
proto
::
VarType
::
Type
>
GetVarTypes
(
const
std
::
vector
<
std
::
string
>
&
names
)
const
{
std
::
vector
<
proto
::
VarType
::
Type
>
retv
;
retv
.
resize
(
names
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
retv
.
begin
(),
std
::
bind
(
std
::
mem_fn
(
&
CompileTimeInferShapeContext
::
GetVarType
),
this
,
std
::
placeholders
::
_1
));
return
retv
;
}
DDim
GetDim
(
const
std
::
string
&
name
)
const
override
;
proto
::
VarType
::
Type
GetVarType
(
const
std
::
string
&
name
)
const
;
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
;
DDim
GetDim
(
const
std
::
string
&
name
)
const
{
auto
var
=
block_
.
FindVarRecursive
(
name
);
PADDLE_ENFORCE
(
var
!=
nullptr
,
"Cannot find variable %s"
,
name
);
DDim
res
;
try
{
auto
shape
=
var
->
GetShape
();
res
=
shape
.
empty
()
?
make_ddim
({
0UL
})
:
make_ddim
(
shape
);
}
catch
(...)
{
VLOG
(
5
)
<<
"GetDim of variable "
<<
name
<<
" error"
;
std
::
rethrow_exception
(
std
::
current_exception
());
}
return
res
;
}
std
::
vector
<
DDim
>
GetDims
(
const
std
::
vector
<
std
::
string
>
&
names
)
const
{
std
::
vector
<
DDim
>
ret
;
ret
.
reserve
(
names
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
std
::
back_inserter
(
ret
),
[
this
](
const
std
::
string
&
name
)
{
return
this
->
GetDim
(
name
);
});
return
ret
;
}
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
);
void
SetDims
(
const
std
::
vector
<
std
::
string
>
&
names
,
const
std
::
vector
<
DDim
>
&
dims
)
{
size_t
length
=
names
.
size
();
PADDLE_ENFORCE_EQ
(
length
,
dims
.
size
());
for
(
size_t
i
=
0
;
i
<
length
;
++
i
)
{
if
(
names
[
i
]
==
framework
::
kEmptyVarName
)
{
continue
;
}
SetDim
(
names
[
i
],
dims
[
i
]);
}
}
std
::
vector
<
DDim
>
GetRepeatedDims
(
const
std
::
string
&
name
)
const
override
;
void
SetRepeatedDims
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
)
override
;
InferShapeVarPtr
GetVarPtr
(
const
std
::
string
&
name
)
override
;
const
OpDesc
&
op_
;
const
BlockDesc
&
block_
;
};
...
...
@@ -644,20 +747,6 @@ const std::vector<std::string> &CompileTimeInferShapeContext::Outputs(
return
op_
.
Output
(
name
);
}
DDim
CompileTimeInferShapeContext
::
GetDim
(
const
std
::
string
&
name
)
const
{
auto
var
=
block_
.
FindVarRecursive
(
name
);
PADDLE_ENFORCE
(
var
!=
nullptr
,
"Cannot find variable %s"
,
name
);
DDim
res
;
try
{
auto
shape
=
var
->
GetShape
();
res
=
shape
.
empty
()
?
make_ddim
({
0UL
})
:
make_ddim
(
shape
);
}
catch
(...)
{
VLOG
(
5
)
<<
"GetDim of variable "
<<
name
<<
" error"
;
std
::
rethrow_exception
(
std
::
current_exception
());
}
return
res
;
}
std
::
vector
<
DDim
>
CompileTimeInferShapeContext
::
GetRepeatedDims
(
const
std
::
string
&
name
)
const
{
auto
var
=
block_
.
FindVarRecursive
(
name
);
...
...
@@ -696,10 +785,5 @@ proto::VarType::Type CompileTimeInferShapeContext::GetVarType(
return
block_
.
FindVarRecursive
(
name
)
->
GetType
();
}
InferShapeVarPtr
CompileTimeInferShapeContext
::
GetVarPtr
(
const
std
::
string
&
name
)
{
return
block_
.
FindVarRecursive
(
name
);
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/operator.cc
浏览文件 @
a872eb90
...
...
@@ -142,12 +142,14 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames,
const
Scope
&
scope
)
{
for
(
auto
&
var_name_item
:
innames
)
{
std
::
vector
<
Variable
*>&
input_vars
=
inputs
[
var_name_item
.
first
];
input_vars
.
reserve
(
var_name_item
.
second
.
size
());
for
(
auto
&
var_name
:
var_name_item
.
second
)
{
input_vars
.
push_back
(
scope
.
FindVar
(
var_name
));
}
}
for
(
auto
&
var_name_item
:
outnames
)
{
std
::
vector
<
Variable
*>&
output_vars
=
outputs
[
var_name_item
.
first
];
output_vars
.
reserve
(
var_name_item
.
second
.
size
());
for
(
auto
&
var_name
:
var_name_item
.
second
)
{
output_vars
.
push_back
(
scope
.
FindVar
(
var_name
));
}
...
...
@@ -556,30 +558,28 @@ class RuntimeInferShapeContext : public InferShapeContext {
bool
HasOutput
(
const
std
::
string
&
name
)
const
override
{
// has only one output
const
auto
&
outs
=
op_
.
Outputs
()
;
const
auto
&
outs
=
ctx_
.
outputs
;
auto
it
=
outs
.
find
(
name
);
if
(
it
==
outs
.
end
())
{
return
false
;
}
const
auto
&
out
=
it
->
second
;
if
(
out
.
size
()
==
0
||
out
[
0
]
==
kEmptyVarName
)
{
if
(
out
.
size
()
==
0
)
{
return
false
;
}
PADDLE_ENFORCE_EQ
(
out
.
size
(),
1UL
,
"Output %s should not have more than one outputs"
,
name
);
return
scope_
.
FindVar
(
out
[
0
])
!=
nullptr
;
return
out
[
0
]
!=
nullptr
;
}
bool
HasInputs
(
const
std
::
string
&
name
)
const
override
{
if
(
!
op_
.
HasInputs
(
name
))
{
return
false
;
}
auto
inputs
=
op_
.
Inputs
(
name
);
if
(
inputs
.
empty
())
{
const
auto
&
ins
=
ctx_
.
inputs
;
auto
it
=
ins
.
find
(
name
);
if
(
it
==
ins
.
end
()
||
it
->
second
.
empty
())
{
return
false
;
}
for
(
auto
&
input
:
i
nputs
)
{
if
(
scope_
.
FindVar
(
input
)
==
nullptr
)
{
for
(
auto
&
input
:
i
t
->
second
)
{
if
(
input
==
nullptr
)
{
return
false
;
}
}
...
...
@@ -587,15 +587,13 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
bool
HasOutputs
(
const
std
::
string
&
name
)
const
override
{
if
(
!
op_
.
HasOutputs
(
name
))
{
return
false
;
}
auto
outputs
=
op_
.
Outputs
(
name
);
if
(
outputs
.
empty
())
{
const
auto
&
outs
=
ctx_
.
outputs
;
auto
it
=
outs
.
find
(
name
);
if
(
it
==
outs
.
end
()
||
it
->
second
.
empty
())
{
return
false
;
}
for
(
auto
&
output
:
outputs
)
{
if
(
scope_
.
FindVar
(
output
)
==
nullptr
)
{
for
(
auto
&
output
:
it
->
second
)
{
if
(
output
==
nullptr
)
{
return
false
;
}
}
...
...
@@ -616,16 +614,18 @@ class RuntimeInferShapeContext : public InferShapeContext {
void
ShareDim
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
size_t
j
=
0
)
override
{
PADDLE_ENFORCE_LT
(
i
,
Inputs
(
in
).
size
());
PADDLE_ENFORCE_LT
(
j
,
Outputs
(
out
).
size
());
const
std
::
string
&
input_n
=
Inputs
(
in
)[
i
];
const
std
::
string
&
output_n
=
Outputs
(
out
)[
j
];
auto
in_it
=
ctx_
.
inputs
.
find
(
in
);
auto
out_it
=
ctx_
.
outputs
.
find
(
out
);
PADDLE_ENFORCE
(
in_it
!=
ctx_
.
inputs
.
end
()
&&
in_it
->
second
.
size
()
>
i
,
"Inputs %s should have %llu argument"
,
in
,
i
);
PADDLE_ENFORCE
(
out_it
!=
ctx_
.
outputs
.
end
()
&&
out_it
->
second
.
size
()
>
j
,
"Outputs %s should have %llu argument"
,
out
,
j
);
Variable
*
in_var
=
in_it
->
second
[
i
];
Variable
*
out_var
=
out_it
->
second
[
j
];
Variable
*
in_var
=
scope_
.
FindVar
(
input_n
);
Variable
*
out_var
=
scope_
.
FindVar
(
output_n
);
PADDLE_ENFORCE
(
in_var
->
Type
()
==
out_var
->
Type
(),
"The type of %s and %s is not the same."
,
output_n
,
GetDim
(
input_n
));
"The type of %s and %s is not the same."
,
in
,
out
);
if
(
in_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
&
in_sele_rows
=
in_var
->
Get
<
framework
::
SelectedRows
>
();
...
...
@@ -646,13 +646,16 @@ class RuntimeInferShapeContext : public InferShapeContext {
void
ShareLoD
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
size_t
j
=
0
)
const
override
{
const
std
::
vector
<
std
::
string
>&
inputs
=
Inputs
(
in
);
const
std
::
vector
<
std
::
string
>&
outputs
=
Outputs
(
out
);
PADDLE_ENFORCE_LT
(
i
,
inputs
.
size
());
PADDLE_ENFORCE_LT
(
j
,
outputs
.
size
());
Variable
*
in_var
=
scope_
.
FindVar
(
inputs
.
at
(
i
));
auto
in_it
=
ctx_
.
inputs
.
find
(
in
);
auto
out_it
=
ctx_
.
outputs
.
find
(
out
);
PADDLE_ENFORCE
(
in_it
!=
ctx_
.
inputs
.
end
()
&&
in_it
->
second
.
size
()
>
i
,
"Inputs %s should have %llu argument"
,
in
,
i
);
PADDLE_ENFORCE
(
out_it
!=
ctx_
.
outputs
.
end
()
&&
out_it
->
second
.
size
()
>
j
,
"Outputs %s should have %llu argument"
,
out
,
j
);
Variable
*
in_var
=
in_it
->
second
.
at
(
i
);
if
(
!
in_var
->
IsType
<
LoDTensor
>
())
return
;
Variable
*
out_var
=
scope_
.
FindVar
(
outputs
.
at
(
j
)
);
Variable
*
out_var
=
out_it
->
second
.
at
(
j
);
PADDLE_ENFORCE
(
out_var
->
IsType
<
LoDTensor
>
(),
"The %d-th output of Output(%s) must be LoDTensor."
,
j
,
out
);
auto
in_tensor
=
in_var
->
Get
<
LoDTensor
>
();
...
...
@@ -687,9 +690,64 @@ class RuntimeInferShapeContext : public InferShapeContext {
bool
IsRuntime
()
const
override
{
return
true
;
}
// TODO(paddle-dev): Can this be template?
std
::
vector
<
InferShapeVarPtr
>
GetInputVarPtrs
(
const
std
::
string
&
name
)
override
{
const
std
::
vector
<
Variable
*>&
vars
=
InputVars
(
name
);
std
::
vector
<
InferShapeVarPtr
>
res
;
res
.
reserve
(
vars
.
size
());
res
.
insert
(
res
.
begin
(),
vars
.
begin
(),
vars
.
end
());
return
res
;
}
std
::
vector
<
InferShapeVarPtr
>
GetOutputVarPtrs
(
const
std
::
string
&
name
)
override
{
const
std
::
vector
<
Variable
*>&
vars
=
OutputVars
(
name
);
std
::
vector
<
InferShapeVarPtr
>
res
;
res
.
reserve
(
vars
.
size
());
res
.
insert
(
res
.
begin
(),
vars
.
begin
(),
vars
.
end
());
return
res
;
}
DDim
GetInputDim
(
const
std
::
string
&
name
)
const
override
{
const
std
::
vector
<
Variable
*>&
vars
=
InputVars
(
name
);
PADDLE_ENFORCE_EQ
(
vars
.
size
(),
1UL
,
"Input(%s) should hold one element, but now it holds %d"
,
name
,
vars
.
size
());
return
this
->
GetDim
(
vars
[
0
]);
}
std
::
vector
<
DDim
>
GetInputsDim
(
const
std
::
string
&
name
)
const
override
{
const
std
::
vector
<
Variable
*>&
vars
=
InputVars
(
name
);
return
GetDims
(
vars
);
}
std
::
vector
<
proto
::
VarType
::
Type
>
GetInputsVarType
(
const
std
::
string
&
name
)
const
override
{
return
GetVarTypes
(
InputVars
(
name
));
}
std
::
vector
<
proto
::
VarType
::
Type
>
GetOutputsVarType
(
const
std
::
string
&
name
)
const
override
{
return
GetVarTypes
(
OutputVars
(
name
));
}
void
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
{
auto
&
vars
=
OutputVars
(
name
);
PADDLE_ENFORCE_EQ
(
vars
.
size
(),
1UL
,
"Output(%s) should hold one element, but now it holds %d"
,
name
,
vars
.
size
());
SetDim
(
vars
[
0
],
dim
);
}
void
SetOutputsDim
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>&
dims
)
override
{
auto
&
vars
=
OutputVars
(
name
);
SetDims
(
vars
,
dims
);
}
protected:
DDim
GetDim
(
const
std
::
string
&
name
)
const
override
{
Variable
*
var
=
scope_
.
FindVar
(
name
);
DDim
GetDim
(
Variable
*
var
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
var
);
if
(
var
->
IsType
<
LoDTensor
>
())
{
return
var
->
Get
<
LoDTensor
>
().
dims
();
...
...
@@ -697,25 +755,44 @@ class RuntimeInferShapeContext : public InferShapeContext {
return
var
->
Get
<
SelectedRows
>
().
GetCompleteDims
();
}
else
{
PADDLE_THROW
(
"Only LoDTensor/SelectedRows support 'GetDim', but Variable
%s'
s "
"Only LoDTensor/SelectedRows support 'GetDim', but Variables "
"type_id is %s."
,
name
,
var
->
Type
().
name
());
var
->
Type
().
name
());
}
}
std
::
vector
<
DDim
>
GetDims
(
const
std
::
vector
<
Variable
*>&
vars
)
const
{
std
::
vector
<
DDim
>
ret
;
ret
.
reserve
(
vars
.
size
());
std
::
transform
(
vars
.
begin
(),
vars
.
end
(),
std
::
back_inserter
(
ret
),
[
this
](
Variable
*
var
)
{
return
this
->
GetDim
(
var
);
});
return
ret
;
}
std
::
vector
<
DDim
>
GetRepeatedDims
(
const
std
::
string
&
name
)
const
override
{
PADDLE_THROW
(
"Only compile time support this method"
);
}
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
{
Variable
*
var
=
scope_
.
FindVar
(
name
);
void
SetDim
(
Variable
*
var
,
const
DDim
&
dim
)
{
if
(
var
->
IsType
<
LoDTensor
>
())
{
var
->
GetMutable
<
LoDTensor
>
()
->
Resize
(
dim
);
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
var
->
GetMutable
<
SelectedRows
>
()
->
set_height
(
dim
[
0
]);
}
else
{
PADDLE_THROW
(
"Variable %s type_id %s, expect LoDTensor/SelectedRows."
,
name
,
var
->
Type
().
name
());
PADDLE_THROW
(
"Variable type_id %s, expect LoDTensor/SelectedRows."
,
var
->
Type
().
name
());
}
}
void
SetDims
(
const
std
::
vector
<
Variable
*>&
vars
,
const
std
::
vector
<
DDim
>&
dims
)
{
size_t
length
=
vars
.
size
();
PADDLE_ENFORCE_EQ
(
length
,
dims
.
size
());
for
(
size_t
i
=
0
;
i
<
length
;
++
i
)
{
if
(
vars
[
i
]
==
nullptr
)
{
continue
;
}
SetDim
(
vars
[
i
],
dims
[
i
]);
}
}
...
...
@@ -724,16 +801,36 @@ class RuntimeInferShapeContext : public InferShapeContext {
PADDLE_THROW
(
"Only compile time support this method"
);
}
proto
::
VarType
::
Type
GetVarType
(
const
std
::
string
&
name
)
const
override
{
auto
*
var
=
scope_
.
FindVar
(
name
);
return
ToVarType
(
var
->
Type
());
std
::
vector
<
proto
::
VarType
::
Type
>
GetVarTypes
(
const
std
::
vector
<
Variable
*>&
vars
)
const
{
std
::
vector
<
proto
::
VarType
::
Type
>
retv
;
retv
.
resize
(
vars
.
size
());
std
::
transform
(
vars
.
begin
(),
vars
.
end
(),
retv
.
begin
(),
std
::
bind
(
std
::
mem_fn
(
&
RuntimeInferShapeContext
::
GetVarType
),
this
,
std
::
placeholders
::
_1
));
return
retv
;
}
InferShapeVarPtr
GetVarPtr
(
const
std
::
string
&
name
)
override
{
return
scope_
.
FindVar
(
name
);
proto
::
VarType
::
Type
GetVarType
(
Variable
*
var
)
const
{
return
ToVarType
(
var
->
Type
()
);
}
private:
const
std
::
vector
<
Variable
*>&
InputVars
(
const
std
::
string
&
name
)
const
{
auto
it
=
ctx_
.
inputs
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
ctx_
.
inputs
.
end
(),
"Operator %s does not have the input %s."
,
op_
.
Type
(),
name
);
return
it
->
second
;
}
const
std
::
vector
<
Variable
*>&
OutputVars
(
const
std
::
string
&
name
)
const
{
auto
it
=
ctx_
.
outputs
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
ctx_
.
outputs
.
end
(),
"Operator %s does not have the outputs %s."
,
op_
.
Type
(),
name
);
return
it
->
second
;
}
const
OperatorBase
&
op_
;
const
Scope
&
scope_
;
const
RuntimeContext
&
ctx_
;
...
...
@@ -864,8 +961,7 @@ Scope* OperatorWithKernel::PrepareData(
for
(
size_t
i
=
0
;
i
<
var_name_item
.
second
.
size
();
++
i
)
{
auto
&
var_name
=
var_name_item
.
second
[
i
];
auto
*
var
=
scope
.
FindVar
(
var_name
);
input_vars
[
i
]
=
var
;
auto
*
var
=
input_vars
[
i
];
// Only tensor can be tranfer to another device.
if
(
var
==
nullptr
||
!
VarIsTensor
(
*
var
))
{
...
...
paddle/fluid/framework/shape_inference.cc
浏览文件 @
a872eb90
...
...
@@ -22,20 +22,6 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
DDim
InferShapeContext
::
GetInputDim
(
const
std
::
string
&
name
)
const
{
const
std
::
vector
<
std
::
string
>
&
arg_names
=
Inputs
(
name
);
PADDLE_ENFORCE_EQ
(
arg_names
.
size
(),
1UL
,
"Input(%s) should hold one element, but now it holds %d"
,
name
,
arg_names
.
size
());
return
this
->
GetDim
(
arg_names
[
0
]);
}
std
::
vector
<
DDim
>
InferShapeContext
::
GetInputsDim
(
const
std
::
string
&
name
)
const
{
const
std
::
vector
<
std
::
string
>
&
arg_names
=
Inputs
(
name
);
return
GetDims
(
arg_names
);
}
std
::
vector
<
DDim
>
InferShapeContext
::
GetReaderDims
(
const
std
::
string
&
name
)
const
{
const
std
::
vector
<
std
::
string
>
&
arg_names
=
Inputs
(
name
);
...
...
@@ -46,26 +32,6 @@ std::vector<DDim> InferShapeContext::GetReaderDims(
return
this
->
GetRepeatedDims
(
arg_names
[
0
]);
}
DDim
InferShapeContext
::
GetInputsElementDim
(
const
std
::
string
&
name
,
int
idx
)
const
{
const
std
::
vector
<
std
::
string
>
&
names
=
Inputs
(
name
);
return
this
->
GetDim
(
names
[
idx
]);
}
void
InferShapeContext
::
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
{
auto
&
arg_names
=
Outputs
(
name
);
PADDLE_ENFORCE_EQ
(
arg_names
.
size
(),
1UL
,
"Output(%s) should hold one element, but now it holds %d"
,
name
,
arg_names
.
size
());
SetDim
(
arg_names
[
0
],
dim
);
}
void
InferShapeContext
::
SetOutputsDim
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
)
{
auto
&
names
=
Outputs
(
name
);
SetDims
(
names
,
dims
);
}
void
InferShapeContext
::
SetReaderDims
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
)
{
const
std
::
vector
<
std
::
string
>
&
arg_names
=
Outputs
(
name
);
...
...
@@ -76,69 +42,5 @@ void InferShapeContext::SetReaderDims(const std::string &name,
return
this
->
SetRepeatedDims
(
arg_names
[
0
],
dims
);
}
std
::
vector
<
InferShapeVarPtr
>
InferShapeContext
::
GetInputVarPtrs
(
const
std
::
string
&
name
)
{
const
std
::
vector
<
std
::
string
>
arg_names
=
Inputs
(
name
);
std
::
vector
<
InferShapeVarPtr
>
res
;
res
.
reserve
(
arg_names
.
size
());
std
::
transform
(
arg_names
.
begin
(),
arg_names
.
end
(),
std
::
back_inserter
(
res
),
[
this
](
const
std
::
string
&
name
)
{
return
this
->
GetVarPtr
(
name
);
});
return
res
;
}
std
::
vector
<
InferShapeVarPtr
>
InferShapeContext
::
GetOutputVarPtrs
(
const
std
::
string
&
name
)
{
const
std
::
vector
<
std
::
string
>
arg_names
=
Outputs
(
name
);
std
::
vector
<
InferShapeVarPtr
>
res
;
res
.
reserve
(
arg_names
.
size
());
std
::
transform
(
arg_names
.
begin
(),
arg_names
.
end
(),
std
::
back_inserter
(
res
),
[
this
](
const
std
::
string
&
name
)
{
return
this
->
GetVarPtr
(
name
);
});
return
res
;
}
std
::
vector
<
DDim
>
InferShapeContext
::
GetDims
(
const
std
::
vector
<
std
::
string
>
&
names
)
const
{
std
::
vector
<
DDim
>
ret
;
ret
.
reserve
(
names
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
std
::
back_inserter
(
ret
),
[
this
](
const
std
::
string
&
name
)
{
return
this
->
GetDim
(
name
);
});
return
ret
;
}
void
InferShapeContext
::
SetDims
(
const
std
::
vector
<
std
::
string
>
&
names
,
const
std
::
vector
<
DDim
>
&
dims
)
{
size_t
length
=
names
.
size
();
PADDLE_ENFORCE_EQ
(
length
,
dims
.
size
());
for
(
size_t
i
=
0
;
i
<
length
;
++
i
)
{
if
(
names
[
i
]
==
framework
::
kEmptyVarName
)
{
continue
;
}
SetDim
(
names
[
i
],
dims
[
i
]);
}
}
std
::
vector
<
proto
::
VarType
::
Type
>
InferShapeContext
::
GetInputsVarType
(
const
std
::
string
&
name
)
const
{
return
GetVarTypes
(
Inputs
(
name
));
}
std
::
vector
<
proto
::
VarType
::
Type
>
InferShapeContext
::
GetOutputsVarType
(
const
std
::
string
&
name
)
const
{
return
GetVarTypes
(
Outputs
(
name
));
}
std
::
vector
<
proto
::
VarType
::
Type
>
InferShapeContext
::
GetVarTypes
(
const
std
::
vector
<
std
::
string
>
&
names
)
const
{
std
::
vector
<
proto
::
VarType
::
Type
>
retv
;
retv
.
resize
(
names
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
retv
.
begin
(),
std
::
bind
(
std
::
mem_fn
(
&
InferShapeContext
::
GetVarType
),
this
,
std
::
placeholders
::
_1
));
return
retv
;
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/shape_inference.h
浏览文件 @
a872eb90
...
...
@@ -33,22 +33,23 @@ class InferShapeContext {
virtual
bool
HasInput
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasOutput
(
const
std
::
string
&
name
)
const
=
0
;
std
::
vector
<
proto
::
VarType
::
Type
>
GetInputsVarType
(
const
std
::
string
&
name
)
const
;
std
::
vector
<
proto
::
VarType
::
Type
>
GetOutputsVarType
(
const
std
::
string
&
name
)
const
;
virtual
std
::
vector
<
proto
::
VarType
::
Type
>
GetInputsVarType
(
const
std
::
string
&
name
)
const
=
0
;
virtual
std
::
vector
<
proto
::
VarType
::
Type
>
GetOutputsVarType
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasInputs
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasOutputs
(
const
std
::
string
&
name
)
const
=
0
;
DDim
GetInputDim
(
const
std
::
string
&
name
)
const
;
std
::
vector
<
DDim
>
GetInputsDim
(
const
std
::
string
&
name
)
const
;
std
::
vector
<
DDim
>
GetReaderDims
(
const
std
::
string
&
name
)
const
;
DDim
GetInputsElementDim
(
const
std
::
string
&
name
,
int
idx
)
const
;
virtual
DDim
GetInputDim
(
const
std
::
string
&
name
)
const
=
0
;
virtual
std
::
vector
<
DDim
>
GetInputsDim
(
const
std
::
string
&
name
)
const
=
0
;
virtual
std
::
vector
<
DDim
>
GetReaderDims
(
const
std
::
string
&
name
)
const
;
void
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
);
void
SetOutputsDim
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
);
void
SetReaderDims
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
);
virtual
void
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
=
0
;
virtual
void
SetOutputsDim
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
)
=
0
;
virtual
void
SetReaderDims
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
);
virtual
AttrReader
Attrs
()
const
=
0
;
virtual
const
std
::
vector
<
std
::
string
>
&
Inputs
(
...
...
@@ -67,27 +68,15 @@ class InferShapeContext {
virtual
bool
IsRuntime
()
const
=
0
;
std
::
vector
<
InferShapeVarPtr
>
GetInputVarPtrs
(
const
std
::
string
&
name
);
std
::
vector
<
InferShapeVarPtr
>
GetOutputVarPtrs
(
const
std
::
string
&
name
);
virtual
InferShapeVarPtr
GetVarPtr
(
const
std
::
string
&
name
)
=
0
;
// Note: In while op, we need this to be public
void
SetDims
(
const
std
::
vector
<
std
::
string
>
&
names
,
const
std
::
vector
<
DDim
>
&
dims
);
virtual
std
::
vector
<
InferShapeVarPtr
>
GetInputVarPtrs
(
const
std
::
string
&
name
)
=
0
;
virtual
std
::
vector
<
InferShapeVarPtr
>
GetOutputVarPtrs
(
const
std
::
string
&
name
)
=
0
;
protected:
virtual
DDim
GetDim
(
const
std
::
string
&
name
)
const
=
0
;
virtual
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
=
0
;
virtual
std
::
vector
<
DDim
>
GetRepeatedDims
(
const
std
::
string
&
name
)
const
=
0
;
virtual
void
SetRepeatedDims
(
const
std
::
string
&
name
,
const
std
::
vector
<
DDim
>
&
dims
)
=
0
;
std
::
vector
<
DDim
>
GetDims
(
const
std
::
vector
<
std
::
string
>
&
names
)
const
;
std
::
vector
<
proto
::
VarType
::
Type
>
GetVarTypes
(
const
std
::
vector
<
std
::
string
>
&
names
)
const
;
virtual
proto
::
VarType
::
Type
GetVarType
(
const
std
::
string
&
name
)
const
=
0
;
};
}
// namespace framework
...
...
paddle/fluid/operators/controlflow/while_op.cc
浏览文件 @
a872eb90
...
...
@@ -399,26 +399,41 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
ctx
->
HasInputs
(
kOutputs
);
ctx
->
HasInputs
(
framework
::
GradVarName
(
kOutputs
));
auto
p_names
=
ctx
->
Inputs
(
kX
);
auto
pg_ig_names
=
ctx
->
Outputs
(
kXGRAD
);
auto
var_types
=
ctx
->
GetInputsVarType
(
kX
);
std
::
vector
<
std
::
string
>
names_to_set
;
std
::
vector
<
framework
::
DDim
>
dims_to_set
;
for
(
size_t
i
=
0
;
i
<
p_names
.
size
();
++
i
)
{
std
::
vector
<
framework
::
InferShapeVarPtr
>
in_var_ptrs
=
ctx
->
GetInputVarPtrs
(
kX
);
std
::
vector
<
framework
::
InferShapeVarPtr
>
out_var_ptrs
=
ctx
->
GetOutputVarPtrs
(
kXGRAD
);
PADDLE_ENFORCE
(
in_var_ptrs
.
size
()
==
out_var_ptrs
.
size
());
for
(
size_t
i
=
0
;
i
<
in_var_ptrs
.
size
();
++
i
)
{
if
(
pg_ig_names
[
i
]
==
framework
::
kEmptyVarName
)
{
continue
;
}
auto
dims
=
ctx
->
GetInputsElementDim
(
kX
,
i
);
if
(
var_types
[
i
]
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
names_to_set
.
push_back
(
pg_ig_names
[
i
]);
dims_to_set
.
push_back
(
dims
);
}
else
if
(
var_types
[
i
]
==
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
{
// not sure how to set the dim of LOD_TENSOR_ARRAY
names_to_set
.
push_back
(
pg_ig_names
[
i
]);
dims_to_set
.
push_back
(
dims
);
if
(
ctx
->
IsRuntime
())
{
framework
::
Variable
*
in_var
=
boost
::
get
<
framework
::
Variable
*>
(
in_var_ptrs
[
i
]);
framework
::
Variable
*
out_var
=
boost
::
get
<
framework
::
Variable
*>
(
out_var_ptrs
[
i
]);
auto
type
=
framework
::
ToVarType
(
in_var
->
Type
());
if
(
type
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
out_var
->
GetMutable
<
LoDTensor
>
()
->
Resize
(
in_var
->
Get
<
framework
::
LoDTensor
>
().
dims
());
}
else
if
(
type
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
out_var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
set_height
(
in_var
->
Get
<
framework
::
SelectedRows
>
().
GetCompleteDims
()[
0
]);
}
else
if
(
type
==
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
{
PADDLE_THROW
(
"WhileGradOp doesn't support type %d"
,
static_cast
<
int
>
(
type
));
}
}
else
{
framework
::
VarDesc
*
in_var
=
boost
::
get
<
framework
::
VarDesc
*>
(
in_var_ptrs
[
i
]);
boost
::
get
<
framework
::
VarDesc
*>
(
out_var_ptrs
[
i
])
->
SetShape
(
in_var
->
GetShape
());
}
}
ctx
->
SetDims
(
names_to_set
,
dims_to_set
);
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录