Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
4130e5fa
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4130e5fa
编写于
10月 14, 2017
作者:
Q
qijun
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'baidu/develop' into add_selected_rows_functor
上级
f59a7c1d
dbb60572
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
160 addition
and
44 deletion
+160
-44
paddle/framework/backward.cc
paddle/framework/backward.cc
+26
-20
paddle/framework/block_desc.cc
paddle/framework/block_desc.cc
+6
-1
paddle/framework/block_desc.h
paddle/framework/block_desc.h
+4
-3
paddle/framework/details/op_registry.h
paddle/framework/details/op_registry.h
+3
-2
paddle/framework/grad_op_desc_maker.h
paddle/framework/grad_op_desc_maker.h
+15
-8
paddle/framework/op_desc.cc
paddle/framework/op_desc.cc
+3
-3
paddle/framework/op_desc.h
paddle/framework/op_desc.h
+2
-2
paddle/framework/program_desc.cc
paddle/framework/program_desc.cc
+1
-1
paddle/pybind/protobuf.cc
paddle/pybind/protobuf.cc
+45
-4
python/paddle/v2/framework/framework.py
python/paddle/v2/framework/framework.py
+52
-0
python/paddle/v2/framework/tests/test_operator_desc.py
python/paddle/v2/framework/tests/test_operator_desc.py
+2
-0
python/paddle/v2/framework/tests/test_variable.py
python/paddle/v2/framework/tests/test_variable.py
+1
-0
未找到文件。
paddle/framework/backward.cc
浏览文件 @
4130e5fa
...
...
@@ -28,15 +28,15 @@ namespace paddle {
namespace
framework
{
static
inline
std
::
unique_ptr
<
OperatorBase
>
CreateGradOp
(
const
OperatorBase
&
op
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_set
)
{
const
OperatorBase
&
op
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_set
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
grad_to_var
)
{
OpDescBind
op_desc
;
op_desc
.
SetInputMap
(
op
.
Inputs
());
op_desc
.
SetOutputMap
(
op
.
Outputs
());
op_desc
.
SetType
(
op
.
Type
());
op_desc
.
SetAttrMap
(
op
.
Attrs
());
auto
&
info
=
OpInfoMap
::
Instance
().
Get
(
op
.
Type
());
auto
grad_descs
=
info
.
GradOpMaker
()(
op_desc
,
no_grad_set
);
auto
grad_descs
=
info
.
GradOpMaker
()(
op_desc
,
no_grad_set
,
grad_to_var
);
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>
grad_ops
;
grad_ops
.
reserve
(
grad_descs
.
size
());
std
::
transform
(
grad_descs
.
begin
(),
grad_descs
.
end
(),
...
...
@@ -99,7 +99,9 @@ static std::unique_ptr<OperatorBase> NOP() {
// See Backward.h for details
static
std
::
unique_ptr
<
OperatorBase
>
BackwardRecursive
(
const
OperatorBase
&
forwardOp
,
std
::
unordered_set
<
std
::
string
>&
no_grad_names
,
size_t
&
uniq_id
)
{
std
::
unordered_set
<
std
::
string
>&
no_grad_names
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
grad_to_var
,
size_t
&
uniq_id
)
{
// If all input gradients of forwarding operator do not need to calculate,
// just return an NOP. Not return null ptr because NOP does not take
// too much time for calculation, but it is useful for simplifying logic.
...
...
@@ -137,7 +139,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
for
(
auto
it
=
forwardNet
.
ops_
.
rbegin
();
it
!=
forwardNet
.
ops_
.
rend
();
++
it
,
++
local_op_id
)
{
auto
&
fwd
=
*
it
;
auto
bwd
=
BackwardRecursive
(
*
fwd
,
no_grad_names
,
uniq_id
);
auto
bwd
=
BackwardRecursive
(
*
fwd
,
no_grad_names
,
grad_to_var
,
uniq_id
);
ForEachVarName
(
bwd
->
Outputs
(),
[
&
dup_output_ops
,
local_op_id
](
const
std
::
string
&
out
)
{
dup_output_ops
[
out
].
emplace_back
(
local_op_id
);
...
...
@@ -189,7 +191,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
}
}
else
{
std
::
unique_ptr
<
OperatorBase
>
grad_op
(
CreateGradOp
(
forwardOp
,
no_grad_names
));
CreateGradOp
(
forwardOp
,
no_grad_names
,
grad_to_var
));
ForEachVarName
(
grad_op
->
Inputs
(),
[
&
no_grad_names
,
&
net
,
&
grad_op
](
const
std
::
string
&
grad_input
)
{
...
...
@@ -228,7 +230,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
*
static_cast
<
const
OperatorBase
*>
(
&
rnnop
.
stepnet
());
// create stepnet's gradient op
rnn_grad_op
->
set_stepnet
(
BackwardRecursive
(
stepnet_op
,
no_grad_names
,
uniq_id
));
BackwardRecursive
(
stepnet_op
,
no_grad_names
,
grad_to_var
,
uniq_id
));
}
if
(
net
->
ops_
.
empty
())
{
// Current no aux op is added to network
...
...
@@ -255,7 +257,8 @@ std::unique_ptr<OperatorBase> Backward(
no_grad_names
.
insert
(
name
+
kGradVarSuffix
);
}
size_t
uid
=
0
;
return
BackwardRecursive
(
forwardOp
,
no_grad_names
,
uid
);
std
::
unordered_map
<
std
::
string
,
std
::
string
>
grad_to_var
;
return
BackwardRecursive
(
forwardOp
,
no_grad_names
,
&
grad_to_var
,
uid
);
}
// ==================================== //
...
...
@@ -272,30 +275,31 @@ static bool AllGradInSet(const std::vector<std::string>& names,
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
MakeOpGrad
(
const
std
::
unique_ptr
<
OpDescBind
>&
op_desc
,
std
::
unordered_set
<
std
::
string
>&
no_grad_vars
)
{
std
::
unordered_set
<
std
::
string
>*
no_grad_vars
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
grad_to_var
)
{
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
grad_op_descs
;
// All input gradients of forwarding operator do not need to calculate.
const
std
::
vector
<
std
::
string
>&
inputs
=
op_desc
->
InputArgumentNames
();
if
(
AllGradInSet
(
inputs
,
no_grad_vars
))
{
if
(
AllGradInSet
(
inputs
,
*
no_grad_vars
))
{
return
grad_op_descs
;
// empty vector
}
// All output gradients of forwarding operator do not need to calculate.
const
std
::
vector
<
std
::
string
>&
outputs
=
op_desc
->
OutputArgumentNames
();
if
(
AllGradInSet
(
outputs
,
no_grad_vars
))
{
if
(
AllGradInSet
(
outputs
,
*
no_grad_vars
))
{
for
(
const
std
::
string
&
name
:
inputs
)
{
no_grad_vars
.
insert
(
GradVarName
(
name
));
no_grad_vars
->
insert
(
GradVarName
(
name
));
}
return
grad_op_descs
;
// empty vector
}
grad_op_descs
=
OpInfoMap
::
Instance
()
.
Get
(
op_desc
->
Type
())
.
GradOpMaker
()(
*
op_desc
,
no_grad_vars
);
.
GradOpMaker
()(
*
op_desc
,
*
no_grad_vars
,
grad_to_var
);
std
::
list
<
std
::
unique_ptr
<
OpDescBind
>>
pending_fill_zeros_ops
;
for
(
auto
&
desc
:
grad_op_descs
)
{
for
(
const
std
::
string
&
in_name
:
desc
->
InputArgumentNames
())
{
if
(
no_grad_vars
.
count
(
in_name
))
{
if
(
no_grad_vars
->
count
(
in_name
))
{
std
::
string
prefix
=
in_name
.
substr
(
0
,
in_name
.
size
()
-
sizeof
(
kGradVarSuffix
)
/
sizeof
(
char
)
+
1
);
std
::
string
new_name
=
prefix
+
kZeroVarSuffix
;
...
...
@@ -315,7 +319,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
MakeBlockBackward
(
ProgramDescBind
&
program_desc
,
int
block_idx
,
std
::
unordered_set
<
std
::
string
>&
no_grad_vars
)
{
std
::
unordered_set
<
std
::
string
>*
no_grad_vars
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
grad_to_var
)
{
BlockDescBind
*
cur_block
=
program_desc
.
Block
(
block_idx
);
std
::
deque
<
std
::
unique_ptr
<
OpDescBind
>>&
op_descs
=
cur_block
->
ops_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
size_t
>>
dup_out_ops
;
...
...
@@ -323,15 +328,15 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
backward_descs
;
for
(
auto
it
=
op_descs
.
rbegin
();
it
!=
op_descs
.
rend
();
++
it
)
{
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
op_grads
=
MakeOpGrad
(
*
it
,
no_grad_vars
);
MakeOpGrad
(
*
it
,
no_grad_vars
,
grad_to_var
);
if
((
*
it
)
->
Type
()
==
"recurrent"
)
{
PADDLE_ENFORCE_EQ
(
op_grads
.
size
(),
size_t
(
1
),
"rnn_op's gradient process should contain only one op."
);
int
step_block_idx
=
(
*
it
)
->
GetBlockAttr
(
"stop_block"
);
auto
backward_block_op_descs
=
MakeBlockBackward
(
program_desc
,
step_block_idx
,
no_grad_vars
);
auto
backward_block_op_descs
=
MakeBlockBackward
(
program_desc
,
step_block_idx
,
no_grad_vars
,
grad_to_var
);
BlockDescBind
*
backward_block
=
program_desc
.
AppendBlock
(
*
cur_block
);
for
(
auto
&
ptr
:
backward_block_op_descs
)
{
backward_block
->
ops_
.
push_back
(
std
::
move
(
ptr
));
...
...
@@ -387,8 +392,9 @@ void AppendBackward(ProgramDescBind& program_desc,
no_grad_var_names
.
insert
(
GradVarName
(
name
));
}
const
int
root_block_idx
=
0
;
auto
backward_op_descs
=
MakeBlockBackward
(
program_desc
,
root_block_idx
,
no_grad_var_names
);
std
::
unordered_map
<
std
::
string
,
std
::
string
>
grad_to_var
;
auto
backward_op_descs
=
MakeBlockBackward
(
program_desc
,
root_block_idx
,
&
no_grad_var_names
,
&
grad_to_var
);
auto
&
forw_op_descs
=
program_desc
.
Block
(
root_block_idx
)
->
ops_
;
for
(
auto
&
ptr
:
backward_op_descs
)
{
forw_op_descs
.
push_back
(
std
::
move
(
ptr
));
...
...
paddle/framework/block_desc.cc
浏览文件 @
4130e5fa
...
...
@@ -66,7 +66,7 @@ std::vector<OpDescBind *> BlockDescBind::AllOps() const {
return
res
;
}
void
BlockDescBind
::
Sync
()
{
void
BlockDescBind
::
Flush
()
{
if
(
need_update_
)
{
auto
&
op_field
=
*
this
->
desc_
->
mutable_ops
();
op_field
.
Clear
();
...
...
@@ -91,5 +91,10 @@ BlockDescBind *BlockDescBind::ParentBlock() const {
return
prog_
->
Block
(
static_cast
<
size_t
>
(
this
->
desc_
->
parent_idx
()));
}
BlockDesc
*
BlockDescBind
::
Proto
()
{
Flush
();
return
desc_
;
}
}
// namespace framework
}
// namespace paddle
paddle/framework/block_desc.h
浏览文件 @
4130e5fa
...
...
@@ -35,7 +35,8 @@ class BlockDescBind {
public:
friend
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
MakeBlockBackward
(
ProgramDescBind
&
program_desc
,
int
block_idx
,
std
::
unordered_set
<
std
::
string
>
&
no_grad_vars
);
std
::
unordered_set
<
std
::
string
>
*
no_grad_vars
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>
*
grad_to_var
);
friend
void
AppendBackward
(
ProgramDescBind
&
program_desc
,
...
...
@@ -64,9 +65,9 @@ class BlockDescBind {
std
::
vector
<
OpDescBind
*>
AllOps
()
const
;
void
Sync
();
void
Flush
();
BlockDesc
*
RawPtr
()
{
return
desc_
;
}
BlockDesc
*
Proto
();
private:
ProgramDescBind
*
prog_
;
// not_own
...
...
paddle/framework/details/op_registry.h
浏览文件 @
4130e5fa
...
...
@@ -99,8 +99,9 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
void
operator
()(
const
char
*
op_type
,
OpInfo
*
info
)
const
{
info
->
grad_op_maker_
=
[](
const
OpDescBind
&
fwd_op
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_set
)
{
T
maker
(
fwd_op
,
no_grad_set
);
const
std
::
unordered_set
<
std
::
string
>&
no_grad_set
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
grad_to_var
)
{
T
maker
(
fwd_op
,
no_grad_set
,
grad_to_var
);
return
maker
();
};
}
...
...
paddle/framework/grad_op_desc_maker.h
浏览文件 @
4130e5fa
...
...
@@ -25,8 +25,9 @@ class GradOpDescMakerBase {
public:
explicit
GradOpDescMakerBase
(
const
OpDescBind
&
fwd_op
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_set
)
:
fwd_op_
(
fwd_op
),
no_grad_set_
(
no_grad_set
)
{}
const
std
::
unordered_set
<
std
::
string
>&
no_grad_set
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
grad_to_var
)
:
fwd_op_
(
fwd_op
),
no_grad_set_
(
no_grad_set
),
grad_to_var_
(
grad_to_var
)
{}
virtual
~
GradOpDescMakerBase
()
=
default
;
virtual
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
operator
()()
const
=
0
;
...
...
@@ -37,12 +38,17 @@ class GradOpDescMakerBase {
std
::
vector
<
std
::
string
>
ret_val
;
auto
var_names
=
this
->
Input
(
name
);
ret_val
.
reserve
(
var_names
.
size
());
std
::
transform
(
var_names
.
begin
(),
var_names
.
end
(),
std
::
back_inserter
(
ret_val
),
[
this
](
const
std
::
string
&
fwd_var_name
)
->
std
::
string
{
auto
g_name
=
GradVarName
(
fwd_var_name
);
return
no_grad_set_
.
count
(
g_name
)
==
0
?
g_name
:
kEmptyVarName
;
});
std
::
transform
(
var_names
.
begin
(),
var_names
.
end
(),
std
::
back_inserter
(
ret_val
),
[
this
](
const
std
::
string
&
fwd_var_name
)
->
std
::
string
{
auto
g_name
=
GradVarName
(
fwd_var_name
);
if
(
no_grad_set_
.
count
(
g_name
))
{
return
kEmptyVarName
;
}
else
{
(
*
this
->
grad_to_var_
)[
g_name
]
=
fwd_var_name
;
return
g_name
;
}
});
if
(
!
drop_empty_grad
)
{
return
ret_val
;
}
...
...
@@ -95,6 +101,7 @@ class GradOpDescMakerBase {
private:
const
OpDescBind
&
fwd_op_
;
const
std
::
unordered_set
<
std
::
string
>&
no_grad_set_
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
grad_to_var_
;
};
class
SingleGradOpDescMaker
:
public
GradOpDescMakerBase
{
...
...
paddle/framework/op_desc.cc
浏览文件 @
4130e5fa
...
...
@@ -32,7 +32,7 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
}
OpDesc
*
OpDescBind
::
Proto
()
{
Sync
();
Flush
();
return
&
op_desc_
;
}
...
...
@@ -101,7 +101,7 @@ void OpDescBind::SetAttr(const std::string &name, const Attribute &v) {
}
void
OpDescBind
::
SetBlockAttr
(
const
std
::
string
&
name
,
BlockDescBind
&
block
)
{
BlockDesc
*
desc
=
block
.
RawPtr
();
BlockDesc
*
desc
=
block
.
Proto
();
this
->
attrs_
[
name
]
=
desc
;
need_update_
=
true
;
}
...
...
@@ -165,7 +165,7 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
void
operator
()(
boost
::
blank
)
const
{
PADDLE_THROW
(
"Unexpected branch"
);
}
};
void
OpDescBind
::
Sync
()
{
void
OpDescBind
::
Flush
()
{
if
(
need_update_
)
{
this
->
op_desc_
.
mutable_inputs
()
->
Clear
();
for
(
auto
&
ipt
:
inputs_
)
{
...
...
paddle/framework/op_desc.h
浏览文件 @
4130e5fa
...
...
@@ -89,8 +89,6 @@ class OpDescBind {
this
->
need_update_
=
true
;
}
void
Sync
();
const
VariableNameMap
&
Inputs
()
const
{
return
inputs_
;
}
const
VariableNameMap
&
Outputs
()
const
{
return
outputs_
;
}
...
...
@@ -104,6 +102,8 @@ class OpDescBind {
void
InferShape
(
const
BlockDescBind
&
block
)
const
;
void
Flush
();
private:
template
<
typename
MapType
>
static
std
::
vector
<
typename
MapType
::
key_type
>
MapKeys
(
const
MapType
&
map
)
{
...
...
paddle/framework/program_desc.cc
浏览文件 @
4130e5fa
...
...
@@ -45,7 +45,7 @@ BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) {
ProgramDesc
*
ProgramDescBind
::
Proto
()
{
for
(
auto
&
block
:
blocks_
)
{
block
->
Sync
();
block
->
Flush
();
}
return
prog_
;
}
...
...
paddle/pybind/protobuf.cc
浏览文件 @
4130e5fa
...
...
@@ -123,7 +123,18 @@ void BindProgramDesc(py::module &m) {
AppendBackward
(
program_desc
,
no_grad_vars
);
})
.
def
(
"block"
,
&
ProgramDescBind
::
Block
,
py
::
return_value_policy
::
reference
)
.
def
(
"num_blocks"
,
&
ProgramDescBind
::
Size
);
.
def
(
"num_blocks"
,
&
ProgramDescBind
::
Size
)
.
def
(
"serialize_to_string"
,
[](
ProgramDescBind
&
program_desc
)
->
py
::
bytes
{
const
ProgramDesc
*
desc
=
program_desc
.
Proto
();
PADDLE_ENFORCE
(
desc
->
IsInitialized
(),
"ProgramDesc has not been initialized."
);
std
::
string
res
;
PADDLE_ENFORCE
(
desc
->
SerializeToString
(
&
res
),
"Serialize ProgramDesc Error. This could be a bug of Paddle."
);
return
res
;
});
}
void
BindBlockDesc
(
py
::
module
&
m
)
{
...
...
@@ -149,7 +160,17 @@ void BindBlockDesc(py::module &m) {
.
def
(
"all_vars"
,
&
BlockDescBind
::
AllVars
,
py
::
return_value_policy
::
reference
)
.
def
(
"all_ops"
,
&
BlockDescBind
::
AllOps
,
py
::
return_value_policy
::
reference
);
py
::
return_value_policy
::
reference
)
.
def
(
"serialize_to_string"
,
[](
BlockDescBind
&
block_desc
)
->
py
::
bytes
{
const
BlockDesc
*
desc
=
block_desc
.
Proto
();
PADDLE_ENFORCE
(
desc
->
IsInitialized
(),
"BlockDesc has not been initialized."
);
std
::
string
res
;
PADDLE_ENFORCE
(
desc
->
SerializeToString
(
&
res
),
"Serialize BlockDesc Error. This could be a bug of Paddle."
);
return
res
;
});
}
void
BindVarDsec
(
py
::
module
&
m
)
{
...
...
@@ -177,7 +198,17 @@ void BindVarDsec(py::module &m) {
.
def
(
"lod_level"
,
&
VarDescBind
::
GetLodLevel
)
.
def
(
"set_lod_level"
,
&
VarDescBind
::
SetLoDLevel
)
.
def
(
"type"
,
&
VarDescBind
::
GetType
)
.
def
(
"set_type"
,
&
VarDescBind
::
SetType
);
.
def
(
"set_type"
,
&
VarDescBind
::
SetType
)
.
def
(
"serialize_to_string"
,
[](
VarDescBind
&
var_desc
)
->
py
::
bytes
{
const
VarDesc
*
desc
=
var_desc
.
Proto
();
PADDLE_ENFORCE
(
desc
->
IsInitialized
(),
"VarDesc has not been initialized."
);
std
::
string
res
;
PADDLE_ENFORCE
(
desc
->
SerializeToString
(
&
res
),
"Serialize VarDesc Error. This could be a bug of Paddle."
);
return
res
;
});
py
::
enum_
<
VarDesc
::
VarType
>
(
var_desc
,
"VarType"
,
""
)
.
value
(
"LOD_TENSOR"
,
VarDesc
::
LOD_TENSOR
)
...
...
@@ -213,7 +244,17 @@ void BindOpDesc(py::module &m) {
.
def
(
"set_block_attr"
,
&
OpDescBind
::
SetBlockAttr
)
.
def
(
"block_attr"
,
&
OpDescBind
::
GetBlockAttr
)
.
def
(
"check_attrs"
,
&
OpDescBind
::
CheckAttrs
)
.
def
(
"infer_shape"
,
&
OpDescBind
::
InferShape
);
.
def
(
"infer_shape"
,
&
OpDescBind
::
InferShape
)
.
def
(
"serialize_to_string"
,
[](
OpDescBind
&
op_desc
)
->
py
::
bytes
{
const
OpDesc
*
desc
=
op_desc
.
Proto
();
PADDLE_ENFORCE
(
desc
->
IsInitialized
(),
"OpDesc has not been initialized."
);
std
::
string
res
;
PADDLE_ENFORCE
(
desc
->
SerializeToString
(
&
res
),
"Serialize OpDesc Error. This could be a bug of Paddle."
);
return
res
;
});
}
}
// namespace pybind
...
...
python/paddle/v2/framework/framework.py
浏览文件 @
4130e5fa
...
...
@@ -73,6 +73,13 @@ class Variable(object):
self
.
block
.
vars
[
name
]
=
self
self
.
op
=
None
def
__str__
(
self
):
protostr
=
self
.
desc
.
serialize_to_string
()
proto
=
framework_pb2
.
VarDesc
.
FromString
(
str
(
protostr
))
return
proto
.
__str__
()
__repr__
=
__str__
@
property
def
name
(
self
):
return
self
.
desc
.
name
()
...
...
@@ -169,6 +176,18 @@ class Operator(object):
proto
=
OpProtoHolder
.
instance
().
get_op_proto
(
type
)
if
inputs
is
not
None
:
given
=
set
()
need
=
set
()
for
n
in
inputs
:
given
.
add
(
n
)
for
m
in
proto
.
inputs
:
need
.
add
(
m
.
name
)
if
not
given
==
need
:
raise
ValueError
(
"Incorrect setting for input(s) of operator
\"
%s
\"
. Need: [%s] Given: [%s]"
%
(
type
,
", "
.
join
(
str
(
e
)
for
e
in
need
),
", "
.
join
(
str
(
e
)
for
e
in
given
)))
for
in_proto
in
proto
.
inputs
:
in_argus
=
inputs
[
in_proto
.
name
]
if
not
isinstance
(
in_argus
,
list
):
...
...
@@ -183,6 +202,18 @@ class Operator(object):
self
.
desc
.
set_input
(
in_proto
.
name
,
in_argu_names
)
if
outputs
is
not
None
:
given
=
set
()
need
=
set
()
for
n
in
outputs
:
given
.
add
(
n
)
for
m
in
proto
.
outputs
:
need
.
add
(
m
.
name
)
if
not
given
==
need
:
raise
ValueError
(
"Incorrect setting for output(s) of operator
\"
%s
\"
. Need: [%s] Given: [%s]"
%
(
type
,
", "
.
join
(
str
(
e
)
for
e
in
need
),
", "
.
join
(
str
(
e
)
for
e
in
given
)))
for
out_proto
in
proto
.
outputs
:
out_argus
=
outputs
[
out_proto
.
name
]
if
not
isinstance
(
out_argus
,
list
):
...
...
@@ -210,6 +241,13 @@ class Operator(object):
self
.
desc
.
check_attrs
()
self
.
desc
.
infer_shape
(
self
.
block
.
desc
)
def
__str__
(
self
):
protostr
=
self
.
desc
.
serialize_to_string
()
proto
=
framework_pb2
.
OpDesc
.
FromString
(
str
(
protostr
))
return
proto
.
__str__
()
__repr__
=
__str__
@
property
def
type
(
self
):
return
self
.
desc
.
type
()
...
...
@@ -252,6 +290,13 @@ class Block(object):
self
.
ops
=
collections
.
deque
()
# operator list
self
.
program
=
program
def
__str__
(
self
):
protostr
=
self
.
desc
.
serialize_to_string
()
proto
=
framework_pb2
.
BlockDesc
.
FromString
(
str
(
protostr
))
return
proto
.
__str__
()
__repr__
=
__str__
@
property
def
parent_idx
(
self
):
return
self
.
desc
.
parent
...
...
@@ -296,6 +341,13 @@ class Program(object):
self
.
blocks
=
[
Block
(
self
,
0
)]
self
.
current_block_idx
=
0
def
__str__
(
self
):
protostr
=
self
.
desc
.
serialize_to_string
()
proto
=
framework_pb2
.
ProgramDesc
.
FromString
(
str
(
protostr
))
return
proto
.
__str__
()
__repr__
=
__str__
def
global_block
(
self
):
return
self
.
blocks
[
0
]
...
...
python/paddle/v2/framework/tests/test_operator_desc.py
浏览文件 @
4130e5fa
...
...
@@ -34,6 +34,8 @@ class TestOperator(unittest.TestCase):
"Y"
:
mul_y
},
outputs
=
{
"Out"
:
[
mul_out
]},
attrs
=
{
"x_num_col_dims"
:
1
})
self
.
assertNotEqual
(
str
(
mul_op
),
""
)
self
.
assertEqual
(
mul_op
.
type
,
"mul"
)
self
.
assertEqual
(
mul_op
.
input_names
,
[
"X"
,
"Y"
])
self
.
assertEqual
(
mul_op
.
input
(
"X"
),
[
"mul.x"
])
...
...
python/paddle/v2/framework/tests/test_variable.py
浏览文件 @
4130e5fa
...
...
@@ -21,6 +21,7 @@ class TestVariable(unittest.TestCase):
b
=
g_program
.
current_block
()
w
=
b
.
create_var
(
dtype
=
"float64"
,
shape
=
[
784
,
100
],
lod_level
=
0
,
name
=
"fc.w"
)
self
.
assertNotEqual
(
str
(
w
),
""
)
self
.
assertEqual
(
core
.
DataType
.
FP64
,
w
.
data_type
)
self
.
assertEqual
((
784
,
100
),
w
.
shape
)
self
.
assertEqual
(
"fc.w"
,
w
.
name
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录