Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
09189732
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
699
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
09189732
编写于
12月 21, 2017
作者:
Y
Yu Yang
提交者:
GitHub
12月 21, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rename XXDescBind --> XXDesc (#6797)
* Rename XXDescBind --> XXDesc * Fix Compile
上级
0295b000
变更
48
显示空白变更内容
内联
并排
Showing
48 changed file
with
447 addition
and
472 deletion
+447
-472
paddle/framework/backward.cc
paddle/framework/backward.cc
+42
-46
paddle/framework/backward.h
paddle/framework/backward.h
+1
-1
paddle/framework/backward_test.cc
paddle/framework/backward_test.cc
+63
-63
paddle/framework/block_desc.cc
paddle/framework/block_desc.cc
+29
-30
paddle/framework/block_desc.h
paddle/framework/block_desc.h
+21
-22
paddle/framework/details/op_registry.h
paddle/framework/details/op_registry.h
+3
-3
paddle/framework/executor.cc
paddle/framework/executor.cc
+1
-1
paddle/framework/executor.h
paddle/framework/executor.h
+1
-1
paddle/framework/grad_op_desc_maker.h
paddle/framework/grad_op_desc_maker.h
+12
-14
paddle/framework/op_desc.cc
paddle/framework/op_desc.cc
+36
-42
paddle/framework/op_desc.h
paddle/framework/op_desc.h
+10
-10
paddle/framework/op_registry.cc
paddle/framework/op_registry.cc
+2
-2
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+1
-1
paddle/framework/program_desc.cc
paddle/framework/program_desc.cc
+11
-11
paddle/framework/program_desc.h
paddle/framework/program_desc.h
+10
-10
paddle/framework/program_desc_test.cc
paddle/framework/program_desc_test.cc
+6
-6
paddle/framework/prune_test.cc
paddle/framework/prune_test.cc
+11
-11
paddle/framework/type_defs.h
paddle/framework/type_defs.h
+8
-10
paddle/framework/var_desc.cc
paddle/framework/var_desc.cc
+10
-12
paddle/framework/var_desc.h
paddle/framework/var_desc.h
+3
-3
paddle/framework/var_type_inference.h
paddle/framework/var_type_inference.h
+1
-2
paddle/framework/var_type_inference_test.cc
paddle/framework/var_type_inference_test.cc
+3
-4
paddle/operators/array_to_lod_tensor_op.cc
paddle/operators/array_to_lod_tensor_op.cc
+3
-3
paddle/operators/assign_op.cc
paddle/operators/assign_op.cc
+3
-3
paddle/operators/beam_search_decode_op.cc
paddle/operators/beam_search_decode_op.cc
+2
-2
paddle/operators/cast_op.cc
paddle/operators/cast_op.cc
+3
-3
paddle/operators/conditional_block_op.cc
paddle/operators/conditional_block_op.cc
+6
-6
paddle/operators/increment_op.cc
paddle/operators/increment_op.cc
+3
-3
paddle/operators/lod_rank_table_op.cc
paddle/operators/lod_rank_table_op.cc
+2
-2
paddle/operators/lod_tensor_to_array_op.cc
paddle/operators/lod_tensor_to_array_op.cc
+5
-5
paddle/operators/lookup_table_op.cc
paddle/operators/lookup_table_op.cc
+2
-2
paddle/operators/mean_op.cc
paddle/operators/mean_op.cc
+3
-3
paddle/operators/merge_lod_tensor_op.cc
paddle/operators/merge_lod_tensor_op.cc
+3
-3
paddle/operators/minus_op.cc
paddle/operators/minus_op.cc
+4
-5
paddle/operators/nccl_op_test.cu.cc
paddle/operators/nccl_op_test.cu.cc
+7
-8
paddle/operators/pad_op.cc
paddle/operators/pad_op.cc
+3
-3
paddle/operators/recurrent_op.cc
paddle/operators/recurrent_op.cc
+6
-7
paddle/operators/scale_op.cc
paddle/operators/scale_op.cc
+3
-3
paddle/operators/shrink_rnn_memory_op.cc
paddle/operators/shrink_rnn_memory_op.cc
+3
-3
paddle/operators/sign_op.cc
paddle/operators/sign_op.cc
+3
-3
paddle/operators/softmax_with_cross_entropy_op.cc
paddle/operators/softmax_with_cross_entropy_op.cc
+3
-3
paddle/operators/split_lod_tensor_op.cc
paddle/operators/split_lod_tensor_op.cc
+3
-3
paddle/operators/split_op.cc
paddle/operators/split_op.cc
+3
-3
paddle/operators/sum_op.cc
paddle/operators/sum_op.cc
+6
-7
paddle/operators/tensor_array_read_write_op.cc
paddle/operators/tensor_array_read_write_op.cc
+8
-8
paddle/operators/while_op.cc
paddle/operators/while_op.cc
+9
-9
paddle/pybind/protobuf.cc
paddle/pybind/protobuf.cc
+56
-57
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+10
-10
未找到文件。
paddle/framework/backward.cc
浏览文件 @
09189732
...
@@ -42,7 +42,7 @@ static std::unordered_set<std::string>& CtrlFlowOps() {
...
@@ -42,7 +42,7 @@ static std::unordered_set<std::string>& CtrlFlowOps() {
static
inline
std
::
unique_ptr
<
OperatorBase
>
CreateGradOp
(
static
inline
std
::
unique_ptr
<
OperatorBase
>
CreateGradOp
(
const
OperatorBase
&
op
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_set
,
const
OperatorBase
&
op
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_set
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
grad_to_var
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
grad_to_var
)
{
OpDesc
Bind
op_desc
;
OpDesc
op_desc
;
op_desc
.
SetInputMap
(
op
.
Inputs
());
op_desc
.
SetInputMap
(
op
.
Inputs
());
op_desc
.
SetOutputMap
(
op
.
Outputs
());
op_desc
.
SetOutputMap
(
op
.
Outputs
());
op_desc
.
SetType
(
op
.
Type
());
op_desc
.
SetType
(
op
.
Type
());
...
@@ -53,7 +53,7 @@ static inline std::unique_ptr<OperatorBase> CreateGradOp(
...
@@ -53,7 +53,7 @@ static inline std::unique_ptr<OperatorBase> CreateGradOp(
grad_ops
.
reserve
(
grad_descs
.
size
());
grad_ops
.
reserve
(
grad_descs
.
size
());
std
::
transform
(
grad_descs
.
begin
(),
grad_descs
.
end
(),
std
::
transform
(
grad_descs
.
begin
(),
grad_descs
.
end
(),
std
::
back_inserter
(
grad_ops
),
std
::
back_inserter
(
grad_ops
),
[](
const
std
::
unique_ptr
<
OpDesc
Bind
>&
grad_desc
)
{
[](
const
std
::
unique_ptr
<
OpDesc
>&
grad_desc
)
{
return
OpRegistry
::
CreateOp
(
*
grad_desc
);
return
OpRegistry
::
CreateOp
(
*
grad_desc
);
});
});
PADDLE_ENFORCE
(
!
grad_ops
.
empty
());
PADDLE_ENFORCE
(
!
grad_ops
.
empty
());
...
@@ -296,7 +296,7 @@ static std::string FwdName(const std::string& grad_name) {
...
@@ -296,7 +296,7 @@ static std::string FwdName(const std::string& grad_name) {
static
void
CreateGradVarInBlock
(
static
void
CreateGradVarInBlock
(
size_t
grad_op_start_index
,
size_t
grad_op_start_index
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
param_name_map
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
param_name_map
,
BlockDesc
Bind
*
block_desc
,
BlockDesc
*
block_desc
,
std
::
unordered_map
<
std
::
string
,
GradVarInfo
>*
grad_var_record
)
{
std
::
unordered_map
<
std
::
string
,
GradVarInfo
>*
grad_var_record
)
{
auto
ops
=
block_desc
->
AllOps
();
auto
ops
=
block_desc
->
AllOps
();
for
(
size_t
op_index
=
grad_op_start_index
;
op_index
<
ops
.
size
();
for
(
size_t
op_index
=
grad_op_start_index
;
op_index
<
ops
.
size
();
...
@@ -350,12 +350,11 @@ static void CreateGradVarInBlock(
...
@@ -350,12 +350,11 @@ static void CreateGradVarInBlock(
}
}
}
}
std
::
vector
<
std
::
unique_ptr
<
OpDesc
Bind
>>
MakeOpGrad
(
std
::
vector
<
std
::
unique_ptr
<
OpDesc
>>
MakeOpGrad
(
const
OpDesc
Bind
*
op_desc
,
std
::
unordered_set
<
std
::
string
>*
no_grad_vars
,
const
OpDesc
*
op_desc
,
std
::
unordered_set
<
std
::
string
>*
no_grad_vars
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
grad_to_var
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
grad_to_var
,
const
std
::
vector
<
BlockDescBind
*>&
grad_block
=
const
std
::
vector
<
BlockDesc
*>&
grad_block
=
std
::
vector
<
BlockDesc
*>
())
{
std
::
vector
<
BlockDescBind
*>
())
{
std
::
vector
<
std
::
unique_ptr
<
OpDesc
>>
grad_op_descs
;
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
grad_op_descs
;
// All input gradients of forwarding operator do not need to calculate.
// All input gradients of forwarding operator do not need to calculate.
const
std
::
vector
<
std
::
string
>&
inputs
=
op_desc
->
InputArgumentNames
();
const
std
::
vector
<
std
::
string
>&
inputs
=
op_desc
->
InputArgumentNames
();
if
(
AllGradInSet
(
inputs
,
*
no_grad_vars
))
{
if
(
AllGradInSet
(
inputs
,
*
no_grad_vars
))
{
...
@@ -386,7 +385,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
...
@@ -386,7 +385,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
.
Get
(
op_desc
->
Type
())
.
Get
(
op_desc
->
Type
())
.
GradOpMaker
()(
*
op_desc
,
*
no_grad_vars
,
grad_to_var
,
grad_block
);
.
GradOpMaker
()(
*
op_desc
,
*
no_grad_vars
,
grad_to_var
,
grad_block
);
std
::
list
<
std
::
unique_ptr
<
OpDesc
Bind
>>
pending_fill_zeros_ops
;
std
::
list
<
std
::
unique_ptr
<
OpDesc
>>
pending_fill_zeros_ops
;
for
(
auto
&
desc
:
grad_op_descs
)
{
for
(
auto
&
desc
:
grad_op_descs
)
{
for
(
const
std
::
string
&
in_name
:
desc
->
InputArgumentNames
())
{
for
(
const
std
::
string
&
in_name
:
desc
->
InputArgumentNames
())
{
if
(
no_grad_vars
->
count
(
in_name
))
{
if
(
no_grad_vars
->
count
(
in_name
))
{
...
@@ -394,8 +393,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
...
@@ -394,8 +393,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
0
,
in_name
.
size
()
-
sizeof
(
kGradVarSuffix
)
/
sizeof
(
char
)
+
1
);
0
,
in_name
.
size
()
-
sizeof
(
kGradVarSuffix
)
/
sizeof
(
char
)
+
1
);
std
::
string
new_name
=
prefix
+
kZeroVarSuffix
;
std
::
string
new_name
=
prefix
+
kZeroVarSuffix
;
desc
->
Rename
(
in_name
,
new_name
);
desc
->
Rename
(
in_name
,
new_name
);
std
::
unique_ptr
<
OpDesc
Bind
>
fill_zeros_op
(
std
::
unique_ptr
<
OpDesc
>
fill_zeros_op
(
new
OpDesc
Bind
(
"fill_zeros_like"
,
{{
"X"
,
{
prefix
}}},
new
OpDesc
(
"fill_zeros_like"
,
{{
"X"
,
{
prefix
}}},
{{
"Y"
,
{
new_name
}}},
AttributeMap
{}));
{{
"Y"
,
{
new_name
}}},
AttributeMap
{}));
pending_fill_zeros_ops
.
push_back
(
std
::
move
(
fill_zeros_op
));
pending_fill_zeros_ops
.
push_back
(
std
::
move
(
fill_zeros_op
));
}
}
...
@@ -408,34 +407,33 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
...
@@ -408,34 +407,33 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
return
grad_op_descs
;
return
grad_op_descs
;
}
}
static
BlockDescBind
*
CreateStepBlock
(
static
BlockDesc
*
CreateStepBlock
(
ProgramDescBind
&
program_desc
,
ProgramDesc
&
program_desc
,
std
::
unordered_set
<
std
::
string
>*
no_grad_vars
,
std
::
unordered_set
<
std
::
string
>*
no_grad_vars
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
grad_to_var
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
grad_to_var
,
int
step_block_idx
);
int
step_block_idx
);
std
::
vector
<
std
::
unique_ptr
<
OpDesc
Bind
>>
MakeBlockBackward
(
std
::
vector
<
std
::
unique_ptr
<
OpDesc
>>
MakeBlockBackward
(
ProgramDesc
Bind
&
program_desc
,
int
block_idx
,
ProgramDesc
&
program_desc
,
int
block_idx
,
std
::
unordered_set
<
std
::
string
>*
no_grad_vars
,
std
::
unordered_set
<
std
::
string
>*
no_grad_vars
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
grad_to_var
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
grad_to_var
)
{
VLOG
(
5
)
<<
"MakeBlockBackward"
;
VLOG
(
5
)
<<
"MakeBlockBackward"
;
BlockDesc
Bind
*
cur_block
=
program_desc
.
MutableBlock
(
block_idx
);
BlockDesc
*
cur_block
=
program_desc
.
MutableBlock
(
block_idx
);
std
::
vector
<
OpDesc
Bind
*>
op_descs
=
cur_block
->
AllOps
();
std
::
vector
<
OpDesc
*>
op_descs
=
cur_block
->
AllOps
();
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
size_t
>>
dup_out_ops
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
size_t
>>
dup_out_ops
;
size_t
grad_desc_idx
=
0
;
size_t
grad_desc_idx
=
0
;
std
::
vector
<
std
::
unique_ptr
<
OpDesc
Bind
>>
backward_descs
;
std
::
vector
<
std
::
unique_ptr
<
OpDesc
>>
backward_descs
;
for
(
auto
it
=
op_descs
.
rbegin
();
it
!=
op_descs
.
rend
();
++
it
)
{
for
(
auto
it
=
op_descs
.
rbegin
();
it
!=
op_descs
.
rend
();
++
it
)
{
VLOG
(
5
)
<<
"Making backward "
<<
(
*
it
)
->
Type
()
<<
" op"
;
VLOG
(
5
)
<<
"Making backward "
<<
(
*
it
)
->
Type
()
<<
" op"
;
std
::
vector
<
std
::
unique_ptr
<
OpDesc
Bind
>>
op_grads
;
std
::
vector
<
std
::
unique_ptr
<
OpDesc
>>
op_grads
;
if
((
*
it
)
->
Type
()
==
"recurrent"
||
(
*
it
)
->
Type
()
==
"while"
)
{
if
((
*
it
)
->
Type
()
==
"recurrent"
||
(
*
it
)
->
Type
()
==
"while"
)
{
int
step_block_idx
=
(
*
it
)
->
GetBlockAttr
(
"sub_block"
);
int
step_block_idx
=
(
*
it
)
->
GetBlockAttr
(
"sub_block"
);
BlockDesc
Bind
*
backward_block
=
CreateStepBlock
(
BlockDesc
*
backward_block
=
CreateStepBlock
(
program_desc
,
no_grad_vars
,
program_desc
,
no_grad_vars
,
grad_to_var
,
step_block_idx
);
grad_to_var
,
step_block_idx
);
op_grads
=
MakeOpGrad
(
*
it
,
no_grad_vars
,
grad_to_var
,
{
backward_block
});
op_grads
=
MakeOpGrad
(
*
it
,
no_grad_vars
,
grad_to_var
,
{
backward_block
});
}
else
if
((
*
it
)
->
Type
()
==
"conditional_block"
)
{
}
else
if
((
*
it
)
->
Type
()
==
"conditional_block"
)
{
BlockDesc
Bind
*
backward_block
=
BlockDesc
*
backward_block
=
CreateStepBlock
(
program_desc
,
no_grad_vars
,
grad_to_var
,
CreateStepBlock
(
program_desc
,
no_grad_vars
,
grad_to_var
,
(
*
it
)
->
GetBlockAttr
(
"sub_block"
));
(
*
it
)
->
GetBlockAttr
(
"sub_block"
));
op_grads
=
MakeOpGrad
(
*
it
,
no_grad_vars
,
grad_to_var
,
{
backward_block
});
op_grads
=
MakeOpGrad
(
*
it
,
no_grad_vars
,
grad_to_var
,
{
backward_block
});
...
@@ -463,14 +461,14 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
...
@@ -463,14 +461,14 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
}
}
++
grad_desc_idx
;
++
grad_desc_idx
;
}
}
std
::
transform
(
std
::
transform
(
op_grads
.
begin
(),
op_grads
.
end
(),
op_grads
.
begin
(),
op_grads
.
end
(),
std
::
back_inserter
(
backward_descs
),
std
::
back_inserter
(
backward_descs
),
[](
std
::
unique_ptr
<
OpDescBind
>&
ptr
)
{
return
std
::
move
(
ptr
);
});
[](
std
::
unique_ptr
<
OpDesc
>&
ptr
)
{
return
std
::
move
(
ptr
);
});
}
}
VLOG
(
5
)
<<
"Appending Sums"
;
VLOG
(
5
)
<<
"Appending Sums"
;
// Check whether some variables are written more than once
// Check whether some variables are written more than once
std
::
list
<
std
::
pair
<
size_t
,
std
::
unique_ptr
<
OpDesc
Bind
>>>
pending_sum_ops
;
std
::
list
<
std
::
pair
<
size_t
,
std
::
unique_ptr
<
OpDesc
>>>
pending_sum_ops
;
for
(
const
auto
&
dup
:
dup_out_ops
)
{
for
(
const
auto
&
dup
:
dup_out_ops
)
{
const
std
::
string
&
out_name
=
dup
.
first
;
const
std
::
string
&
out_name
=
dup
.
first
;
const
std
::
vector
<
size_t
>
dup_op
=
dup
.
second
;
const
std
::
vector
<
size_t
>
dup_op
=
dup
.
second
;
...
@@ -486,16 +484,15 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
...
@@ -486,16 +484,15 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
sum_op_inputs
.
emplace_back
(
new_name
);
sum_op_inputs
.
emplace_back
(
new_name
);
next_g_name
=
sum_op_inputs
.
back
();
next_g_name
=
sum_op_inputs
.
back
();
}
}
std
::
unique_ptr
<
OpDesc
Bind
>
sum_op
(
std
::
unique_ptr
<
OpDesc
>
sum_op
(
new
OpDesc
(
"sum"
,
{{
"X"
,
sum_op_inputs
}},
new
OpDescBind
(
"sum"
,
{{
"X"
,
sum_op_inputs
}},
{{
"Out"
,
{
out_name
}}},
{{
"Out"
,
{
out_name
}}},
AttributeMap
{}));
AttributeMap
{}));
pending_sum_ops
.
push_back
({
dup_op
.
back
(),
std
::
move
(
sum_op
)});
pending_sum_ops
.
push_back
({
dup_op
.
back
(),
std
::
move
(
sum_op
)});
}
}
}
}
pending_sum_ops
.
sort
(
pending_sum_ops
.
sort
([](
const
std
::
pair
<
size_t
,
std
::
unique_ptr
<
OpDesc
>>&
a
,
[](
const
std
::
pair
<
size_t
,
std
::
unique_ptr
<
OpDescBind
>>&
a
,
const
std
::
pair
<
size_t
,
std
::
unique_ptr
<
OpDesc
>>&
b
)
{
const
std
::
pair
<
size_t
,
std
::
unique_ptr
<
OpDescBind
>>&
b
)
{
return
a
.
first
>
b
.
first
;
return
a
.
first
>
b
.
first
;
});
});
for
(
auto
&
p
:
pending_sum_ops
)
{
for
(
auto
&
p
:
pending_sum_ops
)
{
...
@@ -508,14 +505,13 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
...
@@ -508,14 +505,13 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
return
backward_descs
;
return
backward_descs
;
}
}
static
BlockDescBind
*
CreateStepBlock
(
static
BlockDesc
*
CreateStepBlock
(
ProgramDescBind
&
program_desc
,
ProgramDesc
&
program_desc
,
std
::
unordered_set
<
std
::
string
>*
no_grad_vars
,
std
::
unordered_set
<
std
::
string
>*
no_grad_vars
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
grad_to_var
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
grad_to_var
,
int
step_block_idx
)
{
int
step_block_idx
)
{
auto
backward_block_op_descs
=
MakeBlockBackward
(
program_desc
,
step_block_idx
,
auto
backward_block_op_descs
=
MakeBlockBackward
(
program_desc
,
step_block_idx
,
no_grad_vars
,
grad_to_var
);
no_grad_vars
,
grad_to_var
);
BlockDesc
Bind
*
backward_block
=
BlockDesc
*
backward_block
=
program_desc
.
AppendBlock
(
*
program_desc
.
MutableBlock
(
step_block_idx
));
program_desc
.
AppendBlock
(
*
program_desc
.
MutableBlock
(
step_block_idx
));
for
(
auto
&
ptr
:
backward_block_op_descs
)
{
for
(
auto
&
ptr
:
backward_block_op_descs
)
{
backward_block
->
AppendAllocatedOp
(
move
(
ptr
));
backward_block
->
AppendAllocatedOp
(
move
(
ptr
));
...
@@ -524,7 +520,7 @@ static BlockDescBind* CreateStepBlock(
...
@@ -524,7 +520,7 @@ static BlockDescBind* CreateStepBlock(
}
}
ParamGradInfoMap
AppendBackward
(
ParamGradInfoMap
AppendBackward
(
ProgramDesc
Bind
&
program_desc
,
const
VarDescBind
&
target
,
ProgramDesc
&
program_desc
,
const
VarDesc
&
target
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_vars
)
{
const
std
::
unordered_set
<
std
::
string
>&
no_grad_vars
)
{
std
::
unordered_set
<
std
::
string
>
no_grad_var_names
;
std
::
unordered_set
<
std
::
string
>
no_grad_var_names
;
no_grad_var_names
.
reserve
(
no_grad_vars
.
size
()
+
1
);
no_grad_var_names
.
reserve
(
no_grad_vars
.
size
()
+
1
);
...
@@ -541,8 +537,8 @@ ParamGradInfoMap AppendBackward(
...
@@ -541,8 +537,8 @@ ParamGradInfoMap AppendBackward(
PADDLE_ENFORCE
(
is_scalar
,
"target should be scalar"
);
PADDLE_ENFORCE
(
is_scalar
,
"target should be scalar"
);
VLOG
(
3
)
<<
"backward from loss="
<<
target
.
Name
()
VLOG
(
3
)
<<
"backward from loss="
<<
target
.
Name
()
<<
" data_type="
<<
target
.
GetDataType
();
<<
" data_type="
<<
target
.
GetDataType
();
std
::
unique_ptr
<
OpDesc
Bind
>
fill_one_op
(
std
::
unique_ptr
<
OpDesc
>
fill_one_op
(
new
OpDesc
Bind
(
"fill_constant"
,
{},
{{
"Out"
,
{
fill_one_op_out
}}},
new
OpDesc
(
"fill_constant"
,
{},
{{
"Out"
,
{
fill_one_op_out
}}},
{{
"shape"
,
std
::
vector
<
int
>
{
1
}},
{{
"shape"
,
std
::
vector
<
int
>
{
1
}},
{
"value"
,
static_cast
<
float
>
(
1.0
)},
{
"value"
,
static_cast
<
float
>
(
1.0
)},
{
"dtype"
,
target
.
GetDataType
()}}));
{
"dtype"
,
target
.
GetDataType
()}}));
...
...
paddle/framework/backward.h
浏览文件 @
09189732
...
@@ -49,7 +49,7 @@ using ParamGradInfoMap = std::unordered_map<std::string /*fwd_var_name*/,
...
@@ -49,7 +49,7 @@ using ParamGradInfoMap = std::unordered_map<std::string /*fwd_var_name*/,
GradVarInfo
/*grad_var_info*/
>
;
GradVarInfo
/*grad_var_info*/
>
;
ParamGradInfoMap
AppendBackward
(
ParamGradInfoMap
AppendBackward
(
ProgramDesc
Bind
&
program_desc
,
const
VarDescBind
&
target
,
ProgramDesc
&
program_desc
,
const
VarDesc
&
target
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_vars
);
const
std
::
unordered_set
<
std
::
string
>&
no_grad_vars
);
}
// namespace framework
}
// namespace framework
...
...
paddle/framework/backward_test.cc
浏览文件 @
09189732
...
@@ -58,13 +58,13 @@ class RowWiseAddGradMaker : public SingleGradOpDescMaker {
...
@@ -58,13 +58,13 @@ class RowWiseAddGradMaker : public SingleGradOpDescMaker {
using
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
protected:
std
::
unique_ptr
<
OpDesc
Bind
>
Apply
()
const
override
{
std
::
unique_ptr
<
OpDesc
>
Apply
()
const
override
{
auto
grad_op
=
new
OpDesc
Bind
();
auto
grad_op
=
new
OpDesc
();
grad_op
->
SetInput
(
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
grad_op
->
SetInput
(
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
grad_op
->
SetOutput
(
GradVarName
(
"X"
),
InputGrad
(
"X"
));
grad_op
->
SetOutput
(
GradVarName
(
"X"
),
InputGrad
(
"X"
));
grad_op
->
SetOutput
(
GradVarName
(
"b"
),
InputGrad
(
"b"
));
grad_op
->
SetOutput
(
GradVarName
(
"b"
),
InputGrad
(
"b"
));
grad_op
->
SetType
(
"rowwise_add_grad"
);
grad_op
->
SetType
(
"rowwise_add_grad"
);
return
std
::
unique_ptr
<
OpDesc
Bind
>
(
grad_op
);
return
std
::
unique_ptr
<
OpDesc
>
(
grad_op
);
}
}
};
};
...
@@ -190,11 +190,11 @@ class MinusGradOpDescMaker : public GradOpDescMakerBase {
...
@@ -190,11 +190,11 @@ class MinusGradOpDescMaker : public GradOpDescMakerBase {
public:
public:
using
GradOpDescMakerBase
::
GradOpDescMakerBase
;
using
GradOpDescMakerBase
::
GradOpDescMakerBase
;
std
::
vector
<
std
::
unique_ptr
<
OpDesc
Bind
>>
operator
()()
const
override
{
std
::
vector
<
std
::
unique_ptr
<
OpDesc
>>
operator
()()
const
override
{
std
::
vector
<
std
::
unique_ptr
<
OpDesc
Bind
>>
retv
;
std
::
vector
<
std
::
unique_ptr
<
OpDesc
>>
retv
;
auto
x_g
=
InputGrad
(
"X"
);
auto
x_g
=
InputGrad
(
"X"
);
if
(
!
x_g
.
empty
())
{
if
(
!
x_g
.
empty
())
{
auto
*
op_desc
=
new
OpDesc
Bind
();
auto
*
op_desc
=
new
OpDesc
();
op_desc
->
SetType
(
"scale"
);
op_desc
->
SetType
(
"scale"
);
op_desc
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
op_desc
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
op_desc
->
SetOutput
(
"Out"
,
x_g
);
op_desc
->
SetOutput
(
"Out"
,
x_g
);
...
@@ -204,7 +204,7 @@ class MinusGradOpDescMaker : public GradOpDescMakerBase {
...
@@ -204,7 +204,7 @@ class MinusGradOpDescMaker : public GradOpDescMakerBase {
auto
y_g
=
InputGrad
(
"Y"
);
auto
y_g
=
InputGrad
(
"Y"
);
if
(
!
y_g
.
empty
())
{
if
(
!
y_g
.
empty
())
{
auto
*
op_desc
=
new
OpDesc
Bind
();
auto
*
op_desc
=
new
OpDesc
();
op_desc
->
SetType
(
"scale"
);
op_desc
->
SetType
(
"scale"
);
op_desc
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
op_desc
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
op_desc
->
SetOutput
(
"Out"
,
y_g
);
op_desc
->
SetOutput
(
"Out"
,
y_g
);
...
@@ -505,25 +505,25 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
...
@@ -505,25 +505,25 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
}
}
TEST
(
Backward
,
simple_single_op
)
{
TEST
(
Backward
,
simple_single_op
)
{
f
::
ProgramDesc
Bind
program
;
f
::
ProgramDesc
program
;
f
::
BlockDesc
Bind
*
block
=
program
.
MutableBlock
(
0
);
f
::
BlockDesc
*
block
=
program
.
MutableBlock
(
0
);
f
::
OpDesc
Bind
*
op
=
block
->
AppendOp
();
f
::
OpDesc
*
op
=
block
->
AppendOp
();
op
->
SetType
(
"rowwise_add"
);
op
->
SetType
(
"rowwise_add"
);
op
->
SetInput
(
"X"
,
{
"x"
});
op
->
SetInput
(
"X"
,
{
"x"
});
op
->
SetInput
(
"b"
,
{
"b"
});
op
->
SetInput
(
"b"
,
{
"b"
});
op
->
SetOutput
(
"Out"
,
{
"out"
});
op
->
SetOutput
(
"Out"
,
{
"out"
});
auto
target
=
f
::
VarDesc
Bind
(
"out"
);
auto
target
=
f
::
VarDesc
(
"out"
);
target
.
SetShape
({
1
});
target
.
SetShape
({
1
});
auto
var_to_grad
=
auto
var_to_grad
=
AppendBackward
(
program
,
target
,
std
::
unordered_set
<
std
::
string
>
{});
AppendBackward
(
program
,
target
,
std
::
unordered_set
<
std
::
string
>
{});
ASSERT_EQ
(
block
->
AllOps
().
size
(),
3UL
);
ASSERT_EQ
(
block
->
AllOps
().
size
(),
3UL
);
f
::
OpDesc
Bind
*
fill_op
=
block
->
AllOps
()[
1
];
f
::
OpDesc
*
fill_op
=
block
->
AllOps
()[
1
];
EXPECT_EQ
(
fill_op
->
Type
(),
"fill_constant"
);
EXPECT_EQ
(
fill_op
->
Type
(),
"fill_constant"
);
f
::
OpDesc
Bind
*
grad_op
=
block
->
AllOps
()[
2
];
f
::
OpDesc
*
grad_op
=
block
->
AllOps
()[
2
];
EXPECT_EQ
(
grad_op
->
Type
(),
"rowwise_add_grad"
);
EXPECT_EQ
(
grad_op
->
Type
(),
"rowwise_add_grad"
);
ASSERT_EQ
(
grad_op
->
InputNames
().
size
(),
1UL
);
ASSERT_EQ
(
grad_op
->
InputNames
().
size
(),
1UL
);
ASSERT_EQ
(
grad_op
->
OutputNames
().
size
(),
2UL
);
ASSERT_EQ
(
grad_op
->
OutputNames
().
size
(),
2UL
);
...
@@ -543,16 +543,16 @@ TEST(Backward, simple_single_op) {
...
@@ -543,16 +543,16 @@ TEST(Backward, simple_single_op) {
}
}
TEST
(
Backward
,
default_attribute
)
{
TEST
(
Backward
,
default_attribute
)
{
f
::
ProgramDesc
Bind
program
;
f
::
ProgramDesc
program
;
f
::
BlockDesc
Bind
*
block
=
program
.
MutableBlock
(
0
);
f
::
BlockDesc
*
block
=
program
.
MutableBlock
(
0
);
f
::
OpDesc
Bind
*
op
=
block
->
AppendOp
();
f
::
OpDesc
*
op
=
block
->
AppendOp
();
op
->
SetType
(
"mul"
);
op
->
SetType
(
"mul"
);
op
->
SetInput
(
"X"
,
{
"x"
});
op
->
SetInput
(
"X"
,
{
"x"
});
op
->
SetInput
(
"Y"
,
{
"y"
});
op
->
SetInput
(
"Y"
,
{
"y"
});
op
->
SetOutput
(
"Out"
,
{
"out"
});
op
->
SetOutput
(
"Out"
,
{
"out"
});
op
->
CheckAttrs
();
op
->
CheckAttrs
();
auto
target
=
f
::
VarDesc
Bind
(
"out"
);
auto
target
=
f
::
VarDesc
(
"out"
);
target
.
SetShape
({
1
});
target
.
SetShape
({
1
});
AppendBackward
(
program
,
target
,
std
::
unordered_set
<
std
::
string
>
{});
AppendBackward
(
program
,
target
,
std
::
unordered_set
<
std
::
string
>
{});
...
@@ -560,47 +560,47 @@ TEST(Backward, default_attribute) {
...
@@ -560,47 +560,47 @@ TEST(Backward, default_attribute) {
EXPECT_EQ
(
boost
::
get
<
int
>
(
op
->
GetAttr
(
"x_num_col_dims"
)),
1
);
EXPECT_EQ
(
boost
::
get
<
int
>
(
op
->
GetAttr
(
"x_num_col_dims"
)),
1
);
EXPECT_EQ
(
boost
::
get
<
int
>
(
op
->
GetAttr
(
"y_num_col_dims"
)),
1
);
EXPECT_EQ
(
boost
::
get
<
int
>
(
op
->
GetAttr
(
"y_num_col_dims"
)),
1
);
f
::
OpDesc
Bind
*
fill_op
=
block
->
AllOps
()[
1
];
f
::
OpDesc
*
fill_op
=
block
->
AllOps
()[
1
];
EXPECT_EQ
(
fill_op
->
Type
(),
"fill_constant"
);
EXPECT_EQ
(
fill_op
->
Type
(),
"fill_constant"
);
f
::
OpDesc
Bind
*
grad_op
=
block
->
AllOps
()[
2
];
f
::
OpDesc
*
grad_op
=
block
->
AllOps
()[
2
];
ASSERT_EQ
(
grad_op
->
Type
(),
"mul_grad"
);
ASSERT_EQ
(
grad_op
->
Type
(),
"mul_grad"
);
EXPECT_EQ
(
boost
::
get
<
int
>
(
grad_op
->
GetAttr
(
"x_num_col_dims"
)),
1
);
EXPECT_EQ
(
boost
::
get
<
int
>
(
grad_op
->
GetAttr
(
"x_num_col_dims"
)),
1
);
EXPECT_EQ
(
boost
::
get
<
int
>
(
grad_op
->
GetAttr
(
"y_num_col_dims"
)),
1
);
EXPECT_EQ
(
boost
::
get
<
int
>
(
grad_op
->
GetAttr
(
"y_num_col_dims"
)),
1
);
}
}
TEST
(
Backward
,
simple_mult_op
)
{
TEST
(
Backward
,
simple_mult_op
)
{
f
::
ProgramDesc
Bind
program
;
f
::
ProgramDesc
program
;
f
::
BlockDesc
Bind
*
block
=
program
.
MutableBlock
(
0
);
f
::
BlockDesc
*
block
=
program
.
MutableBlock
(
0
);
f
::
OpDesc
Bind
*
op1
=
block
->
AppendOp
();
f
::
OpDesc
*
op1
=
block
->
AppendOp
();
op1
->
SetType
(
"rowwise_add"
);
op1
->
SetType
(
"rowwise_add"
);
op1
->
SetInput
(
"X"
,
{
"x1"
});
op1
->
SetInput
(
"X"
,
{
"x1"
});
op1
->
SetInput
(
"b"
,
{
"b1"
});
op1
->
SetInput
(
"b"
,
{
"b1"
});
op1
->
SetOutput
(
"Out"
,
{
"out1"
});
op1
->
SetOutput
(
"Out"
,
{
"out1"
});
f
::
OpDesc
Bind
*
op2
=
block
->
AppendOp
();
f
::
OpDesc
*
op2
=
block
->
AppendOp
();
op2
->
SetType
(
"mul"
);
op2
->
SetType
(
"mul"
);
op2
->
SetInput
(
"X"
,
{
"out1"
});
op2
->
SetInput
(
"X"
,
{
"out1"
});
op2
->
SetInput
(
"Y"
,
{
"y2"
});
op2
->
SetInput
(
"Y"
,
{
"y2"
});
op2
->
SetOutput
(
"Out"
,
{
"out2"
});
op2
->
SetOutput
(
"Out"
,
{
"out2"
});
f
::
OpDesc
Bind
*
op3
=
block
->
AppendOp
();
f
::
OpDesc
*
op3
=
block
->
AppendOp
();
op3
->
SetType
(
"rowwise_add"
);
op3
->
SetType
(
"rowwise_add"
);
op3
->
SetInput
(
"X"
,
{
"out2"
});
op3
->
SetInput
(
"X"
,
{
"out2"
});
op3
->
SetInput
(
"b"
,
{
"b3"
});
op3
->
SetInput
(
"b"
,
{
"b3"
});
op3
->
SetOutput
(
"Out"
,
{
"out3"
});
op3
->
SetOutput
(
"Out"
,
{
"out3"
});
auto
target
=
f
::
VarDesc
Bind
(
"out3"
);
auto
target
=
f
::
VarDesc
(
"out3"
);
target
.
SetShape
({
1
});
target
.
SetShape
({
1
});
size_t
forward_len
=
block
->
AllOps
().
size
();
size_t
forward_len
=
block
->
AllOps
().
size
();
auto
var_to_grad
=
auto
var_to_grad
=
AppendBackward
(
program
,
target
,
std
::
unordered_set
<
std
::
string
>
{});
AppendBackward
(
program
,
target
,
std
::
unordered_set
<
std
::
string
>
{});
ASSERT_EQ
(
block
->
AllOps
().
size
(),
6UL
+
1
);
ASSERT_EQ
(
block
->
AllOps
().
size
(),
6UL
+
1
);
f
::
OpDesc
Bind
*
fill_op
=
block
->
AllOps
()[
forward_len
];
f
::
OpDesc
*
fill_op
=
block
->
AllOps
()[
forward_len
];
EXPECT_EQ
(
fill_op
->
Type
(),
"fill_constant"
);
EXPECT_EQ
(
fill_op
->
Type
(),
"fill_constant"
);
f
::
OpDesc
Bind
*
grad_op1
=
block
->
AllOps
()[
6
];
f
::
OpDesc
*
grad_op1
=
block
->
AllOps
()[
6
];
EXPECT_EQ
(
grad_op1
->
Type
(),
"rowwise_add_grad"
);
EXPECT_EQ
(
grad_op1
->
Type
(),
"rowwise_add_grad"
);
ASSERT_EQ
(
grad_op1
->
InputNames
().
size
(),
1UL
);
ASSERT_EQ
(
grad_op1
->
InputNames
().
size
(),
1UL
);
ASSERT_EQ
(
grad_op1
->
OutputNames
().
size
(),
2UL
);
ASSERT_EQ
(
grad_op1
->
OutputNames
().
size
(),
2UL
);
...
@@ -611,7 +611,7 @@ TEST(Backward, simple_mult_op) {
...
@@ -611,7 +611,7 @@ TEST(Backward, simple_mult_op) {
EXPECT_EQ
(
grad_op1
->
Output
(
f
::
GradVarName
(
"b"
)),
EXPECT_EQ
(
grad_op1
->
Output
(
f
::
GradVarName
(
"b"
)),
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"b1"
)}));
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"b1"
)}));
f
::
OpDesc
Bind
*
grad_op2
=
block
->
AllOps
()[
5
];
f
::
OpDesc
*
grad_op2
=
block
->
AllOps
()[
5
];
EXPECT_EQ
(
grad_op2
->
Type
(),
"mul_grad"
);
EXPECT_EQ
(
grad_op2
->
Type
(),
"mul_grad"
);
ASSERT_EQ
(
grad_op2
->
InputNames
().
size
(),
4UL
);
ASSERT_EQ
(
grad_op2
->
InputNames
().
size
(),
4UL
);
ASSERT_EQ
(
grad_op2
->
OutputNames
().
size
(),
2UL
);
ASSERT_EQ
(
grad_op2
->
OutputNames
().
size
(),
2UL
);
...
@@ -625,7 +625,7 @@ TEST(Backward, simple_mult_op) {
...
@@ -625,7 +625,7 @@ TEST(Backward, simple_mult_op) {
EXPECT_EQ
(
grad_op2
->
Output
(
f
::
GradVarName
(
"Y"
)),
EXPECT_EQ
(
grad_op2
->
Output
(
f
::
GradVarName
(
"Y"
)),
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"y2"
)}));
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"y2"
)}));
f
::
OpDesc
Bind
*
grad_op3
=
block
->
AllOps
()[
4
];
f
::
OpDesc
*
grad_op3
=
block
->
AllOps
()[
4
];
EXPECT_EQ
(
grad_op3
->
Type
(),
"rowwise_add_grad"
);
EXPECT_EQ
(
grad_op3
->
Type
(),
"rowwise_add_grad"
);
ASSERT_EQ
(
grad_op3
->
InputNames
().
size
(),
1UL
);
ASSERT_EQ
(
grad_op3
->
InputNames
().
size
(),
1UL
);
ASSERT_EQ
(
grad_op3
->
OutputNames
().
size
(),
2UL
);
ASSERT_EQ
(
grad_op3
->
OutputNames
().
size
(),
2UL
);
...
@@ -655,42 +655,42 @@ TEST(Backward, simple_mult_op) {
...
@@ -655,42 +655,42 @@ TEST(Backward, simple_mult_op) {
}
}
TEST
(
Backward
,
intermedia_var_no_grad
)
{
TEST
(
Backward
,
intermedia_var_no_grad
)
{
f
::
ProgramDesc
Bind
program
;
f
::
ProgramDesc
program
;
f
::
BlockDesc
Bind
*
block
=
program
.
MutableBlock
(
0
);
f
::
BlockDesc
*
block
=
program
.
MutableBlock
(
0
);
f
::
OpDesc
Bind
*
op1
=
block
->
AppendOp
();
f
::
OpDesc
*
op1
=
block
->
AppendOp
();
op1
->
SetType
(
"rowwise_add"
);
op1
->
SetType
(
"rowwise_add"
);
op1
->
SetInput
(
"X"
,
{
"x1"
});
op1
->
SetInput
(
"X"
,
{
"x1"
});
op1
->
SetInput
(
"b"
,
{
"b1"
});
op1
->
SetInput
(
"b"
,
{
"b1"
});
op1
->
SetOutput
(
"Out"
,
{
"out1"
});
op1
->
SetOutput
(
"Out"
,
{
"out1"
});
f
::
OpDesc
Bind
*
op2
=
block
->
AppendOp
();
f
::
OpDesc
*
op2
=
block
->
AppendOp
();
op2
->
SetType
(
"mul"
);
op2
->
SetType
(
"mul"
);
op2
->
SetInput
(
"X"
,
{
"x2"
});
op2
->
SetInput
(
"X"
,
{
"x2"
});
op2
->
SetInput
(
"Y"
,
{
"y2"
});
op2
->
SetInput
(
"Y"
,
{
"y2"
});
op2
->
SetOutput
(
"Out"
,
{
"out2"
});
op2
->
SetOutput
(
"Out"
,
{
"out2"
});
f
::
OpDesc
Bind
*
op3
=
block
->
AppendOp
();
f
::
OpDesc
*
op3
=
block
->
AppendOp
();
op3
->
SetType
(
"rowwise_add"
);
op3
->
SetType
(
"rowwise_add"
);
op3
->
SetInput
(
"X"
,
{
"out2"
});
op3
->
SetInput
(
"X"
,
{
"out2"
});
op3
->
SetInput
(
"b"
,
{
"b3"
});
op3
->
SetInput
(
"b"
,
{
"b3"
});
op3
->
SetOutput
(
"Out"
,
{
"out3"
});
op3
->
SetOutput
(
"Out"
,
{
"out3"
});
f
::
OpDesc
Bind
*
op4
=
block
->
AppendOp
();
f
::
OpDesc
*
op4
=
block
->
AppendOp
();
op4
->
SetType
(
"mul"
);
op4
->
SetType
(
"mul"
);
op4
->
SetInput
(
"X"
,
{
"out1"
});
op4
->
SetInput
(
"X"
,
{
"out1"
});
op4
->
SetInput
(
"Y"
,
{
"out3"
});
op4
->
SetInput
(
"Y"
,
{
"out3"
});
op4
->
SetOutput
(
"Out"
,
{
"out4"
});
op4
->
SetOutput
(
"Out"
,
{
"out4"
});
auto
target
=
f
::
VarDesc
Bind
(
"out4"
);
auto
target
=
f
::
VarDesc
(
"out4"
);
target
.
SetShape
({
1
});
target
.
SetShape
({
1
});
size_t
forward_len
=
block
->
AllOps
().
size
();
size_t
forward_len
=
block
->
AllOps
().
size
();
auto
var_to_grad
=
AppendBackward
(
program
,
target
,
{
"out3"
});
auto
var_to_grad
=
AppendBackward
(
program
,
target
,
{
"out3"
});
ASSERT_EQ
(
block
->
AllOps
().
size
(),
7UL
);
ASSERT_EQ
(
block
->
AllOps
().
size
(),
7UL
);
f
::
OpDesc
Bind
*
fill_op
=
block
->
AllOps
()[
forward_len
];
f
::
OpDesc
*
fill_op
=
block
->
AllOps
()[
forward_len
];
EXPECT_EQ
(
fill_op
->
Type
(),
"fill_constant"
);
EXPECT_EQ
(
fill_op
->
Type
(),
"fill_constant"
);
f
::
OpDesc
Bind
*
grad_op1
=
block
->
AllOps
()[
6
];
f
::
OpDesc
*
grad_op1
=
block
->
AllOps
()[
6
];
EXPECT_EQ
(
grad_op1
->
Type
(),
"rowwise_add_grad"
);
EXPECT_EQ
(
grad_op1
->
Type
(),
"rowwise_add_grad"
);
ASSERT_EQ
(
grad_op1
->
InputNames
().
size
(),
1UL
);
ASSERT_EQ
(
grad_op1
->
InputNames
().
size
(),
1UL
);
ASSERT_EQ
(
grad_op1
->
OutputNames
().
size
(),
2UL
);
ASSERT_EQ
(
grad_op1
->
OutputNames
().
size
(),
2UL
);
...
@@ -701,7 +701,7 @@ TEST(Backward, intermedia_var_no_grad) {
...
@@ -701,7 +701,7 @@ TEST(Backward, intermedia_var_no_grad) {
EXPECT_EQ
(
grad_op1
->
Output
(
f
::
GradVarName
(
"b"
)),
EXPECT_EQ
(
grad_op1
->
Output
(
f
::
GradVarName
(
"b"
)),
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"b1"
)}));
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"b1"
)}));
f
::
OpDesc
Bind
*
grad_op4
=
block
->
AllOps
()[
5
];
f
::
OpDesc
*
grad_op4
=
block
->
AllOps
()[
5
];
EXPECT_EQ
(
grad_op4
->
Type
(),
"mul_grad"
);
EXPECT_EQ
(
grad_op4
->
Type
(),
"mul_grad"
);
ASSERT_EQ
(
grad_op4
->
InputNames
().
size
(),
4UL
);
ASSERT_EQ
(
grad_op4
->
InputNames
().
size
(),
4UL
);
ASSERT_EQ
(
grad_op4
->
OutputNames
().
size
(),
2UL
);
ASSERT_EQ
(
grad_op4
->
OutputNames
().
size
(),
2UL
);
...
@@ -726,32 +726,32 @@ TEST(Backward, intermedia_var_no_grad) {
...
@@ -726,32 +726,32 @@ TEST(Backward, intermedia_var_no_grad) {
}
}
TEST
(
Backward
,
var_no_grad
)
{
TEST
(
Backward
,
var_no_grad
)
{
f
::
ProgramDesc
Bind
program
;
f
::
ProgramDesc
program
;
f
::
BlockDesc
Bind
*
block
=
program
.
MutableBlock
(
0
);
f
::
BlockDesc
*
block
=
program
.
MutableBlock
(
0
);
f
::
OpDesc
Bind
*
op1
=
block
->
AppendOp
();
f
::
OpDesc
*
op1
=
block
->
AppendOp
();
op1
->
SetType
(
"mult_in_out"
);
op1
->
SetType
(
"mult_in_out"
);
op1
->
SetInput
(
"X"
,
{
"x1"
});
op1
->
SetInput
(
"X"
,
{
"x1"
});
op1
->
SetInput
(
"H"
,
{
"h1"
});
op1
->
SetInput
(
"H"
,
{
"h1"
});
op1
->
SetOutput
(
"Y"
,
{
"y1"
});
op1
->
SetOutput
(
"Y"
,
{
"y1"
});
op1
->
SetOutput
(
"Z"
,
{
"z1"
});
op1
->
SetOutput
(
"Z"
,
{
"z1"
});
f
::
OpDesc
Bind
*
op2
=
block
->
AppendOp
();
f
::
OpDesc
*
op2
=
block
->
AppendOp
();
op2
->
SetType
(
"mult_in_out"
);
op2
->
SetType
(
"mult_in_out"
);
op2
->
SetInput
(
"X"
,
{
"y1"
});
op2
->
SetInput
(
"X"
,
{
"y1"
});
op2
->
SetInput
(
"H"
,
{
"z1"
});
op2
->
SetInput
(
"H"
,
{
"z1"
});
op2
->
SetOutput
(
"Y"
,
{
"y2"
});
op2
->
SetOutput
(
"Y"
,
{
"y2"
});
op2
->
SetOutput
(
"Z"
,
{
"z2"
});
op2
->
SetOutput
(
"Z"
,
{
"z2"
});
auto
target
=
f
::
VarDesc
Bind
(
"z2"
);
auto
target
=
f
::
VarDesc
(
"z2"
);
target
.
SetShape
({
1
});
target
.
SetShape
({
1
});
size_t
forward_len
=
block
->
AllOps
().
size
();
size_t
forward_len
=
block
->
AllOps
().
size
();
auto
var_to_grad
=
AppendBackward
(
program
,
target
,
{
"z1"
});
auto
var_to_grad
=
AppendBackward
(
program
,
target
,
{
"z1"
});
ASSERT_EQ
(
block
->
AllOps
().
size
(),
6UL
);
ASSERT_EQ
(
block
->
AllOps
().
size
(),
6UL
);
f
::
OpDesc
Bind
*
fill_op
=
block
->
AllOps
()[
forward_len
];
f
::
OpDesc
*
fill_op
=
block
->
AllOps
()[
forward_len
];
EXPECT_EQ
(
fill_op
->
Type
(),
"fill_constant"
);
EXPECT_EQ
(
fill_op
->
Type
(),
"fill_constant"
);
f
::
OpDesc
Bind
*
grad_op2
=
block
->
AllOps
()[
3
];
f
::
OpDesc
*
grad_op2
=
block
->
AllOps
()[
3
];
ASSERT_EQ
(
grad_op2
->
Type
(),
"mult_in_out_grad"
);
ASSERT_EQ
(
grad_op2
->
Type
(),
"mult_in_out_grad"
);
ASSERT_EQ
(
grad_op2
->
InputNames
().
size
(),
6UL
);
ASSERT_EQ
(
grad_op2
->
InputNames
().
size
(),
6UL
);
ASSERT_EQ
(
grad_op2
->
OutputNames
().
size
(),
2UL
);
ASSERT_EQ
(
grad_op2
->
OutputNames
().
size
(),
2UL
);
...
@@ -767,7 +767,7 @@ TEST(Backward, var_no_grad) {
...
@@ -767,7 +767,7 @@ TEST(Backward, var_no_grad) {
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"y1"
)}));
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"y1"
)}));
EXPECT_EQ
(
grad_op2
->
Output
(
f
::
GradVarName
(
"H"
)),
std
::
vector
<
std
::
string
>
());
EXPECT_EQ
(
grad_op2
->
Output
(
f
::
GradVarName
(
"H"
)),
std
::
vector
<
std
::
string
>
());
f
::
OpDesc
Bind
*
fill_zero_op
=
block
->
AllOps
()[
4
];
f
::
OpDesc
*
fill_zero_op
=
block
->
AllOps
()[
4
];
ASSERT_EQ
(
fill_zero_op
->
Type
(),
"fill_zeros_like"
);
ASSERT_EQ
(
fill_zero_op
->
Type
(),
"fill_zeros_like"
);
ASSERT_EQ
(
fill_zero_op
->
InputNames
().
size
(),
1UL
);
ASSERT_EQ
(
fill_zero_op
->
InputNames
().
size
(),
1UL
);
ASSERT_EQ
(
fill_zero_op
->
OutputNames
().
size
(),
1UL
);
ASSERT_EQ
(
fill_zero_op
->
OutputNames
().
size
(),
1UL
);
...
@@ -775,7 +775,7 @@ TEST(Backward, var_no_grad) {
...
@@ -775,7 +775,7 @@ TEST(Backward, var_no_grad) {
EXPECT_EQ
(
fill_zero_op
->
Output
(
"Y"
),
EXPECT_EQ
(
fill_zero_op
->
Output
(
"Y"
),
std
::
vector
<
std
::
string
>
({
std
::
string
(
"z1"
)
+
f
::
kZeroVarSuffix
}));
std
::
vector
<
std
::
string
>
({
std
::
string
(
"z1"
)
+
f
::
kZeroVarSuffix
}));
f
::
OpDesc
Bind
*
grad_op1
=
block
->
AllOps
()[
5
];
f
::
OpDesc
*
grad_op1
=
block
->
AllOps
()[
5
];
ASSERT_EQ
(
grad_op1
->
Type
(),
"mult_in_out_grad"
);
ASSERT_EQ
(
grad_op1
->
Type
(),
"mult_in_out_grad"
);
ASSERT_EQ
(
grad_op1
->
InputNames
().
size
(),
6UL
);
ASSERT_EQ
(
grad_op1
->
InputNames
().
size
(),
6UL
);
ASSERT_EQ
(
grad_op1
->
OutputNames
().
size
(),
2UL
);
ASSERT_EQ
(
grad_op1
->
OutputNames
().
size
(),
2UL
);
...
@@ -803,37 +803,37 @@ TEST(Backward, var_no_grad) {
...
@@ -803,37 +803,37 @@ TEST(Backward, var_no_grad) {
}
}
TEST
(
Backward
,
shared_var
)
{
TEST
(
Backward
,
shared_var
)
{
f
::
ProgramDesc
Bind
program
;
f
::
ProgramDesc
program
;
f
::
BlockDesc
Bind
*
block
=
program
.
MutableBlock
(
0
);
f
::
BlockDesc
*
block
=
program
.
MutableBlock
(
0
);
f
::
OpDesc
Bind
*
op1
=
block
->
AppendOp
();
f
::
OpDesc
*
op1
=
block
->
AppendOp
();
op1
->
SetType
(
"rowwise_add"
);
op1
->
SetType
(
"rowwise_add"
);
op1
->
SetInput
(
"X"
,
{
"x1"
});
op1
->
SetInput
(
"X"
,
{
"x1"
});
op1
->
SetInput
(
"b"
,
{
"b1"
});
op1
->
SetInput
(
"b"
,
{
"b1"
});
op1
->
SetOutput
(
"Out"
,
{
"out1"
});
op1
->
SetOutput
(
"Out"
,
{
"out1"
});
f
::
OpDesc
Bind
*
op2
=
block
->
AppendOp
();
f
::
OpDesc
*
op2
=
block
->
AppendOp
();
op2
->
SetType
(
"mul"
);
op2
->
SetType
(
"mul"
);
op2
->
SetInput
(
"X"
,
{
"out1"
});
op2
->
SetInput
(
"X"
,
{
"out1"
});
op2
->
SetInput
(
"Y"
,
{
"y2"
});
op2
->
SetInput
(
"Y"
,
{
"y2"
});
op2
->
SetOutput
(
"Out"
,
{
"out2"
});
op2
->
SetOutput
(
"Out"
,
{
"out2"
});
f
::
OpDesc
Bind
*
op3
=
block
->
AppendOp
();
f
::
OpDesc
*
op3
=
block
->
AppendOp
();
op3
->
SetType
(
"rowwise_add"
);
op3
->
SetType
(
"rowwise_add"
);
op3
->
SetInput
(
"X"
,
{
"out1"
});
op3
->
SetInput
(
"X"
,
{
"out1"
});
op3
->
SetInput
(
"b"
,
{
"b3"
});
op3
->
SetInput
(
"b"
,
{
"b3"
});
op3
->
SetOutput
(
"Out"
,
{
"out3"
});
op3
->
SetOutput
(
"Out"
,
{
"out3"
});
auto
target
=
f
::
VarDesc
Bind
(
"out3"
);
auto
target
=
f
::
VarDesc
(
"out3"
);
target
.
SetShape
({
1
});
target
.
SetShape
({
1
});
size_t
forward_len
=
block
->
AllOps
().
size
();
size_t
forward_len
=
block
->
AllOps
().
size
();
auto
var_to_grad
=
auto
var_to_grad
=
AppendBackward
(
program
,
target
,
std
::
unordered_set
<
std
::
string
>
{});
AppendBackward
(
program
,
target
,
std
::
unordered_set
<
std
::
string
>
{});
ASSERT_EQ
(
block
->
AllOps
().
size
(),
8UL
);
ASSERT_EQ
(
block
->
AllOps
().
size
(),
8UL
);
f
::
OpDesc
Bind
*
fill_op
=
block
->
AllOps
()[
forward_len
];
f
::
OpDesc
*
fill_op
=
block
->
AllOps
()[
forward_len
];
EXPECT_EQ
(
fill_op
->
Type
(),
"fill_constant"
);
EXPECT_EQ
(
fill_op
->
Type
(),
"fill_constant"
);
f
::
OpDesc
Bind
*
grad_op3
=
block
->
AllOps
()[
4
];
f
::
OpDesc
*
grad_op3
=
block
->
AllOps
()[
4
];
ASSERT_EQ
(
grad_op3
->
Type
(),
"rowwise_add_grad"
);
ASSERT_EQ
(
grad_op3
->
Type
(),
"rowwise_add_grad"
);
ASSERT_EQ
(
grad_op3
->
InputNames
().
size
(),
1UL
);
ASSERT_EQ
(
grad_op3
->
InputNames
().
size
(),
1UL
);
ASSERT_EQ
(
grad_op3
->
OutputNames
().
size
(),
2UL
);
ASSERT_EQ
(
grad_op3
->
OutputNames
().
size
(),
2UL
);
...
@@ -844,7 +844,7 @@ TEST(Backward, shared_var) {
...
@@ -844,7 +844,7 @@ TEST(Backward, shared_var) {
EXPECT_EQ
(
grad_op3
->
Output
(
f
::
GradVarName
(
"b"
)),
EXPECT_EQ
(
grad_op3
->
Output
(
f
::
GradVarName
(
"b"
)),
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"b3"
)}));
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"b3"
)}));
f
::
OpDesc
Bind
*
grad_op4
=
block
->
AllOps
()[
5
];
f
::
OpDesc
*
grad_op4
=
block
->
AllOps
()[
5
];
ASSERT_EQ
(
grad_op4
->
Type
(),
"mul_grad"
);
ASSERT_EQ
(
grad_op4
->
Type
(),
"mul_grad"
);
ASSERT_EQ
(
grad_op4
->
InputNames
().
size
(),
4UL
);
ASSERT_EQ
(
grad_op4
->
InputNames
().
size
(),
4UL
);
ASSERT_EQ
(
grad_op4
->
OutputNames
().
size
(),
2UL
);
ASSERT_EQ
(
grad_op4
->
OutputNames
().
size
(),
2UL
);
...
@@ -858,7 +858,7 @@ TEST(Backward, shared_var) {
...
@@ -858,7 +858,7 @@ TEST(Backward, shared_var) {
EXPECT_EQ
(
grad_op4
->
Output
(
f
::
GradVarName
(
"Y"
)),
EXPECT_EQ
(
grad_op4
->
Output
(
f
::
GradVarName
(
"Y"
)),
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"y2"
)}));
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"y2"
)}));
f
::
OpDesc
Bind
*
sum_op
=
block
->
AllOps
()[
6
];
f
::
OpDesc
*
sum_op
=
block
->
AllOps
()[
6
];
ASSERT_EQ
(
sum_op
->
Type
(),
"sum"
);
ASSERT_EQ
(
sum_op
->
Type
(),
"sum"
);
ASSERT_EQ
(
sum_op
->
InputNames
().
size
(),
1UL
);
ASSERT_EQ
(
sum_op
->
InputNames
().
size
(),
1UL
);
ASSERT_EQ
(
sum_op
->
OutputNames
().
size
(),
1UL
);
ASSERT_EQ
(
sum_op
->
OutputNames
().
size
(),
1UL
);
...
@@ -868,7 +868,7 @@ TEST(Backward, shared_var) {
...
@@ -868,7 +868,7 @@ TEST(Backward, shared_var) {
EXPECT_EQ
(
sum_op
->
Output
(
"Out"
),
EXPECT_EQ
(
sum_op
->
Output
(
"Out"
),
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"out1"
)}));
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"out1"
)}));
f
::
OpDesc
Bind
*
grad_op1
=
block
->
AllOps
()[
7
];
f
::
OpDesc
*
grad_op1
=
block
->
AllOps
()[
7
];
ASSERT_EQ
(
grad_op1
->
Type
(),
"rowwise_add_grad"
);
ASSERT_EQ
(
grad_op1
->
Type
(),
"rowwise_add_grad"
);
ASSERT_EQ
(
grad_op1
->
InputNames
().
size
(),
1UL
);
ASSERT_EQ
(
grad_op1
->
InputNames
().
size
(),
1UL
);
ASSERT_EQ
(
grad_op1
->
OutputNames
().
size
(),
2UL
);
ASSERT_EQ
(
grad_op1
->
OutputNames
().
size
(),
2UL
);
...
@@ -895,19 +895,19 @@ TEST(Backward, shared_var) {
...
@@ -895,19 +895,19 @@ TEST(Backward, shared_var) {
}
}
TEST
(
Backward
,
half_backward
)
{
TEST
(
Backward
,
half_backward
)
{
f
::
ProgramDesc
Bind
program
;
f
::
ProgramDesc
program
;
f
::
BlockDesc
Bind
*
block
=
program
.
MutableBlock
(
0
);
f
::
BlockDesc
*
block
=
program
.
MutableBlock
(
0
);
auto
*
op1
=
block
->
AppendOp
();
auto
*
op1
=
block
->
AppendOp
();
op1
->
SetType
(
"minus"
);
op1
->
SetType
(
"minus"
);
op1
->
SetInput
(
"X"
,
{
"a"
});
op1
->
SetInput
(
"X"
,
{
"a"
});
op1
->
SetInput
(
"Y"
,
{
"b"
});
op1
->
SetInput
(
"Y"
,
{
"b"
});
op1
->
SetOutput
(
"Out"
,
{
"out"
});
op1
->
SetOutput
(
"Out"
,
{
"out"
});
auto
target
=
f
::
VarDesc
Bind
(
"out"
);
auto
target
=
f
::
VarDesc
(
"out"
);
target
.
SetShape
({
1
});
target
.
SetShape
({
1
});
size_t
forward_len
=
block
->
AllOps
().
size
();
size_t
forward_len
=
block
->
AllOps
().
size
();
auto
var_to_grad
=
AppendBackward
(
program
,
target
,
{
"b"
});
auto
var_to_grad
=
AppendBackward
(
program
,
target
,
{
"b"
});
f
::
OpDesc
Bind
*
fill_op
=
block
->
AllOps
()[
forward_len
];
f
::
OpDesc
*
fill_op
=
block
->
AllOps
()[
forward_len
];
EXPECT_EQ
(
fill_op
->
Type
(),
"fill_constant"
);
EXPECT_EQ
(
fill_op
->
Type
(),
"fill_constant"
);
auto
ops
=
block
->
AllOps
();
auto
ops
=
block
->
AllOps
();
ASSERT_EQ
(
3UL
,
ops
.
size
());
ASSERT_EQ
(
3UL
,
ops
.
size
());
...
...
paddle/framework/block_desc.cc
浏览文件 @
09189732
...
@@ -19,18 +19,18 @@ limitations under the License. */
...
@@ -19,18 +19,18 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
VarDesc
Bind
*
BlockDescBind
::
Var
(
const
std
::
string
&
name
)
{
VarDesc
*
BlockDesc
::
Var
(
const
std
::
string
&
name
)
{
auto
it
=
vars_
.
find
(
name
);
auto
it
=
vars_
.
find
(
name
);
if
(
it
!=
vars_
.
end
())
{
if
(
it
!=
vars_
.
end
())
{
return
it
->
second
.
get
();
return
it
->
second
.
get
();
}
}
need_update_
=
true
;
need_update_
=
true
;
auto
*
var
=
new
VarDesc
Bind
(
name
);
auto
*
var
=
new
VarDesc
(
name
);
vars_
[
name
].
reset
(
var
);
vars_
[
name
].
reset
(
var
);
return
var
;
return
var
;
}
}
VarDesc
Bind
*
BlockDescBind
::
FindVar
(
const
std
::
string
&
name
)
const
{
VarDesc
*
BlockDesc
::
FindVar
(
const
std
::
string
&
name
)
const
{
auto
it
=
vars_
.
find
(
name
);
auto
it
=
vars_
.
find
(
name
);
if
(
it
==
vars_
.
end
())
{
if
(
it
==
vars_
.
end
())
{
return
nullptr
;
return
nullptr
;
...
@@ -38,11 +38,11 @@ VarDescBind *BlockDescBind::FindVar(const std::string &name) const {
...
@@ -38,11 +38,11 @@ VarDescBind *BlockDescBind::FindVar(const std::string &name) const {
return
it
->
second
.
get
();
return
it
->
second
.
get
();
}
}
bool
BlockDesc
Bind
::
HasVar
(
const
std
::
string
&
name
)
const
{
bool
BlockDesc
::
HasVar
(
const
std
::
string
&
name
)
const
{
return
vars_
.
find
(
name
)
!=
vars_
.
end
();
return
vars_
.
find
(
name
)
!=
vars_
.
end
();
}
}
VarDesc
Bind
*
BlockDescBind
::
FindVarRecursive
(
const
std
::
string
&
name
)
const
{
VarDesc
*
BlockDesc
::
FindVarRecursive
(
const
std
::
string
&
name
)
const
{
if
(
name
==
kEmptyVarName
)
return
nullptr
;
if
(
name
==
kEmptyVarName
)
return
nullptr
;
auto
it
=
vars_
.
find
(
name
);
auto
it
=
vars_
.
find
(
name
);
...
@@ -53,53 +53,52 @@ VarDescBind *BlockDescBind::FindVarRecursive(const std::string &name) const {
...
@@ -53,53 +53,52 @@ VarDescBind *BlockDescBind::FindVarRecursive(const std::string &name) const {
return
it
->
second
.
get
();
return
it
->
second
.
get
();
}
}
VarDescBind
*
BlockDescBind
::
FindRecursiveOrCreateVar
(
VarDesc
*
BlockDesc
::
FindRecursiveOrCreateVar
(
const
std
::
string
&
name_bytes
)
{
const
std
::
string
&
name_bytes
)
{
VarDesc
*
res
=
FindVarRecursive
(
name_bytes
);
VarDescBind
*
res
=
FindVarRecursive
(
name_bytes
);
if
(
res
==
nullptr
)
{
if
(
res
==
nullptr
)
{
res
=
Var
(
name_bytes
);
res
=
Var
(
name_bytes
);
}
}
return
res
;
return
res
;
}
}
bool
BlockDesc
Bind
::
HasVarRecursive
(
const
std
::
string
&
name
)
const
{
bool
BlockDesc
::
HasVarRecursive
(
const
std
::
string
&
name
)
const
{
return
FindVarRecursive
(
name
)
!=
nullptr
;
return
FindVarRecursive
(
name
)
!=
nullptr
;
}
}
std
::
vector
<
VarDesc
Bind
*>
BlockDescBind
::
AllVars
()
const
{
std
::
vector
<
VarDesc
*>
BlockDesc
::
AllVars
()
const
{
std
::
vector
<
VarDesc
Bind
*>
res
;
std
::
vector
<
VarDesc
*>
res
;
for
(
const
auto
&
p
:
vars_
)
{
for
(
const
auto
&
p
:
vars_
)
{
res
.
push_back
(
p
.
second
.
get
());
res
.
push_back
(
p
.
second
.
get
());
}
}
return
res
;
return
res
;
}
}
OpDesc
Bind
*
BlockDescBind
::
AppendOp
()
{
OpDesc
*
BlockDesc
::
AppendOp
()
{
need_update_
=
true
;
need_update_
=
true
;
ops_
.
emplace_back
(
new
OpDesc
Bind
());
ops_
.
emplace_back
(
new
OpDesc
());
return
ops_
.
back
().
get
();
return
ops_
.
back
().
get
();
}
}
void
BlockDesc
Bind
::
AppendAllocatedOp
(
std
::
unique_ptr
<
OpDescBind
>
&&
op_desc
)
{
void
BlockDesc
::
AppendAllocatedOp
(
std
::
unique_ptr
<
OpDesc
>
&&
op_desc
)
{
need_update_
=
true
;
need_update_
=
true
;
ops_
.
emplace_back
(
std
::
move
(
op_desc
));
ops_
.
emplace_back
(
std
::
move
(
op_desc
));
}
}
OpDesc
Bind
*
BlockDescBind
::
PrependOp
()
{
OpDesc
*
BlockDesc
::
PrependOp
()
{
need_update_
=
true
;
need_update_
=
true
;
ops_
.
emplace_front
(
new
OpDesc
Bind
());
ops_
.
emplace_front
(
new
OpDesc
());
return
ops_
.
front
().
get
();
return
ops_
.
front
().
get
();
}
}
std
::
vector
<
OpDesc
Bind
*>
BlockDescBind
::
AllOps
()
const
{
std
::
vector
<
OpDesc
*>
BlockDesc
::
AllOps
()
const
{
std
::
vector
<
OpDesc
Bind
*>
res
;
std
::
vector
<
OpDesc
*>
res
;
for
(
const
auto
&
op
:
ops_
)
{
for
(
const
auto
&
op
:
ops_
)
{
res
.
push_back
(
op
.
get
());
res
.
push_back
(
op
.
get
());
}
}
return
res
;
return
res
;
}
}
void
BlockDesc
Bind
::
Flush
()
{
void
BlockDesc
::
Flush
()
{
for
(
auto
&
op_desc
:
ops_
)
{
for
(
auto
&
op_desc
:
ops_
)
{
op_desc
->
Flush
();
op_desc
->
Flush
();
}
}
...
@@ -121,43 +120,43 @@ void BlockDescBind::Flush() {
...
@@ -121,43 +120,43 @@ void BlockDescBind::Flush() {
}
}
}
}
BlockDesc
Bind
*
BlockDescBind
::
ParentBlock
()
const
{
BlockDesc
*
BlockDesc
::
ParentBlock
()
const
{
if
(
this
->
desc_
->
parent_idx
()
==
kNoneBlockIndex
)
{
if
(
this
->
desc_
->
parent_idx
()
==
kNoneBlockIndex
)
{
return
nullptr
;
return
nullptr
;
}
}
return
prog_
->
MutableBlock
(
static_cast
<
size_t
>
(
this
->
desc_
->
parent_idx
()));
return
prog_
->
MutableBlock
(
static_cast
<
size_t
>
(
this
->
desc_
->
parent_idx
()));
}
}
proto
::
BlockDesc
*
BlockDesc
Bind
::
Proto
()
{
proto
::
BlockDesc
*
BlockDesc
::
Proto
()
{
Flush
();
Flush
();
return
desc_
;
return
desc_
;
}
}
BlockDesc
Bind
::
BlockDescBind
(
ProgramDescBind
*
prog
,
proto
::
BlockDesc
*
desc
)
BlockDesc
::
BlockDesc
(
ProgramDesc
*
prog
,
proto
::
BlockDesc
*
desc
)
:
prog_
(
prog
),
desc_
(
desc
),
need_update_
(
false
)
{
:
prog_
(
prog
),
desc_
(
desc
),
need_update_
(
false
)
{
for
(
const
proto
::
VarDesc
&
var_desc
:
desc_
->
vars
())
{
for
(
const
proto
::
VarDesc
&
var_desc
:
desc_
->
vars
())
{
vars_
[
var_desc
.
name
()].
reset
(
new
VarDesc
Bind
(
var_desc
));
vars_
[
var_desc
.
name
()].
reset
(
new
VarDesc
(
var_desc
));
}
}
for
(
const
proto
::
OpDesc
&
op_desc
:
desc_
->
ops
())
{
for
(
const
proto
::
OpDesc
&
op_desc
:
desc_
->
ops
())
{
ops_
.
emplace_back
(
new
OpDesc
Bind
(
op_desc
,
prog
));
ops_
.
emplace_back
(
new
OpDesc
(
op_desc
,
prog
));
}
}
}
}
BlockDesc
Bind
::
BlockDescBind
(
const
BlockDescBind
&
other
,
proto
::
BlockDesc
*
desc
,
BlockDesc
::
BlockDesc
(
const
BlockDesc
&
other
,
proto
::
BlockDesc
*
desc
,
ProgramDescBind
*
prog
)
ProgramDesc
*
prog
)
:
prog_
(
prog
),
desc_
(
desc
)
{
:
prog_
(
prog
),
desc_
(
desc
)
{
need_update_
=
true
;
need_update_
=
true
;
for
(
auto
&
op
:
other
.
ops_
)
{
for
(
auto
&
op
:
other
.
ops_
)
{
ops_
.
emplace_back
(
new
OpDesc
Bind
(
*
op
));
ops_
.
emplace_back
(
new
OpDesc
(
*
op
));
}
}
for
(
auto
&
it
:
other
.
vars_
)
{
for
(
auto
&
it
:
other
.
vars_
)
{
auto
*
var
=
new
VarDesc
Bind
(
*
it
.
second
);
auto
*
var
=
new
VarDesc
(
*
it
.
second
);
vars_
[
it
.
first
].
reset
(
var
);
vars_
[
it
.
first
].
reset
(
var
);
}
}
}
}
void
BlockDesc
Bind
::
ClearPBOps
()
{
void
BlockDesc
::
ClearPBOps
()
{
auto
ops
=
this
->
desc_
->
mutable_ops
();
auto
ops
=
this
->
desc_
->
mutable_ops
();
while
(
!
ops
->
empty
())
{
while
(
!
ops
->
empty
())
{
// we do not own the OpDesc, so release the ownership.
// we do not own the OpDesc, so release the ownership.
...
@@ -165,7 +164,7 @@ void BlockDescBind::ClearPBOps() {
...
@@ -165,7 +164,7 @@ void BlockDescBind::ClearPBOps() {
}
}
}
}
void
BlockDesc
Bind
::
ClearPBVars
()
{
void
BlockDesc
::
ClearPBVars
()
{
auto
vars
=
this
->
desc_
->
mutable_vars
();
auto
vars
=
this
->
desc_
->
mutable_vars
();
while
(
!
vars
->
empty
())
{
while
(
!
vars
->
empty
())
{
// we do not own the VarDesc, so release the ownership.
// we do not own the VarDesc, so release the ownership.
...
...
paddle/framework/block_desc.h
浏览文件 @
09189732
...
@@ -28,20 +28,19 @@ limitations under the License. */
...
@@ -28,20 +28,19 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
ProgramDesc
Bind
;
class
ProgramDesc
;
// Each Protobuf Message, we provide a XXXBind class. In that class, we optimize
// Each Protobuf Message, we provide a XXXBind class. In that class, we optimize
// read/write speed. Only when we want the protobuf message, the local changes
// read/write speed. Only when we want the protobuf message, the local changes
// will be synchronized (by `Sync` method).
// will be synchronized (by `Sync` method).
class
BlockDesc
Bind
{
class
BlockDesc
{
public:
public:
BlockDesc
Bind
(
ProgramDescBind
*
prog
,
proto
::
BlockDesc
*
desc
);
BlockDesc
(
ProgramDesc
*
prog
,
proto
::
BlockDesc
*
desc
);
BlockDescBind
(
const
BlockDescBind
&
other
,
proto
::
BlockDesc
*
desc
,
BlockDesc
(
const
BlockDesc
&
other
,
proto
::
BlockDesc
*
desc
,
ProgramDesc
*
prog
);
ProgramDescBind
*
prog
);
~
BlockDesc
Bind
()
{
~
BlockDesc
()
{
this
->
ClearPBVars
();
this
->
ClearPBVars
();
this
->
ClearPBOps
();
this
->
ClearPBOps
();
}
}
...
@@ -50,15 +49,15 @@ class BlockDescBind {
...
@@ -50,15 +49,15 @@ class BlockDescBind {
int32_t
Parent
()
const
{
return
desc_
->
parent_idx
();
}
int32_t
Parent
()
const
{
return
desc_
->
parent_idx
();
}
VarDesc
Bind
*
Var
(
const
std
::
string
&
name_bytes
);
VarDesc
*
Var
(
const
std
::
string
&
name_bytes
);
VarDesc
Bind
*
FindVar
(
const
std
::
string
&
name_bytes
)
const
;
VarDesc
*
FindVar
(
const
std
::
string
&
name_bytes
)
const
;
bool
HasVar
(
const
std
::
string
&
var_name
)
const
;
bool
HasVar
(
const
std
::
string
&
var_name
)
const
;
VarDesc
Bind
*
FindVarRecursive
(
const
std
::
string
&
name_bytes
)
const
;
VarDesc
*
FindVarRecursive
(
const
std
::
string
&
name_bytes
)
const
;
VarDesc
Bind
*
FindRecursiveOrCreateVar
(
const
std
::
string
&
name_bytes
);
VarDesc
*
FindRecursiveOrCreateVar
(
const
std
::
string
&
name_bytes
);
bool
HasVarRecursive
(
const
std
::
string
&
var_name
)
const
;
bool
HasVarRecursive
(
const
std
::
string
&
var_name
)
const
;
...
@@ -70,41 +69,41 @@ class BlockDescBind {
...
@@ -70,41 +69,41 @@ class BlockDescBind {
return
var_names
;
return
var_names
;
}
}
std
::
vector
<
VarDesc
Bind
*>
AllVars
()
const
;
std
::
vector
<
VarDesc
*>
AllVars
()
const
;
BlockDesc
Bind
*
ParentBlock
()
const
;
BlockDesc
*
ParentBlock
()
const
;
OpDesc
Bind
*
AppendOp
();
OpDesc
*
AppendOp
();
void
AppendAllocatedOp
(
std
::
unique_ptr
<
OpDesc
Bind
>
&&
op_desc
);
void
AppendAllocatedOp
(
std
::
unique_ptr
<
OpDesc
>
&&
op_desc
);
OpDesc
Bind
*
PrependOp
();
OpDesc
*
PrependOp
();
std
::
vector
<
OpDesc
Bind
*>
AllOps
()
const
;
std
::
vector
<
OpDesc
*>
AllOps
()
const
;
size_t
OpSize
()
const
{
return
ops_
.
size
();
}
size_t
OpSize
()
const
{
return
ops_
.
size
();
}
OpDesc
Bind
*
Op
(
int
idx
)
{
return
ops_
.
at
(
idx
).
get
();
}
OpDesc
*
Op
(
int
idx
)
{
return
ops_
.
at
(
idx
).
get
();
}
void
Flush
();
void
Flush
();
proto
::
BlockDesc
*
Proto
();
proto
::
BlockDesc
*
Proto
();
ProgramDesc
Bind
*
Program
()
{
return
this
->
prog_
;
}
ProgramDesc
*
Program
()
{
return
this
->
prog_
;
}
private:
private:
void
ClearPBOps
();
void
ClearPBOps
();
void
ClearPBVars
();
void
ClearPBVars
();
private:
private:
ProgramDesc
Bind
*
prog_
;
// not_own
ProgramDesc
*
prog_
;
// not_own
proto
::
BlockDesc
*
desc_
;
// not_own
proto
::
BlockDesc
*
desc_
;
// not_own
bool
need_update_
;
bool
need_update_
;
std
::
deque
<
std
::
unique_ptr
<
OpDesc
Bind
>>
ops_
;
std
::
deque
<
std
::
unique_ptr
<
OpDesc
>>
ops_
;
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
VarDesc
Bind
>>
vars_
;
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
VarDesc
>>
vars_
;
DISABLE_COPY_AND_ASSIGN
(
BlockDesc
Bind
);
DISABLE_COPY_AND_ASSIGN
(
BlockDesc
);
};
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/details/op_registry.h
浏览文件 @
09189732
...
@@ -106,10 +106,10 @@ template <typename T>
...
@@ -106,10 +106,10 @@ template <typename T>
struct
OpInfoFiller
<
T
,
kGradOpDescMaker
>
{
struct
OpInfoFiller
<
T
,
kGradOpDescMaker
>
{
void
operator
()(
const
char
*
op_type
,
OpInfo
*
info
)
const
{
void
operator
()(
const
char
*
op_type
,
OpInfo
*
info
)
const
{
info
->
grad_op_maker_
=
[](
info
->
grad_op_maker_
=
[](
const
OpDesc
Bind
&
fwd_op
,
const
OpDesc
&
fwd_op
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_set
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_set
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
grad_to_var
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
grad_to_var
,
const
std
::
vector
<
BlockDesc
Bind
*>&
grad_block
)
{
const
std
::
vector
<
BlockDesc
*>&
grad_block
)
{
T
maker
(
fwd_op
,
no_grad_set
,
grad_to_var
,
grad_block
);
T
maker
(
fwd_op
,
no_grad_set
,
grad_to_var
,
grad_block
);
return
maker
();
return
maker
();
};
};
...
@@ -119,7 +119,7 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
...
@@ -119,7 +119,7 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
template
<
typename
T
>
template
<
typename
T
>
struct
OpInfoFiller
<
T
,
kVarTypeInference
>
{
struct
OpInfoFiller
<
T
,
kVarTypeInference
>
{
void
operator
()(
const
char
*
op_type
,
OpInfo
*
info
)
const
{
void
operator
()(
const
char
*
op_type
,
OpInfo
*
info
)
const
{
info
->
infer_var_type_
=
[](
const
OpDesc
Bind
&
fwd_op
,
BlockDescBind
*
block
)
{
info
->
infer_var_type_
=
[](
const
OpDesc
&
fwd_op
,
BlockDesc
*
block
)
{
T
inference
;
T
inference
;
inference
(
fwd_op
,
block
);
inference
(
fwd_op
,
block
);
};
};
...
...
paddle/framework/executor.cc
浏览文件 @
09189732
...
@@ -64,7 +64,7 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
...
@@ -64,7 +64,7 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
}
}
}
}
void
Executor
::
Run
(
const
ProgramDesc
Bind
&
pdesc
,
Scope
*
scope
,
int
block_id
,
void
Executor
::
Run
(
const
ProgramDesc
&
pdesc
,
Scope
*
scope
,
int
block_id
,
bool
create_local_scope
)
{
bool
create_local_scope
)
{
// TODO(tonyyang-svail):
// TODO(tonyyang-svail):
// - only runs on the first device (i.e. no interdevice communication)
// - only runs on the first device (i.e. no interdevice communication)
...
...
paddle/framework/executor.h
浏览文件 @
09189732
...
@@ -114,7 +114,7 @@ class Executor {
...
@@ -114,7 +114,7 @@ class Executor {
* ProgramDesc
* ProgramDesc
* Scope
* Scope
*/
*/
void
Run
(
const
ProgramDesc
Bind
&
,
Scope
*
,
int
,
bool
create_local_scope
=
true
);
void
Run
(
const
ProgramDesc
&
,
Scope
*
,
int
,
bool
create_local_scope
=
true
);
private:
private:
std
::
vector
<
const
platform
::
DeviceContext
*>
device_contexts_
;
std
::
vector
<
const
platform
::
DeviceContext
*>
device_contexts_
;
...
...
paddle/framework/grad_op_desc_maker.h
浏览文件 @
09189732
...
@@ -25,18 +25,16 @@ namespace framework {
...
@@ -25,18 +25,16 @@ namespace framework {
class
GradOpDescMakerBase
{
class
GradOpDescMakerBase
{
public:
public:
explicit
GradOpDescMakerBase
(
explicit
GradOpDescMakerBase
(
const
OpDescBind
&
fwd_op
,
const
OpDesc
&
fwd_op
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_set
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_set
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
grad_to_var
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
grad_to_var
,
const
std
::
vector
<
BlockDescBind
*>&
grad_block
=
const
std
::
vector
<
BlockDesc
*>&
grad_block
=
std
::
vector
<
BlockDesc
*>
())
std
::
vector
<
BlockDescBind
*>
())
:
fwd_op_
(
fwd_op
),
:
fwd_op_
(
fwd_op
),
no_grad_set_
(
no_grad_set
),
no_grad_set_
(
no_grad_set
),
grad_to_var_
(
grad_to_var
),
grad_to_var_
(
grad_to_var
),
grad_block_
(
grad_block
)
{}
grad_block_
(
grad_block
)
{}
virtual
~
GradOpDescMakerBase
()
=
default
;
virtual
~
GradOpDescMakerBase
()
=
default
;
virtual
std
::
vector
<
std
::
unique_ptr
<
OpDesc
Bind
>>
operator
()()
const
=
0
;
virtual
std
::
vector
<
std
::
unique_ptr
<
OpDesc
>>
operator
()()
const
=
0
;
protected:
protected:
std
::
vector
<
std
::
string
>
InputGrad
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
string
>
InputGrad
(
const
std
::
string
&
name
,
...
@@ -105,26 +103,26 @@ class GradOpDescMakerBase {
...
@@ -105,26 +103,26 @@ class GradOpDescMakerBase {
std
::
string
ForwardOpType
()
const
{
return
this
->
fwd_op_
.
Type
();
}
std
::
string
ForwardOpType
()
const
{
return
this
->
fwd_op_
.
Type
();
}
private:
private:
const
OpDesc
Bind
&
fwd_op_
;
const
OpDesc
&
fwd_op_
;
const
std
::
unordered_set
<
std
::
string
>&
no_grad_set_
;
const
std
::
unordered_set
<
std
::
string
>&
no_grad_set_
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
grad_to_var_
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
grad_to_var_
;
protected:
protected:
std
::
vector
<
BlockDesc
Bind
*>
grad_block_
;
std
::
vector
<
BlockDesc
*>
grad_block_
;
};
};
class
SingleGradOpDescMaker
:
public
GradOpDescMakerBase
{
class
SingleGradOpDescMaker
:
public
GradOpDescMakerBase
{
public:
public:
using
GradOpDescMakerBase
::
GradOpDescMakerBase
;
using
GradOpDescMakerBase
::
GradOpDescMakerBase
;
std
::
vector
<
std
::
unique_ptr
<
OpDesc
Bind
>>
operator
()()
const
{
std
::
vector
<
std
::
unique_ptr
<
OpDesc
>>
operator
()()
const
{
std
::
vector
<
std
::
unique_ptr
<
OpDesc
Bind
>>
retv
;
std
::
vector
<
std
::
unique_ptr
<
OpDesc
>>
retv
;
retv
.
emplace_back
(
this
->
Apply
());
retv
.
emplace_back
(
this
->
Apply
());
return
retv
;
return
retv
;
}
}
protected:
protected:
virtual
std
::
unique_ptr
<
OpDesc
Bind
>
Apply
()
const
=
0
;
virtual
std
::
unique_ptr
<
OpDesc
>
Apply
()
const
=
0
;
};
};
template
<
bool
DropEmptyIG
=
true
>
template
<
bool
DropEmptyIG
=
true
>
...
@@ -133,8 +131,8 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
...
@@ -133,8 +131,8 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
using
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
protected:
virtual
std
::
unique_ptr
<
OpDesc
Bind
>
Apply
()
const
{
virtual
std
::
unique_ptr
<
OpDesc
>
Apply
()
const
{
auto
*
grad
=
new
OpDesc
Bind
();
auto
*
grad
=
new
OpDesc
();
grad
->
SetType
(
this
->
GradOpType
());
grad
->
SetType
(
this
->
GradOpType
());
for
(
auto
&
input_param
:
this
->
InputNames
())
{
for
(
auto
&
input_param
:
this
->
InputNames
())
{
...
@@ -150,7 +148,7 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
...
@@ -150,7 +148,7 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
grad
->
SetAttrMap
(
this
->
Attrs
());
grad
->
SetAttrMap
(
this
->
Attrs
());
return
std
::
unique_ptr
<
OpDesc
Bind
>
(
grad
);
return
std
::
unique_ptr
<
OpDesc
>
(
grad
);
}
}
virtual
std
::
string
GradOpType
()
const
{
virtual
std
::
string
GradOpType
()
const
{
...
@@ -161,7 +159,7 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
...
@@ -161,7 +159,7 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
class
EmptyGradOpMaker
:
public
GradOpDescMakerBase
{
class
EmptyGradOpMaker
:
public
GradOpDescMakerBase
{
public:
public:
using
GradOpDescMakerBase
::
GradOpDescMakerBase
;
using
GradOpDescMakerBase
::
GradOpDescMakerBase
;
std
::
vector
<
std
::
unique_ptr
<
OpDesc
Bind
>>
operator
()()
const
override
{
std
::
vector
<
std
::
unique_ptr
<
OpDesc
>>
operator
()()
const
override
{
return
{};
return
{};
}
}
};
};
...
...
paddle/framework/op_desc.cc
浏览文件 @
09189732
...
@@ -25,12 +25,11 @@ limitations under the License. */
...
@@ -25,12 +25,11 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
OpDesc
Bind
;
class
OpDesc
;
class
BlockDesc
Bind
;
class
BlockDesc
;
class
CompileTimeInferShapeContext
:
public
InferShapeContext
{
class
CompileTimeInferShapeContext
:
public
InferShapeContext
{
public:
public:
CompileTimeInferShapeContext
(
const
OpDescBind
&
op
,
CompileTimeInferShapeContext
(
const
OpDesc
&
op
,
const
BlockDesc
&
block
);
const
BlockDescBind
&
block
);
bool
HasInput
(
const
std
::
string
&
name
)
const
override
;
bool
HasInput
(
const
std
::
string
&
name
)
const
override
;
...
@@ -76,13 +75,12 @@ class CompileTimeInferShapeContext : public InferShapeContext {
...
@@ -76,13 +75,12 @@ class CompileTimeInferShapeContext : public InferShapeContext {
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
;
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
;
const
OpDesc
Bind
&
op_
;
const
OpDesc
&
op_
;
const
BlockDesc
Bind
&
block_
;
const
BlockDesc
&
block_
;
};
};
OpDescBind
::
OpDescBind
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
OpDesc
::
OpDesc
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
{
const
AttributeMap
&
attrs
)
{
desc_
.
set_type
(
type
);
desc_
.
set_type
(
type
);
inputs_
=
inputs
;
inputs_
=
inputs
;
outputs_
=
outputs
;
outputs_
=
outputs
;
...
@@ -90,7 +88,7 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
...
@@ -90,7 +88,7 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
need_update_
=
true
;
need_update_
=
true
;
}
}
OpDesc
Bind
::
OpDescBind
(
const
proto
::
OpDesc
&
desc
,
ProgramDescBind
*
prog
)
OpDesc
::
OpDesc
(
const
proto
::
OpDesc
&
desc
,
ProgramDesc
*
prog
)
:
desc_
(
desc
),
need_update_
(
false
)
{
:
desc_
(
desc
),
need_update_
(
false
)
{
// restore inputs_
// restore inputs_
int
input_size
=
desc_
.
inputs_size
();
int
input_size
=
desc_
.
inputs_size
();
...
@@ -126,20 +124,19 @@ OpDescBind::OpDescBind(const proto::OpDesc &desc, ProgramDescBind *prog)
...
@@ -126,20 +124,19 @@ OpDescBind::OpDescBind(const proto::OpDesc &desc, ProgramDescBind *prog)
}
}
}
}
proto
::
OpDesc
*
OpDesc
Bind
::
Proto
()
{
proto
::
OpDesc
*
OpDesc
::
Proto
()
{
Flush
();
Flush
();
return
&
desc_
;
return
&
desc_
;
}
}
const
std
::
vector
<
std
::
string
>
&
OpDescBind
::
Input
(
const
std
::
vector
<
std
::
string
>
&
OpDesc
::
Input
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
name
)
const
{
auto
it
=
inputs_
.
find
(
name
);
auto
it
=
inputs_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
inputs_
.
end
(),
"Input %s cannot be found in Op %s"
,
name
,
PADDLE_ENFORCE
(
it
!=
inputs_
.
end
(),
"Input %s cannot be found in Op %s"
,
name
,
Type
());
Type
());
return
it
->
second
;
return
it
->
second
;
}
}
std
::
vector
<
std
::
string
>
OpDesc
Bind
::
InputArgumentNames
()
const
{
std
::
vector
<
std
::
string
>
OpDesc
::
InputArgumentNames
()
const
{
std
::
vector
<
std
::
string
>
retv
;
std
::
vector
<
std
::
string
>
retv
;
for
(
auto
&
ipt
:
this
->
inputs_
)
{
for
(
auto
&
ipt
:
this
->
inputs_
)
{
retv
.
insert
(
retv
.
end
(),
ipt
.
second
.
begin
(),
ipt
.
second
.
end
());
retv
.
insert
(
retv
.
end
(),
ipt
.
second
.
begin
(),
ipt
.
second
.
end
());
...
@@ -147,21 +144,20 @@ std::vector<std::string> OpDescBind::InputArgumentNames() const {
...
@@ -147,21 +144,20 @@ std::vector<std::string> OpDescBind::InputArgumentNames() const {
return
retv
;
return
retv
;
}
}
void
OpDesc
Bind
::
SetInput
(
const
std
::
string
&
param_name
,
void
OpDesc
::
SetInput
(
const
std
::
string
&
param_name
,
const
std
::
vector
<
std
::
string
>
&
args
)
{
const
std
::
vector
<
std
::
string
>
&
args
)
{
need_update_
=
true
;
need_update_
=
true
;
inputs_
[
param_name
]
=
args
;
inputs_
[
param_name
]
=
args
;
}
}
const
std
::
vector
<
std
::
string
>
&
OpDescBind
::
Output
(
const
std
::
vector
<
std
::
string
>
&
OpDesc
::
Output
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
name
)
const
{
auto
it
=
outputs_
.
find
(
name
);
auto
it
=
outputs_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
outputs_
.
end
(),
"Output %s cannot be found in Op %s"
,
PADDLE_ENFORCE
(
it
!=
outputs_
.
end
(),
"Output %s cannot be found in Op %s"
,
name
,
Type
());
name
,
Type
());
return
it
->
second
;
return
it
->
second
;
}
}
std
::
vector
<
std
::
string
>
OpDesc
Bind
::
OutputArgumentNames
()
const
{
std
::
vector
<
std
::
string
>
OpDesc
::
OutputArgumentNames
()
const
{
std
::
vector
<
std
::
string
>
retv
;
std
::
vector
<
std
::
string
>
retv
;
for
(
auto
&
ipt
:
this
->
outputs_
)
{
for
(
auto
&
ipt
:
this
->
outputs_
)
{
retv
.
insert
(
retv
.
end
(),
ipt
.
second
.
begin
(),
ipt
.
second
.
end
());
retv
.
insert
(
retv
.
end
(),
ipt
.
second
.
begin
(),
ipt
.
second
.
end
());
...
@@ -169,19 +165,19 @@ std::vector<std::string> OpDescBind::OutputArgumentNames() const {
...
@@ -169,19 +165,19 @@ std::vector<std::string> OpDescBind::OutputArgumentNames() const {
return
retv
;
return
retv
;
}
}
void
OpDesc
Bind
::
SetOutput
(
const
std
::
string
&
param_name
,
void
OpDesc
::
SetOutput
(
const
std
::
string
&
param_name
,
const
std
::
vector
<
std
::
string
>
&
args
)
{
const
std
::
vector
<
std
::
string
>
&
args
)
{
need_update_
=
true
;
need_update_
=
true
;
this
->
outputs_
[
param_name
]
=
args
;
this
->
outputs_
[
param_name
]
=
args
;
}
}
proto
::
AttrType
OpDesc
Bind
::
GetAttrType
(
const
std
::
string
&
name
)
const
{
proto
::
AttrType
OpDesc
::
GetAttrType
(
const
std
::
string
&
name
)
const
{
auto
it
=
attrs_
.
find
(
name
);
auto
it
=
attrs_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
attrs_
.
end
(),
"Attribute %s is not found"
,
name
);
PADDLE_ENFORCE
(
it
!=
attrs_
.
end
(),
"Attribute %s is not found"
,
name
);
return
static_cast
<
proto
::
AttrType
>
(
it
->
second
.
which
()
-
1
);
return
static_cast
<
proto
::
AttrType
>
(
it
->
second
.
which
()
-
1
);
}
}
std
::
vector
<
std
::
string
>
OpDesc
Bind
::
AttrNames
()
const
{
std
::
vector
<
std
::
string
>
OpDesc
::
AttrNames
()
const
{
std
::
vector
<
std
::
string
>
retv
;
std
::
vector
<
std
::
string
>
retv
;
retv
.
reserve
(
attrs_
.
size
());
retv
.
reserve
(
attrs_
.
size
());
for
(
auto
&
attr
:
attrs_
)
{
for
(
auto
&
attr
:
attrs_
)
{
...
@@ -190,41 +186,39 @@ std::vector<std::string> OpDescBind::AttrNames() const {
...
@@ -190,41 +186,39 @@ std::vector<std::string> OpDescBind::AttrNames() const {
return
retv
;
return
retv
;
}
}
void
OpDesc
Bind
::
SetAttr
(
const
std
::
string
&
name
,
const
Attribute
&
v
)
{
void
OpDesc
::
SetAttr
(
const
std
::
string
&
name
,
const
Attribute
&
v
)
{
this
->
attrs_
[
name
]
=
v
;
this
->
attrs_
[
name
]
=
v
;
need_update_
=
true
;
need_update_
=
true
;
}
}
void
OpDesc
Bind
::
SetBlockAttr
(
const
std
::
string
&
name
,
BlockDescBind
&
block
)
{
void
OpDesc
::
SetBlockAttr
(
const
std
::
string
&
name
,
BlockDesc
&
block
)
{
this
->
attrs_
[
name
]
=
&
block
;
this
->
attrs_
[
name
]
=
&
block
;
need_update_
=
true
;
need_update_
=
true
;
}
}
void
OpDesc
Bind
::
SetAttrMap
(
void
OpDesc
::
SetAttrMap
(
const
std
::
unordered_map
<
std
::
string
,
Attribute
>
&
attr_map
)
{
const
std
::
unordered_map
<
std
::
string
,
Attribute
>
&
attr_map
)
{
attrs_
=
attr_map
;
attrs_
=
attr_map
;
need_update_
=
true
;
need_update_
=
true
;
}
}
Attribute
OpDesc
Bind
::
GetAttr
(
const
std
::
string
&
name
)
const
{
Attribute
OpDesc
::
GetAttr
(
const
std
::
string
&
name
)
const
{
auto
it
=
attrs_
.
find
(
name
);
auto
it
=
attrs_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
attrs_
.
end
(),
"Attribute %s is not found"
,
name
);
PADDLE_ENFORCE
(
it
!=
attrs_
.
end
(),
"Attribute %s is not found"
,
name
);
return
it
->
second
;
return
it
->
second
;
}
}
int
OpDesc
Bind
::
GetBlockAttr
(
const
std
::
string
&
name
)
const
{
int
OpDesc
::
GetBlockAttr
(
const
std
::
string
&
name
)
const
{
auto
it
=
attrs_
.
find
(
name
);
auto
it
=
attrs_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
attrs_
.
end
(),
"Attribute %s is not found"
,
name
);
PADDLE_ENFORCE
(
it
!=
attrs_
.
end
(),
"Attribute %s is not found"
,
name
);
return
boost
::
get
<
BlockDesc
Bind
*>
(
it
->
second
)
->
ID
();
return
boost
::
get
<
BlockDesc
*>
(
it
->
second
)
->
ID
();
}
}
const
std
::
unordered_map
<
std
::
string
,
Attribute
>
&
OpDescBind
::
GetAttrMap
()
const
std
::
unordered_map
<
std
::
string
,
Attribute
>
&
OpDesc
::
GetAttrMap
()
const
{
const
{
return
attrs_
;
return
attrs_
;
}
}
void
OpDescBind
::
Rename
(
const
std
::
string
&
old_name
,
void
OpDesc
::
Rename
(
const
std
::
string
&
old_name
,
const
std
::
string
&
new_name
)
{
const
std
::
string
&
new_name
)
{
for
(
auto
&
input
:
inputs_
)
{
for
(
auto
&
input
:
inputs_
)
{
std
::
replace
(
input
.
second
.
begin
(),
input
.
second
.
end
(),
old_name
,
new_name
);
std
::
replace
(
input
.
second
.
begin
(),
input
.
second
.
end
(),
old_name
,
new_name
);
}
}
...
@@ -235,7 +229,7 @@ void OpDescBind::Rename(const std::string &old_name,
...
@@ -235,7 +229,7 @@ void OpDescBind::Rename(const std::string &old_name,
need_update_
=
true
;
need_update_
=
true
;
}
}
void
OpDesc
Bind
::
RenameOutput
(
const
std
::
string
&
old_name
,
void
OpDesc
::
RenameOutput
(
const
std
::
string
&
old_name
,
const
std
::
string
&
new_name
)
{
const
std
::
string
&
new_name
)
{
for
(
auto
&
output
:
outputs_
)
{
for
(
auto
&
output
:
outputs_
)
{
std
::
replace
(
output
.
second
.
begin
(),
output
.
second
.
end
(),
old_name
,
std
::
replace
(
output
.
second
.
begin
(),
output
.
second
.
end
(),
old_name
,
...
@@ -244,7 +238,7 @@ void OpDescBind::RenameOutput(const std::string &old_name,
...
@@ -244,7 +238,7 @@ void OpDescBind::RenameOutput(const std::string &old_name,
need_update_
=
true
;
need_update_
=
true
;
}
}
void
OpDesc
Bind
::
RenameInput
(
const
std
::
string
&
old_name
,
void
OpDesc
::
RenameInput
(
const
std
::
string
&
old_name
,
const
std
::
string
&
new_name
)
{
const
std
::
string
&
new_name
)
{
for
(
auto
&
input
:
inputs_
)
{
for
(
auto
&
input
:
inputs_
)
{
std
::
replace
(
input
.
second
.
begin
(),
input
.
second
.
end
(),
old_name
,
new_name
);
std
::
replace
(
input
.
second
.
begin
(),
input
.
second
.
end
(),
old_name
,
new_name
);
...
@@ -278,7 +272,7 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
...
@@ -278,7 +272,7 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
void
operator
()(
boost
::
blank
)
const
{
PADDLE_THROW
(
"Unexpected branch"
);
}
void
operator
()(
boost
::
blank
)
const
{
PADDLE_THROW
(
"Unexpected branch"
);
}
};
};
void
OpDesc
Bind
::
Flush
()
{
void
OpDesc
::
Flush
()
{
if
(
need_update_
)
{
if
(
need_update_
)
{
this
->
desc_
.
mutable_inputs
()
->
Clear
();
this
->
desc_
.
mutable_inputs
()
->
Clear
();
for
(
auto
&
ipt
:
inputs_
)
{
for
(
auto
&
ipt
:
inputs_
)
{
...
@@ -330,7 +324,7 @@ static void InitInferShapeFuncs() {
...
@@ -330,7 +324,7 @@ static void InitInferShapeFuncs() {
});
});
}
}
void
OpDesc
Bind
::
CheckAttrs
()
{
void
OpDesc
::
CheckAttrs
()
{
PADDLE_ENFORCE
(
!
Type
().
empty
(),
PADDLE_ENFORCE
(
!
Type
().
empty
(),
"CheckAttr() can not be called before type is setted."
);
"CheckAttr() can not be called before type is setted."
);
auto
*
checker
=
OpInfoMap
::
Instance
().
Get
(
Type
()).
Checker
();
auto
*
checker
=
OpInfoMap
::
Instance
().
Get
(
Type
()).
Checker
();
...
@@ -342,7 +336,7 @@ void OpDescBind::CheckAttrs() {
...
@@ -342,7 +336,7 @@ void OpDescBind::CheckAttrs() {
checker
->
Check
(
attrs_
);
checker
->
Check
(
attrs_
);
}
}
void
OpDesc
Bind
::
InferShape
(
const
BlockDescBind
&
block
)
const
{
void
OpDesc
::
InferShape
(
const
BlockDesc
&
block
)
const
{
VLOG
(
3
)
<<
"CompileTime infer shape on "
<<
Type
();
VLOG
(
3
)
<<
"CompileTime infer shape on "
<<
Type
();
InitInferShapeFuncs
();
InitInferShapeFuncs
();
auto
&
infer_shape
=
OpInfoMap
::
Instance
().
Get
(
this
->
Type
()).
infer_shape_
;
auto
&
infer_shape
=
OpInfoMap
::
Instance
().
Get
(
this
->
Type
()).
infer_shape_
;
...
@@ -365,7 +359,7 @@ void OpDescBind::InferShape(const BlockDescBind &block) const {
...
@@ -365,7 +359,7 @@ void OpDescBind::InferShape(const BlockDescBind &block) const {
infer_shape
(
&
ctx
);
infer_shape
(
&
ctx
);
}
}
void
OpDesc
Bind
::
InferVarType
(
BlockDescBind
*
block
)
const
{
void
OpDesc
::
InferVarType
(
BlockDesc
*
block
)
const
{
auto
&
info
=
OpInfoMap
::
Instance
().
Get
(
this
->
Type
());
auto
&
info
=
OpInfoMap
::
Instance
().
Get
(
this
->
Type
());
if
(
info
.
infer_var_type_
)
{
if
(
info
.
infer_var_type_
)
{
info
.
infer_var_type_
(
*
this
,
block
);
info
.
infer_var_type_
(
*
this
,
block
);
...
@@ -384,7 +378,7 @@ void OpDescBind::InferVarType(BlockDescBind *block) const {
...
@@ -384,7 +378,7 @@ void OpDescBind::InferVarType(BlockDescBind *block) const {
}
}
CompileTimeInferShapeContext
::
CompileTimeInferShapeContext
(
CompileTimeInferShapeContext
::
CompileTimeInferShapeContext
(
const
OpDesc
Bind
&
op
,
const
BlockDescBind
&
block
)
const
OpDesc
&
op
,
const
BlockDesc
&
block
)
:
op_
(
op
),
block_
(
block
)
{}
:
op_
(
op
),
block_
(
block
)
{}
bool
CompileTimeInferShapeContext
::
HasInput
(
const
std
::
string
&
name
)
const
{
bool
CompileTimeInferShapeContext
::
HasInput
(
const
std
::
string
&
name
)
const
{
...
...
paddle/framework/op_desc.h
浏览文件 @
09189732
...
@@ -23,17 +23,17 @@ limitations under the License. */
...
@@ -23,17 +23,17 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
BlockDesc
Bind
;
class
BlockDesc
;
class
ProgramDesc
Bind
;
class
ProgramDesc
;
class
OpDesc
Bind
{
class
OpDesc
{
public:
public:
OpDesc
Bind
()
{}
OpDesc
()
{}
OpDesc
Bind
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
OpDesc
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
);
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
);
OpDesc
Bind
(
const
proto
::
OpDesc
&
desc
,
ProgramDescBind
*
prog
);
OpDesc
(
const
proto
::
OpDesc
&
desc
,
ProgramDesc
*
prog
);
proto
::
OpDesc
*
Proto
();
proto
::
OpDesc
*
Proto
();
...
@@ -65,7 +65,7 @@ class OpDescBind {
...
@@ -65,7 +65,7 @@ class OpDescBind {
void
SetAttr
(
const
std
::
string
&
name
,
const
Attribute
&
v
);
void
SetAttr
(
const
std
::
string
&
name
,
const
Attribute
&
v
);
void
SetBlockAttr
(
const
std
::
string
&
name
,
BlockDesc
Bind
&
block
);
void
SetBlockAttr
(
const
std
::
string
&
name
,
BlockDesc
&
block
);
Attribute
GetAttr
(
const
std
::
string
&
name
)
const
;
Attribute
GetAttr
(
const
std
::
string
&
name
)
const
;
...
@@ -107,9 +107,9 @@ class OpDescBind {
...
@@ -107,9 +107,9 @@ class OpDescBind {
void
CheckAttrs
();
void
CheckAttrs
();
void
InferShape
(
const
BlockDesc
Bind
&
block
)
const
;
void
InferShape
(
const
BlockDesc
&
block
)
const
;
void
InferVarType
(
BlockDesc
Bind
*
block
)
const
;
void
InferVarType
(
BlockDesc
*
block
)
const
;
void
MarkAsTarget
()
{
desc_
.
set_is_target
(
true
);
}
void
MarkAsTarget
()
{
desc_
.
set_is_target
(
true
);
}
...
...
paddle/framework/op_registry.cc
浏览文件 @
09189732
...
@@ -47,7 +47,7 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap(
...
@@ -47,7 +47,7 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap(
std
::
unique_ptr
<
OperatorBase
>
OpRegistry
::
CreateOp
(
std
::
unique_ptr
<
OperatorBase
>
OpRegistry
::
CreateOp
(
const
proto
::
OpDesc
&
op_desc
)
{
const
proto
::
OpDesc
&
op_desc
)
{
VLOG
(
1
)
<<
"CreateOp directly from OpDesc is deprecated. It should only be"
VLOG
(
1
)
<<
"CreateOp directly from OpDesc is deprecated. It should only be"
"used in unit tests. Use CreateOp(const OpDesc
Bind
& op_desc) "
"used in unit tests. Use CreateOp(const OpDesc& op_desc) "
"instead."
;
"instead."
;
VariableNameMap
inputs
=
ConvertOpDescVarsToVarNameMap
(
op_desc
.
inputs
());
VariableNameMap
inputs
=
ConvertOpDescVarsToVarNameMap
(
op_desc
.
inputs
());
VariableNameMap
outputs
=
ConvertOpDescVarsToVarNameMap
(
op_desc
.
outputs
());
VariableNameMap
outputs
=
ConvertOpDescVarsToVarNameMap
(
op_desc
.
outputs
());
...
@@ -59,7 +59,7 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
...
@@ -59,7 +59,7 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
return
CreateOp
(
op_desc
.
type
(),
inputs
,
outputs
,
attrs
);
return
CreateOp
(
op_desc
.
type
(),
inputs
,
outputs
,
attrs
);
}
}
std
::
unique_ptr
<
OperatorBase
>
OpRegistry
::
CreateOp
(
const
OpDesc
Bind
&
op_desc
)
{
std
::
unique_ptr
<
OperatorBase
>
OpRegistry
::
CreateOp
(
const
OpDesc
&
op_desc
)
{
return
CreateOp
(
op_desc
.
Type
(),
op_desc
.
Inputs
(),
op_desc
.
Outputs
(),
return
CreateOp
(
op_desc
.
Type
(),
op_desc
.
Inputs
(),
op_desc
.
Outputs
(),
op_desc
.
GetAttrMap
());
op_desc
.
GetAttrMap
());
}
}
...
...
paddle/framework/op_registry.h
浏览文件 @
09189732
...
@@ -79,7 +79,7 @@ class OpRegistry {
...
@@ -79,7 +79,7 @@ class OpRegistry {
static
std
::
unique_ptr
<
OperatorBase
>
CreateOp
(
const
proto
::
OpDesc
&
op_desc
);
static
std
::
unique_ptr
<
OperatorBase
>
CreateOp
(
const
proto
::
OpDesc
&
op_desc
);
static
std
::
unique_ptr
<
OperatorBase
>
CreateOp
(
const
OpDesc
Bind
&
op_desc
);
static
std
::
unique_ptr
<
OperatorBase
>
CreateOp
(
const
OpDesc
&
op_desc
);
};
};
template
<
typename
PlaceType
,
bool
at_end
,
size_t
I
,
typename
...
KernelType
>
template
<
typename
PlaceType
,
bool
at_end
,
size_t
I
,
typename
...
KernelType
>
...
...
paddle/framework/program_desc.cc
浏览文件 @
09189732
...
@@ -18,49 +18,49 @@ limitations under the License. */
...
@@ -18,49 +18,49 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
BlockDesc
Bind
*
ProgramDescBind
::
AppendBlock
(
const
BlockDescBind
&
parent
)
{
BlockDesc
*
ProgramDesc
::
AppendBlock
(
const
BlockDesc
&
parent
)
{
auto
*
b
=
desc_
.
add_blocks
();
auto
*
b
=
desc_
.
add_blocks
();
b
->
set_parent_idx
(
parent
.
ID
());
b
->
set_parent_idx
(
parent
.
ID
());
b
->
set_idx
(
desc_
.
blocks_size
()
-
1
);
b
->
set_idx
(
desc_
.
blocks_size
()
-
1
);
blocks_
.
emplace_back
(
new
BlockDesc
Bind
(
this
,
b
));
blocks_
.
emplace_back
(
new
BlockDesc
(
this
,
b
));
return
blocks_
.
back
().
get
();
return
blocks_
.
back
().
get
();
}
}
proto
::
ProgramDesc
*
ProgramDesc
Bind
::
Proto
()
{
proto
::
ProgramDesc
*
ProgramDesc
::
Proto
()
{
for
(
auto
&
block
:
blocks_
)
{
for
(
auto
&
block
:
blocks_
)
{
block
->
Flush
();
block
->
Flush
();
}
}
return
&
desc_
;
return
&
desc_
;
}
}
ProgramDesc
Bind
::
ProgramDescBind
()
{
ProgramDesc
::
ProgramDesc
()
{
auto
*
block
=
desc_
.
mutable_blocks
()
->
Add
();
auto
*
block
=
desc_
.
mutable_blocks
()
->
Add
();
block
->
set_idx
(
kRootBlockIndex
);
block
->
set_idx
(
kRootBlockIndex
);
block
->
set_parent_idx
(
kNoneBlockIndex
);
block
->
set_parent_idx
(
kNoneBlockIndex
);
blocks_
.
emplace_back
(
new
BlockDesc
Bind
(
this
,
block
));
blocks_
.
emplace_back
(
new
BlockDesc
(
this
,
block
));
}
}
ProgramDesc
Bind
::
ProgramDescBind
(
const
ProgramDescBind
&
o
)
{
ProgramDesc
::
ProgramDesc
(
const
ProgramDesc
&
o
)
{
desc_
=
o
.
desc_
;
desc_
=
o
.
desc_
;
for
(
int
i
=
0
;
i
<
desc_
.
blocks_size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
desc_
.
blocks_size
();
++
i
)
{
auto
*
block
=
desc_
.
mutable_blocks
(
i
);
auto
*
block
=
desc_
.
mutable_blocks
(
i
);
blocks_
.
emplace_back
(
new
BlockDesc
Bind
(
*
o
.
blocks_
[
i
],
block
,
this
));
blocks_
.
emplace_back
(
new
BlockDesc
(
*
o
.
blocks_
[
i
],
block
,
this
));
}
}
}
}
ProgramDesc
Bind
::
ProgramDescBind
(
const
proto
::
ProgramDesc
&
desc
)
{
ProgramDesc
::
ProgramDesc
(
const
proto
::
ProgramDesc
&
desc
)
{
desc_
=
desc
;
desc_
=
desc
;
for
(
auto
&
block_desc
:
*
desc_
.
mutable_blocks
())
{
for
(
auto
&
block_desc
:
*
desc_
.
mutable_blocks
())
{
blocks_
.
emplace_back
(
new
BlockDesc
Bind
(
this
,
&
block_desc
));
blocks_
.
emplace_back
(
new
BlockDesc
(
this
,
&
block_desc
));
}
}
}
}
ProgramDesc
Bind
::
ProgramDescBind
(
const
std
::
string
&
binary_str
)
{
ProgramDesc
::
ProgramDesc
(
const
std
::
string
&
binary_str
)
{
PADDLE_ENFORCE
(
desc_
.
ParseFromString
(
binary_str
),
PADDLE_ENFORCE
(
desc_
.
ParseFromString
(
binary_str
),
"Fail to parse program_desc from binary string."
);
"Fail to parse program_desc from binary string."
);
for
(
auto
&
block_desc
:
*
desc_
.
mutable_blocks
())
{
for
(
auto
&
block_desc
:
*
desc_
.
mutable_blocks
())
{
blocks_
.
emplace_back
(
new
BlockDesc
Bind
(
this
,
&
block_desc
));
blocks_
.
emplace_back
(
new
BlockDesc
(
this
,
&
block_desc
));
}
}
}
}
...
...
paddle/framework/program_desc.h
浏览文件 @
09189732
...
@@ -23,23 +23,23 @@ limitations under the License. */
...
@@ -23,23 +23,23 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
BlockDesc
Bind
;
class
BlockDesc
;
class
ProgramDesc
Bind
{
class
ProgramDesc
{
public:
public:
ProgramDesc
Bind
();
ProgramDesc
();
explicit
ProgramDesc
Bind
(
const
proto
::
ProgramDesc
&
desc
);
explicit
ProgramDesc
(
const
proto
::
ProgramDesc
&
desc
);
ProgramDesc
Bind
(
const
ProgramDescBind
&
o
);
ProgramDesc
(
const
ProgramDesc
&
o
);
explicit
ProgramDesc
Bind
(
const
std
::
string
&
binary_str
);
explicit
ProgramDesc
(
const
std
::
string
&
binary_str
);
BlockDesc
Bind
*
AppendBlock
(
const
BlockDescBind
&
parent
);
BlockDesc
*
AppendBlock
(
const
BlockDesc
&
parent
);
BlockDesc
Bind
*
MutableBlock
(
size_t
idx
)
{
return
blocks_
[
idx
].
get
();
}
BlockDesc
*
MutableBlock
(
size_t
idx
)
{
return
blocks_
[
idx
].
get
();
}
const
BlockDesc
Bind
&
Block
(
size_t
idx
)
const
{
return
*
blocks_
[
idx
];
}
const
BlockDesc
&
Block
(
size_t
idx
)
const
{
return
*
blocks_
[
idx
];
}
size_t
Size
()
const
{
return
blocks_
.
size
();
}
size_t
Size
()
const
{
return
blocks_
.
size
();
}
...
@@ -48,7 +48,7 @@ class ProgramDescBind {
...
@@ -48,7 +48,7 @@ class ProgramDescBind {
private:
private:
proto
::
ProgramDesc
desc_
;
proto
::
ProgramDesc
desc_
;
std
::
vector
<
std
::
unique_ptr
<
BlockDesc
Bind
>>
blocks_
;
std
::
vector
<
std
::
unique_ptr
<
BlockDesc
>>
blocks_
;
};
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/program_desc_test.cc
浏览文件 @
09189732
...
@@ -19,7 +19,7 @@
...
@@ -19,7 +19,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
TEST
(
ProgramDesc
,
copy_ctor
)
{
TEST
(
ProgramDesc
,
copy_ctor
)
{
ProgramDesc
Bind
program
;
ProgramDesc
program
;
auto
*
global_block
=
program
.
MutableBlock
(
0
);
auto
*
global_block
=
program
.
MutableBlock
(
0
);
auto
*
x
=
global_block
->
Var
(
"X"
);
auto
*
x
=
global_block
->
Var
(
"X"
);
x
->
SetType
(
proto
::
VarDesc_VarType_LOD_TENSOR
);
x
->
SetType
(
proto
::
VarDesc_VarType_LOD_TENSOR
);
...
@@ -42,12 +42,12 @@ TEST(ProgramDesc, copy_ctor) {
...
@@ -42,12 +42,12 @@ TEST(ProgramDesc, copy_ctor) {
out
->
SetType
(
proto
::
VarDesc_VarType_LOD_TENSOR
);
out
->
SetType
(
proto
::
VarDesc_VarType_LOD_TENSOR
);
op
->
SetOutput
(
"Y"
,
{
out
->
Name
()});
op
->
SetOutput
(
"Y"
,
{
out
->
Name
()});
ProgramDesc
Bind
program_copy
(
program
);
ProgramDesc
program_copy
(
program
);
auto
*
global_block_copy
=
program_copy
.
MutableBlock
(
0
);
auto
*
global_block_copy
=
program_copy
.
MutableBlock
(
0
);
ASSERT_NE
(
global_block
,
global_block_copy
);
ASSERT_NE
(
global_block
,
global_block_copy
);
auto
assert_same_var
=
[
&
](
const
std
::
string
&
name
,
VarDesc
Bind
*
var_before
)
{
auto
assert_same_var
=
[
&
](
const
std
::
string
&
name
,
VarDesc
*
var_before
)
{
ASSERT_TRUE
(
global_block_copy
->
HasVar
(
name
));
ASSERT_TRUE
(
global_block_copy
->
HasVar
(
name
));
auto
*
copy
=
global_block_copy
->
Var
(
name
);
auto
*
copy
=
global_block_copy
->
Var
(
name
);
ASSERT_NE
(
copy
,
var_before
);
ASSERT_NE
(
copy
,
var_before
);
...
@@ -81,7 +81,7 @@ TEST(ProgramDesc, copy_ctor) {
...
@@ -81,7 +81,7 @@ TEST(ProgramDesc, copy_ctor) {
}
}
TEST
(
ProgramDescBind
,
serialize_and_deserialize
)
{
TEST
(
ProgramDescBind
,
serialize_and_deserialize
)
{
ProgramDesc
Bind
program_origin
;
ProgramDesc
program_origin
;
auto
*
global_block
=
program_origin
.
MutableBlock
(
0
);
auto
*
global_block
=
program_origin
.
MutableBlock
(
0
);
auto
*
x
=
global_block
->
Var
(
"X"
);
auto
*
x
=
global_block
->
Var
(
"X"
);
x
->
SetType
(
proto
::
VarDesc_VarType_LOD_TENSOR
);
x
->
SetType
(
proto
::
VarDesc_VarType_LOD_TENSOR
);
...
@@ -107,11 +107,11 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
...
@@ -107,11 +107,11 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
std
::
string
binary_str
;
std
::
string
binary_str
;
program_origin
.
Proto
()
->
SerializeToString
(
&
binary_str
);
program_origin
.
Proto
()
->
SerializeToString
(
&
binary_str
);
ProgramDesc
Bind
program_restored
(
binary_str
);
ProgramDesc
program_restored
(
binary_str
);
auto
*
global_block_restored
=
program_restored
.
MutableBlock
(
0
);
auto
*
global_block_restored
=
program_restored
.
MutableBlock
(
0
);
ASSERT_NE
(
global_block
,
global_block_restored
);
ASSERT_NE
(
global_block
,
global_block_restored
);
auto
assert_same_var
=
[
&
](
const
std
::
string
&
name
,
VarDesc
Bind
*
var_before
)
{
auto
assert_same_var
=
[
&
](
const
std
::
string
&
name
,
VarDesc
*
var_before
)
{
ASSERT_TRUE
(
global_block_restored
->
HasVar
(
name
));
ASSERT_TRUE
(
global_block_restored
->
HasVar
(
name
));
auto
*
restored
=
global_block_restored
->
Var
(
name
);
auto
*
restored
=
global_block_restored
->
Var
(
name
);
ASSERT_NE
(
restored
,
var_before
);
ASSERT_NE
(
restored
,
var_before
);
...
...
paddle/framework/prune_test.cc
浏览文件 @
09189732
...
@@ -29,7 +29,7 @@ namespace ops = paddle::operators;
...
@@ -29,7 +29,7 @@ namespace ops = paddle::operators;
void
AddOp
(
const
std
::
string
&
type
,
const
f
::
VariableNameMap
&
inputs
,
void
AddOp
(
const
std
::
string
&
type
,
const
f
::
VariableNameMap
&
inputs
,
const
f
::
VariableNameMap
&
outputs
,
f
::
AttributeMap
attrs
,
const
f
::
VariableNameMap
&
outputs
,
f
::
AttributeMap
attrs
,
paddle
::
framework
::
BlockDesc
Bind
*
block
)
{
paddle
::
framework
::
BlockDesc
*
block
)
{
// insert output
// insert output
for
(
auto
kv
:
outputs
)
{
for
(
auto
kv
:
outputs
)
{
for
(
auto
v
:
kv
.
second
)
{
for
(
auto
v
:
kv
.
second
)
{
...
@@ -51,8 +51,8 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
...
@@ -51,8 +51,8 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
}
}
TEST
(
Prune
,
one_operator
)
{
TEST
(
Prune
,
one_operator
)
{
f
::
ProgramDesc
Bind
program
;
f
::
ProgramDesc
program
;
f
::
BlockDesc
Bind
*
block
=
program
.
MutableBlock
(
0
);
f
::
BlockDesc
*
block
=
program
.
MutableBlock
(
0
);
AddOp
(
"one_one"
,
{{
"input"
,
{
"a"
}}},
{{
"output"
,
{
"b"
}}},
f
::
AttributeMap
{},
AddOp
(
"one_one"
,
{{
"input"
,
{
"a"
}}},
{{
"output"
,
{
"b"
}}},
f
::
AttributeMap
{},
block
);
block
);
...
@@ -69,8 +69,8 @@ TEST(Prune, one_operator) {
...
@@ -69,8 +69,8 @@ TEST(Prune, one_operator) {
}
}
TEST
(
Prune
,
forward
)
{
TEST
(
Prune
,
forward
)
{
f
::
ProgramDesc
Bind
program
;
f
::
ProgramDesc
program
;
f
::
BlockDesc
Bind
*
block
=
program
.
MutableBlock
(
0
);
f
::
BlockDesc
*
block
=
program
.
MutableBlock
(
0
);
AddOp
(
"one_one"
,
{{
"input"
,
{
"a"
}}},
{{
"output"
,
{
"b"
}}},
f
::
AttributeMap
{},
AddOp
(
"one_one"
,
{{
"input"
,
{
"a"
}}},
{{
"output"
,
{
"b"
}}},
f
::
AttributeMap
{},
block
);
block
);
...
@@ -92,8 +92,8 @@ TEST(Prune, forward) {
...
@@ -92,8 +92,8 @@ TEST(Prune, forward) {
}
}
TEST
(
Prune
,
multi_input_op
)
{
TEST
(
Prune
,
multi_input_op
)
{
f
::
ProgramDesc
Bind
program
;
f
::
ProgramDesc
program
;
f
::
BlockDesc
Bind
*
block
=
program
.
MutableBlock
(
0
);
f
::
BlockDesc
*
block
=
program
.
MutableBlock
(
0
);
AddOp
(
"one_one"
,
{{
"input"
,
{
"a0"
}}},
{{
"output"
,
{
"b0"
}}},
f
::
AttributeMap
{},
AddOp
(
"one_one"
,
{{
"input"
,
{
"a0"
}}},
{{
"output"
,
{
"b0"
}}},
f
::
AttributeMap
{},
block
);
block
);
...
@@ -113,8 +113,8 @@ TEST(Prune, multi_input_op) {
...
@@ -113,8 +113,8 @@ TEST(Prune, multi_input_op) {
}
}
TEST
(
Prune
,
multi_output_op
)
{
TEST
(
Prune
,
multi_output_op
)
{
f
::
ProgramDesc
Bind
program
;
f
::
ProgramDesc
program
;
f
::
BlockDesc
Bind
*
block
=
program
.
MutableBlock
(
0
);
f
::
BlockDesc
*
block
=
program
.
MutableBlock
(
0
);
AddOp
(
"one_two"
,
{{
"input"
,
{
"a"
}}},
{{
"output"
,
{
"b"
,
"c"
}}},
AddOp
(
"one_two"
,
{{
"input"
,
{
"a"
}}},
{{
"output"
,
{
"b"
,
"c"
}}},
f
::
AttributeMap
{},
block
);
f
::
AttributeMap
{},
block
);
...
@@ -132,8 +132,8 @@ TEST(Prune, multi_output_op) {
...
@@ -132,8 +132,8 @@ TEST(Prune, multi_output_op) {
}
}
TEST
(
Prune
,
multi_target
)
{
TEST
(
Prune
,
multi_target
)
{
f
::
ProgramDesc
Bind
program
;
f
::
ProgramDesc
program
;
f
::
BlockDesc
Bind
*
block
=
program
.
MutableBlock
(
0
);
f
::
BlockDesc
*
block
=
program
.
MutableBlock
(
0
);
AddOp
(
"one_two"
,
{{
"input"
,
{
"a"
}}},
{{
"output"
,
{
"b"
,
"c"
}}},
AddOp
(
"one_two"
,
{{
"input"
,
{
"a"
}}},
{{
"output"
,
{
"b"
,
"c"
}}},
f
::
AttributeMap
{},
block
);
f
::
AttributeMap
{},
block
);
...
...
paddle/framework/type_defs.h
浏览文件 @
09189732
...
@@ -25,11 +25,9 @@
...
@@ -25,11 +25,9 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
OperatorBase
;
class
OperatorBase
;
class
OpDescBind
;
class
OpDesc
;
class
BlockDescBind
;
class
BlockDesc
;
class
InferShapeContext
;
class
InferShapeContext
;
class
BlockDesc
Bind
;
class
BlockDesc
;
using
VariableNameMap
=
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
;
using
VariableNameMap
=
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
;
...
@@ -37,7 +35,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
...
@@ -37,7 +35,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
using
Attribute
=
using
Attribute
=
boost
::
variant
<
boost
::
blank
,
int
,
float
,
std
::
string
,
std
::
vector
<
int
>
,
boost
::
variant
<
boost
::
blank
,
int
,
float
,
std
::
string
,
std
::
vector
<
int
>
,
std
::
vector
<
float
>
,
std
::
vector
<
std
::
string
>
,
bool
,
std
::
vector
<
float
>
,
std
::
vector
<
std
::
string
>
,
bool
,
std
::
vector
<
bool
>
,
BlockDesc
Bind
*>
;
std
::
vector
<
bool
>
,
BlockDesc
*>
;
using
AttributeMap
=
std
::
unordered_map
<
std
::
string
,
Attribute
>
;
using
AttributeMap
=
std
::
unordered_map
<
std
::
string
,
Attribute
>
;
...
@@ -45,13 +43,13 @@ using OpCreator = std::function<OperatorBase*(
...
@@ -45,13 +43,13 @@ using OpCreator = std::function<OperatorBase*(
const
std
::
string
&
/*type*/
,
const
VariableNameMap
&
/*inputs*/
,
const
std
::
string
&
/*type*/
,
const
VariableNameMap
&
/*inputs*/
,
const
VariableNameMap
&
/*outputs*/
,
const
AttributeMap
&
/*attrs*/
)
>
;
const
VariableNameMap
&
/*outputs*/
,
const
AttributeMap
&
/*attrs*/
)
>
;
using
GradOpMakerFN
=
std
::
function
<
std
::
vector
<
std
::
unique_ptr
<
OpDesc
Bind
>>
(
using
GradOpMakerFN
=
std
::
function
<
std
::
vector
<
std
::
unique_ptr
<
OpDesc
>>
(
const
OpDesc
Bind
&
,
const
std
::
unordered_set
<
std
::
string
>&
/*no_grad_set*/
,
const
OpDesc
&
,
const
std
::
unordered_set
<
std
::
string
>&
/*no_grad_set*/
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
/*grad_to_var*/
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
/*grad_to_var*/
,
const
std
::
vector
<
BlockDesc
Bind
*>&
grad_block
)
>
;
const
std
::
vector
<
BlockDesc
*>&
grad_block
)
>
;
using
InferVarTypeFN
=
std
::
function
<
void
(
const
OpDescBind
&
/*op_desc*/
,
using
InferVarTypeFN
=
BlockDescBind
*
/*block*/
)
>
;
std
::
function
<
void
(
const
OpDesc
&
/*op_desc*/
,
BlockDesc
*
/*block*/
)
>
;
using
InferShapeFN
=
std
::
function
<
void
(
InferShapeContext
*
)
>
;
using
InferShapeFN
=
std
::
function
<
void
(
InferShapeContext
*
)
>
;
...
...
paddle/framework/var_desc.cc
浏览文件 @
09189732
...
@@ -18,29 +18,27 @@ limitations under the License. */
...
@@ -18,29 +18,27 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
proto
::
VarDesc
::
VarType
VarDesc
Bind
::
GetType
()
const
{
return
desc_
.
type
();
}
proto
::
VarDesc
::
VarType
VarDesc
::
GetType
()
const
{
return
desc_
.
type
();
}
void
VarDescBind
::
SetType
(
proto
::
VarDesc
::
VarType
type
)
{
void
VarDesc
::
SetType
(
proto
::
VarDesc
::
VarType
type
)
{
desc_
.
set_type
(
type
);
}
desc_
.
set_type
(
type
);
}
void
VarDesc
Bind
::
SetShape
(
const
std
::
vector
<
int64_t
>
&
dims
)
{
void
VarDesc
::
SetShape
(
const
std
::
vector
<
int64_t
>
&
dims
)
{
VectorToRepeated
(
dims
,
mutable_tensor_desc
()
->
mutable_dims
());
VectorToRepeated
(
dims
,
mutable_tensor_desc
()
->
mutable_dims
());
}
}
void
VarDesc
Bind
::
SetDataType
(
proto
::
DataType
data_type
)
{
void
VarDesc
::
SetDataType
(
proto
::
DataType
data_type
)
{
mutable_tensor_desc
()
->
set_data_type
(
data_type
);
mutable_tensor_desc
()
->
set_data_type
(
data_type
);
}
}
std
::
vector
<
int64_t
>
VarDesc
Bind
::
Shape
()
const
{
std
::
vector
<
int64_t
>
VarDesc
::
Shape
()
const
{
return
RepeatedToVector
(
tensor_desc
().
dims
());
return
RepeatedToVector
(
tensor_desc
().
dims
());
}
}
proto
::
DataType
VarDesc
Bind
::
GetDataType
()
const
{
proto
::
DataType
VarDesc
::
GetDataType
()
const
{
return
tensor_desc
().
data_type
();
return
tensor_desc
().
data_type
();
}
}
void
VarDesc
Bind
::
SetLoDLevel
(
int32_t
lod_level
)
{
void
VarDesc
::
SetLoDLevel
(
int32_t
lod_level
)
{
switch
(
desc_
.
type
())
{
switch
(
desc_
.
type
())
{
case
proto
::
VarDesc
::
LOD_TENSOR
:
case
proto
::
VarDesc
::
LOD_TENSOR
:
desc_
.
mutable_lod_tensor
()
->
set_lod_level
(
lod_level
);
desc_
.
mutable_lod_tensor
()
->
set_lod_level
(
lod_level
);
...
@@ -54,7 +52,7 @@ void VarDescBind::SetLoDLevel(int32_t lod_level) {
...
@@ -54,7 +52,7 @@ void VarDescBind::SetLoDLevel(int32_t lod_level) {
}
}
}
}
int32_t
VarDesc
Bind
::
GetLodLevel
()
const
{
int32_t
VarDesc
::
GetLodLevel
()
const
{
switch
(
desc_
.
type
())
{
switch
(
desc_
.
type
())
{
case
proto
::
VarDesc
::
LOD_TENSOR
:
case
proto
::
VarDesc
::
LOD_TENSOR
:
return
desc_
.
lod_tensor
().
lod_level
();
return
desc_
.
lod_tensor
().
lod_level
();
...
@@ -66,7 +64,7 @@ int32_t VarDescBind::GetLodLevel() const {
...
@@ -66,7 +64,7 @@ int32_t VarDescBind::GetLodLevel() const {
}
}
}
}
const
proto
::
TensorDesc
&
VarDesc
Bind
::
tensor_desc
()
const
{
const
proto
::
TensorDesc
&
VarDesc
::
tensor_desc
()
const
{
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"invoke TensorDesc must after set type"
);
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"invoke TensorDesc must after set type"
);
switch
(
desc_
.
type
())
{
switch
(
desc_
.
type
())
{
case
proto
::
VarDesc
::
SELECTED_ROWS
:
case
proto
::
VarDesc
::
SELECTED_ROWS
:
...
@@ -80,7 +78,7 @@ const proto::TensorDesc &VarDescBind::tensor_desc() const {
...
@@ -80,7 +78,7 @@ const proto::TensorDesc &VarDescBind::tensor_desc() const {
}
}
}
}
proto
::
TensorDesc
*
VarDesc
Bind
::
mutable_tensor_desc
()
{
proto
::
TensorDesc
*
VarDesc
::
mutable_tensor_desc
()
{
PADDLE_ENFORCE
(
desc_
.
has_type
(),
PADDLE_ENFORCE
(
desc_
.
has_type
(),
"invoke MutableTensorDesc must after set type"
);
"invoke MutableTensorDesc must after set type"
);
switch
(
desc_
.
type
())
{
switch
(
desc_
.
type
())
{
...
...
paddle/framework/var_desc.h
浏览文件 @
09189732
...
@@ -53,14 +53,14 @@ inline void VectorToRepeated(const std::vector<bool> &vec,
...
@@ -53,14 +53,14 @@ inline void VectorToRepeated(const std::vector<bool> &vec,
}
}
}
}
class
VarDesc
Bind
{
class
VarDesc
{
public:
public:
explicit
VarDesc
Bind
(
const
std
::
string
&
name
)
{
explicit
VarDesc
(
const
std
::
string
&
name
)
{
desc_
.
set_name
(
name
);
desc_
.
set_name
(
name
);
desc_
.
set_type
(
proto
::
VarDesc
::
LOD_TENSOR
);
desc_
.
set_type
(
proto
::
VarDesc
::
LOD_TENSOR
);
}
}
explicit
VarDesc
Bind
(
const
proto
::
VarDesc
&
desc
)
:
desc_
(
desc
)
{}
explicit
VarDesc
(
const
proto
::
VarDesc
&
desc
)
:
desc_
(
desc
)
{}
proto
::
VarDesc
*
Proto
()
{
return
&
desc_
;
}
proto
::
VarDesc
*
Proto
()
{
return
&
desc_
;
}
...
...
paddle/framework/var_type_inference.h
浏览文件 @
09189732
...
@@ -21,8 +21,7 @@ namespace framework {
...
@@ -21,8 +21,7 @@ namespace framework {
class
VarTypeInference
{
class
VarTypeInference
{
public:
public:
virtual
~
VarTypeInference
()
{}
virtual
~
VarTypeInference
()
{}
virtual
void
operator
()(
const
OpDescBind
&
op_desc
,
virtual
void
operator
()(
const
OpDesc
&
op_desc
,
BlockDesc
*
block
)
const
=
0
;
BlockDescBind
*
block
)
const
=
0
;
};
};
}
// namespace framework
}
// namespace framework
...
...
paddle/framework/var_type_inference_test.cc
浏览文件 @
09189732
...
@@ -33,8 +33,7 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
...
@@ -33,8 +33,7 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class
SumOpVarTypeInference
:
public
VarTypeInference
{
class
SumOpVarTypeInference
:
public
VarTypeInference
{
public:
public:
void
operator
()(
const
OpDescBind
&
op_desc
,
void
operator
()(
const
OpDesc
&
op_desc
,
BlockDesc
*
block
)
const
override
{
BlockDescBind
*
block
)
const
override
{
auto
&
inputs
=
op_desc
.
Input
(
"X"
);
auto
&
inputs
=
op_desc
.
Input
(
"X"
);
auto
default_var_type
=
proto
::
VarDesc
::
SELECTED_ROWS
;
auto
default_var_type
=
proto
::
VarDesc
::
SELECTED_ROWS
;
...
@@ -62,7 +61,7 @@ namespace paddle {
...
@@ -62,7 +61,7 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
TEST
(
InferVarType
,
sum_op
)
{
TEST
(
InferVarType
,
sum_op
)
{
ProgramDesc
Bind
prog
;
ProgramDesc
prog
;
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"sum"
);
op
->
SetType
(
"sum"
);
op
->
SetInput
(
"X"
,
{
"test_a"
,
"test_b"
,
"test_c"
});
op
->
SetInput
(
"X"
,
{
"test_a"
,
"test_b"
,
"test_c"
});
...
@@ -85,7 +84,7 @@ TEST(InferVarType, sum_op) {
...
@@ -85,7 +84,7 @@ TEST(InferVarType, sum_op) {
}
}
TEST
(
InferVarType
,
sum_op_without_infer_var_type
)
{
TEST
(
InferVarType
,
sum_op_without_infer_var_type
)
{
ProgramDesc
Bind
prog
;
ProgramDesc
prog
;
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
auto
*
op
=
prog
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"sum_without_infer_var_type"
);
op
->
SetType
(
"sum_without_infer_var_type"
);
op
->
SetInput
(
"X"
,
{
"test2_a"
,
"test2_b"
,
"test2_c"
});
op
->
SetInput
(
"X"
,
{
"test2_a"
,
"test2_b"
,
"test2_c"
});
...
...
paddle/operators/array_to_lod_tensor_op.cc
浏览文件 @
09189732
...
@@ -149,14 +149,14 @@ class ArrayToLoDTensorGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -149,14 +149,14 @@ class ArrayToLoDTensorGradMaker : public framework::SingleGradOpDescMaker {
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
protected:
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
grad_op
=
new
framework
::
OpDesc
Bind
();
auto
*
grad_op
=
new
framework
::
OpDesc
();
grad_op
->
SetType
(
"lod_tensor_to_array"
);
grad_op
->
SetType
(
"lod_tensor_to_array"
);
grad_op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
grad_op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
grad_op
->
SetInput
(
"RankTable"
,
Input
(
"RankTable"
));
grad_op
->
SetInput
(
"RankTable"
,
Input
(
"RankTable"
));
grad_op
->
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
grad_op
->
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
grad_op
->
SetAttrMap
(
Attrs
());
grad_op
->
SetAttrMap
(
Attrs
());
return
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
(
grad_op
);
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad_op
);
}
}
};
};
...
...
paddle/operators/assign_op.cc
浏览文件 @
09189732
...
@@ -121,12 +121,12 @@ class AssignGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -121,12 +121,12 @@ class AssignGradMaker : public framework::SingleGradOpDescMaker {
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
protected:
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
op
=
new
framework
::
OpDesc
Bind
();
auto
*
op
=
new
framework
::
OpDesc
();
op
->
SetType
(
"assign"
);
op
->
SetType
(
"assign"
);
op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
op
->
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
op
->
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
return
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
(
op
);
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
op
);
}
}
};
};
...
...
paddle/operators/beam_search_decode_op.cc
浏览文件 @
09189732
...
@@ -119,8 +119,8 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase {
...
@@ -119,8 +119,8 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase {
class
BeamSearchDecodeInferVarType
:
public
framework
::
VarTypeInference
{
class
BeamSearchDecodeInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
const
framework
::
OpDesc
Bind
&
op_desc
,
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
Bind
*
block
)
const
override
{
framework
::
BlockDesc
*
block
)
const
override
{
for
(
auto
&
o
:
op_desc
.
Output
(
"SentenceIds"
))
{
for
(
auto
&
o
:
op_desc
.
Output
(
"SentenceIds"
))
{
block
->
Var
(
o
)
->
SetType
(
framework
::
proto
::
VarDesc
::
LOD_TENSOR
);
block
->
Var
(
o
)
->
SetType
(
framework
::
proto
::
VarDesc
::
LOD_TENSOR
);
}
}
...
...
paddle/operators/cast_op.cc
浏览文件 @
09189732
...
@@ -52,14 +52,14 @@ class CastOpGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -52,14 +52,14 @@ class CastOpGradMaker : public framework::SingleGradOpDescMaker {
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
protected:
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
grad
=
new
framework
::
OpDesc
Bind
();
auto
grad
=
new
framework
::
OpDesc
();
grad
->
SetType
(
"cast"
);
grad
->
SetType
(
"cast"
);
grad
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
grad
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
grad
->
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
grad
->
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
grad
->
SetAttr
(
"out_dtype"
,
GetAttr
(
"in_dtype"
));
grad
->
SetAttr
(
"out_dtype"
,
GetAttr
(
"in_dtype"
));
grad
->
SetAttr
(
"in_dtype"
,
GetAttr
(
"out_dtype"
));
grad
->
SetAttr
(
"in_dtype"
,
GetAttr
(
"out_dtype"
));
return
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
(
grad
);
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad
);
}
}
};
};
...
...
paddle/operators/conditional_block_op.cc
浏览文件 @
09189732
...
@@ -65,7 +65,7 @@ class ConditionalBlockOp : public ConditionalOp {
...
@@ -65,7 +65,7 @@ class ConditionalBlockOp : public ConditionalOp {
scopes
->
front
()
=
&
scope
.
NewScope
();
scopes
->
front
()
=
&
scope
.
NewScope
();
auto
&
cur_scope
=
*
scopes
->
front
();
auto
&
cur_scope
=
*
scopes
->
front
();
auto
*
block
=
Attr
<
framework
::
BlockDesc
Bind
*>
(
"sub_block"
);
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
"sub_block"
);
framework
::
Executor
exec
(
dev_ctx
);
framework
::
Executor
exec
(
dev_ctx
);
exec
.
Run
(
*
block
->
Program
(),
&
cur_scope
,
block
->
ID
(),
false
);
exec
.
Run
(
*
block
->
Program
(),
&
cur_scope
,
block
->
ID
(),
false
);
}
}
...
@@ -86,7 +86,7 @@ class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -86,7 +86,7 @@ class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"(std::vector<Scope*>) The step scope of conditional block. To "
"(std::vector<Scope*>) The step scope of conditional block. To "
"unify the conditional block, rnn and while op, the type of "
"unify the conditional block, rnn and while op, the type of "
"scope is std::vector<Scope*>"
);
"scope is std::vector<Scope*>"
);
AddAttr
<
framework
::
BlockDesc
Bind
*>
(
AddAttr
<
framework
::
BlockDesc
*>
(
"sub_block"
,
"The step block of conditional block operator"
);
"sub_block"
,
"The step block of conditional block operator"
);
AddComment
(
R"DOC(Conditional block operator
AddComment
(
R"DOC(Conditional block operator
...
@@ -116,7 +116,7 @@ class ConditionalBlockGradOp : public ConditionalOp {
...
@@ -116,7 +116,7 @@ class ConditionalBlockGradOp : public ConditionalOp {
auto
&
scopes
=
scope_var
->
Get
<
std
::
vector
<
framework
::
Scope
*>>
();
auto
&
scopes
=
scope_var
->
Get
<
std
::
vector
<
framework
::
Scope
*>>
();
framework
::
Scope
&
cur_scope
=
*
scopes
[
0
];
framework
::
Scope
&
cur_scope
=
*
scopes
[
0
];
auto
*
block
=
Attr
<
framework
::
BlockDesc
Bind
*>
(
"sub_block"
);
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
"sub_block"
);
framework
::
Executor
exec
(
dev_ctx
);
framework
::
Executor
exec
(
dev_ctx
);
exec
.
Run
(
*
block
->
Program
(),
&
cur_scope
,
block
->
ID
(),
false
);
exec
.
Run
(
*
block
->
Program
(),
&
cur_scope
,
block
->
ID
(),
false
);
...
@@ -170,8 +170,8 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -170,8 +170,8 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
protected:
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
grad_op
=
new
framework
::
OpDesc
Bind
();
auto
grad_op
=
new
framework
::
OpDesc
();
grad_op
->
SetType
(
"conditional_block_grad"
);
grad_op
->
SetType
(
"conditional_block_grad"
);
grad_op
->
SetInput
(
"X"
,
Input
(
"X"
));
grad_op
->
SetInput
(
"X"
,
Input
(
"X"
));
grad_op
->
SetInput
(
"Params"
,
Input
(
"Params"
));
grad_op
->
SetInput
(
"Params"
,
Input
(
"Params"
));
...
@@ -181,7 +181,7 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -181,7 +181,7 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"Params"
),
InputGrad
(
"Params"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"Params"
),
InputGrad
(
"Params"
));
grad_op
->
SetBlockAttr
(
"sub_block"
,
*
this
->
grad_block_
[
0
]);
grad_op
->
SetBlockAttr
(
"sub_block"
,
*
this
->
grad_block_
[
0
]);
return
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
(
grad_op
);
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad_op
);
}
}
};
};
...
...
paddle/operators/increment_op.cc
浏览文件 @
09189732
...
@@ -93,13 +93,13 @@ class IncrementGradOpMaker : public framework::SingleGradOpDescMaker {
...
@@ -93,13 +93,13 @@ class IncrementGradOpMaker : public framework::SingleGradOpDescMaker {
public:
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
grad_op
=
new
framework
::
OpDesc
Bind
();
auto
*
grad_op
=
new
framework
::
OpDesc
();
grad_op
->
SetType
(
"increment"
);
grad_op
->
SetType
(
"increment"
);
grad_op
->
SetInput
(
"X"
,
Output
(
"Out"
));
grad_op
->
SetInput
(
"X"
,
Output
(
"Out"
));
grad_op
->
SetOutput
(
"Out"
,
Input
(
"X"
));
grad_op
->
SetOutput
(
"Out"
,
Input
(
"X"
));
grad_op
->
SetAttr
(
"step"
,
-
boost
::
get
<
float
>
(
GetAttr
(
"step"
)));
grad_op
->
SetAttr
(
"step"
,
-
boost
::
get
<
float
>
(
GetAttr
(
"step"
)));
return
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
(
grad_op
);
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad_op
);
}
}
};
};
...
...
paddle/operators/lod_rank_table_op.cc
浏览文件 @
09189732
...
@@ -63,8 +63,8 @@ class LoDRankTableInferShape : public framework::InferShapeBase {
...
@@ -63,8 +63,8 @@ class LoDRankTableInferShape : public framework::InferShapeBase {
class
LoDRankTableInferVarType
:
public
framework
::
VarTypeInference
{
class
LoDRankTableInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
const
framework
::
OpDesc
Bind
&
op_desc
,
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
Bind
*
block
)
const
override
{
framework
::
BlockDesc
*
block
)
const
override
{
for
(
auto
&
o
:
op_desc
.
Output
(
"Out"
))
{
for
(
auto
&
o
:
op_desc
.
Output
(
"Out"
))
{
block
->
FindRecursiveOrCreateVar
(
o
)
->
SetType
(
block
->
FindRecursiveOrCreateVar
(
o
)
->
SetType
(
framework
::
proto
::
VarDesc
::
LOD_RANK_TABLE
);
framework
::
proto
::
VarDesc
::
LOD_RANK_TABLE
);
...
...
paddle/operators/lod_tensor_to_array_op.cc
浏览文件 @
09189732
...
@@ -127,8 +127,8 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase {
...
@@ -127,8 +127,8 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase {
class
LoDTensorToArrayInferVarType
:
public
framework
::
VarTypeInference
{
class
LoDTensorToArrayInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
const
framework
::
OpDesc
Bind
&
op_desc
,
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
Bind
*
block
)
const
override
{
framework
::
BlockDesc
*
block
)
const
override
{
for
(
auto
&
out_var
:
op_desc
.
Output
(
"Out"
))
{
for
(
auto
&
out_var
:
op_desc
.
Output
(
"Out"
))
{
block
->
Var
(
out_var
)
->
SetType
(
framework
::
proto
::
VarDesc
::
LOD_TENSOR_ARRAY
);
block
->
Var
(
out_var
)
->
SetType
(
framework
::
proto
::
VarDesc
::
LOD_TENSOR_ARRAY
);
}
}
...
@@ -140,14 +140,14 @@ class LoDTensorToArrayGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -140,14 +140,14 @@ class LoDTensorToArrayGradMaker : public framework::SingleGradOpDescMaker {
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
protected:
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
grad_op
=
new
framework
::
OpDesc
Bind
();
auto
*
grad_op
=
new
framework
::
OpDesc
();
grad_op
->
SetType
(
"array_to_lod_tensor"
);
grad_op
->
SetType
(
"array_to_lod_tensor"
);
grad_op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
grad_op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
grad_op
->
SetInput
(
"RankTable"
,
Input
(
"RankTable"
));
grad_op
->
SetInput
(
"RankTable"
,
Input
(
"RankTable"
));
grad_op
->
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
grad_op
->
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
grad_op
->
SetAttrMap
(
Attrs
());
grad_op
->
SetAttrMap
(
Attrs
());
return
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
(
grad_op
);
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad_op
);
}
}
};
};
...
...
paddle/operators/lookup_table_op.cc
浏览文件 @
09189732
...
@@ -108,8 +108,8 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
...
@@ -108,8 +108,8 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
class
LookupTableOpGradVarTypeInference
:
public
framework
::
VarTypeInference
{
class
LookupTableOpGradVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
const
framework
::
OpDesc
Bind
&
op_desc
,
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
Bind
*
block
)
const
override
{
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
op_desc
.
Output
(
framework
::
GradVarName
(
"W"
)).
front
();
auto
out_var_name
=
op_desc
.
Output
(
framework
::
GradVarName
(
"W"
)).
front
();
auto
attr
=
op_desc
.
GetAttr
(
"is_sparse"
);
auto
attr
=
op_desc
.
GetAttr
(
"is_sparse"
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
bool
is_sparse
=
boost
::
get
<
bool
>
(
attr
);
...
...
paddle/operators/mean_op.cc
浏览文件 @
09189732
...
@@ -60,13 +60,13 @@ class MeanGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -60,13 +60,13 @@ class MeanGradMaker : public framework::SingleGradOpDescMaker {
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
protected:
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
grad_op
=
new
framework
::
OpDesc
Bind
();
auto
*
grad_op
=
new
framework
::
OpDesc
();
grad_op
->
SetType
(
"mean_grad"
);
grad_op
->
SetType
(
"mean_grad"
);
grad_op
->
SetInput
(
"X"
,
Input
(
"X"
));
grad_op
->
SetInput
(
"X"
,
Input
(
"X"
));
grad_op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
grad_op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
return
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
(
grad_op
);
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad_op
);
}
}
};
};
...
...
paddle/operators/merge_lod_tensor_op.cc
浏览文件 @
09189732
...
@@ -161,15 +161,15 @@ class MergeLoDTensorGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -161,15 +161,15 @@ class MergeLoDTensorGradMaker : public framework::SingleGradOpDescMaker {
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
protected:
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
grad_op
=
new
framework
::
OpDesc
Bind
();
auto
*
grad_op
=
new
framework
::
OpDesc
();
grad_op
->
SetType
(
"split_lod_tensor"
);
grad_op
->
SetType
(
"split_lod_tensor"
);
grad_op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
grad_op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
grad_op
->
SetInput
(
"Mask"
,
Input
(
"Mask"
));
grad_op
->
SetInput
(
"Mask"
,
Input
(
"Mask"
));
grad_op
->
SetOutput
(
"OutTrue"
,
InputGrad
(
"InTrue"
));
grad_op
->
SetOutput
(
"OutTrue"
,
InputGrad
(
"InTrue"
));
grad_op
->
SetOutput
(
"OutFalse"
,
InputGrad
(
"InFalse"
));
grad_op
->
SetOutput
(
"OutFalse"
,
InputGrad
(
"InFalse"
));
grad_op
->
SetAttrMap
(
Attrs
());
grad_op
->
SetAttrMap
(
Attrs
());
return
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
(
grad_op
);
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad_op
);
}
}
};
};
...
...
paddle/operators/minus_op.cc
浏览文件 @
09189732
...
@@ -70,12 +70,11 @@ class MinusGradMaker : public framework::GradOpDescMakerBase {
...
@@ -70,12 +70,11 @@ class MinusGradMaker : public framework::GradOpDescMakerBase {
public:
public:
using
framework
::
GradOpDescMakerBase
::
GradOpDescMakerBase
;
using
framework
::
GradOpDescMakerBase
::
GradOpDescMakerBase
;
std
::
vector
<
std
::
unique_ptr
<
framework
::
OpDescBind
>>
operator
()()
std
::
vector
<
std
::
unique_ptr
<
framework
::
OpDesc
>>
operator
()()
const
override
{
const
override
{
std
::
vector
<
std
::
unique_ptr
<
framework
::
OpDesc
>>
ops
;
std
::
vector
<
std
::
unique_ptr
<
framework
::
OpDescBind
>>
ops
;
auto
x_g
=
InputGrad
(
"X"
);
auto
x_g
=
InputGrad
(
"X"
);
if
(
!
x_g
.
empty
())
{
if
(
!
x_g
.
empty
())
{
auto
*
x_g_op
=
new
framework
::
OpDesc
Bind
();
auto
*
x_g_op
=
new
framework
::
OpDesc
();
x_g_op
->
SetType
(
"scale"
);
x_g_op
->
SetType
(
"scale"
);
x_g_op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
x_g_op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
x_g_op
->
SetOutput
(
"Out"
,
x_g
);
x_g_op
->
SetOutput
(
"Out"
,
x_g
);
...
@@ -85,7 +84,7 @@ class MinusGradMaker : public framework::GradOpDescMakerBase {
...
@@ -85,7 +84,7 @@ class MinusGradMaker : public framework::GradOpDescMakerBase {
auto
y_g
=
InputGrad
(
"Y"
);
auto
y_g
=
InputGrad
(
"Y"
);
if
(
!
y_g
.
empty
())
{
if
(
!
y_g
.
empty
())
{
auto
*
y_g_op
=
new
framework
::
OpDesc
Bind
();
auto
*
y_g_op
=
new
framework
::
OpDesc
();
y_g_op
->
SetType
(
"scale"
);
y_g_op
->
SetType
(
"scale"
);
y_g_op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
y_g_op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
y_g_op
->
SetOutput
(
"Out"
,
y_g
);
y_g_op
->
SetOutput
(
"Out"
,
y_g
);
...
...
paddle/operators/nccl_op_test.cu.cc
浏览文件 @
09189732
...
@@ -65,7 +65,7 @@ class NCCLTester : public ::testing::Test {
...
@@ -65,7 +65,7 @@ class NCCLTester : public ::testing::Test {
}
}
void
NCCLInitOp
()
{
void
NCCLInitOp
()
{
std
::
unique_ptr
<
f
::
OpDesc
Bind
>
op1
(
new
f
::
OpDescBind
);
std
::
unique_ptr
<
f
::
OpDesc
>
op1
(
new
f
::
OpDesc
);
op1
->
SetType
(
"ncclInit"
);
op1
->
SetType
(
"ncclInit"
);
op1
->
SetOutput
(
"Communicator"
,
{
"comm"
});
op1
->
SetOutput
(
"Communicator"
,
{
"comm"
});
...
@@ -81,10 +81,9 @@ class NCCLTester : public ::testing::Test {
...
@@ -81,10 +81,9 @@ class NCCLTester : public ::testing::Test {
}
}
template
<
class
T
>
template
<
class
T
>
void
PerThreadProgram
(
int
gpu_id
,
const
f
::
OpDescBind
&
op_desc
,
void
PerThreadProgram
(
int
gpu_id
,
const
f
::
OpDesc
&
op_desc
,
f
::
Scope
*
scope
)
{
f
::
Scope
*
scope
)
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
mu
);
std
::
unique_lock
<
std
::
mutex
>
lk
(
mu
);
const
f
::
OpDesc
Bind
*
op1
=
&
op_desc
;
const
f
::
OpDesc
*
op1
=
&
op_desc
;
p
::
GPUPlace
place
(
gpu_id
);
p
::
GPUPlace
place
(
gpu_id
);
auto
&
ctx
=
dev_ctxs
.
at
(
gpu_id
);
auto
&
ctx
=
dev_ctxs
.
at
(
gpu_id
);
...
@@ -125,7 +124,7 @@ class NCCLTester : public ::testing::Test {
...
@@ -125,7 +124,7 @@ class NCCLTester : public ::testing::Test {
// ncclInitOp with desc
// ncclInitOp with desc
TEST
(
NCCL
,
ncclInitOp
)
{
TEST
(
NCCL
,
ncclInitOp
)
{
std
::
unique_ptr
<
f
::
OpDesc
Bind
>
op_desc
(
new
f
::
OpDescBind
);
std
::
unique_ptr
<
f
::
OpDesc
>
op_desc
(
new
f
::
OpDesc
);
op_desc
->
SetType
(
"ncclInit"
);
op_desc
->
SetType
(
"ncclInit"
);
op_desc
->
SetOutput
(
"Communicator"
,
{
"x1"
});
op_desc
->
SetOutput
(
"Communicator"
,
{
"x1"
});
...
@@ -145,7 +144,7 @@ TEST(NCCL, ncclInitOp) {
...
@@ -145,7 +144,7 @@ TEST(NCCL, ncclInitOp) {
// ncclAllReduceOp with desc
// ncclAllReduceOp with desc
TEST_F
(
NCCLTester
,
ncclAllReduceOp
)
{
TEST_F
(
NCCLTester
,
ncclAllReduceOp
)
{
std
::
unique_ptr
<
f
::
OpDesc
Bind
>
op2
(
new
f
::
OpDescBind
);
std
::
unique_ptr
<
f
::
OpDesc
>
op2
(
new
f
::
OpDesc
);
op2
->
SetType
(
"ncclAllReduce"
);
op2
->
SetType
(
"ncclAllReduce"
);
op2
->
SetInput
(
"X"
,
{
"st"
});
op2
->
SetInput
(
"X"
,
{
"st"
});
op2
->
SetInput
(
"Communicator"
,
{
"comm"
});
op2
->
SetInput
(
"Communicator"
,
{
"comm"
});
...
@@ -192,7 +191,7 @@ TEST_F(NCCLTester, ncclAllReduceOp) {
...
@@ -192,7 +191,7 @@ TEST_F(NCCLTester, ncclAllReduceOp) {
// ncclReduceOp with desc
// ncclReduceOp with desc
TEST_F
(
NCCLTester
,
ncclReduceOp
)
{
TEST_F
(
NCCLTester
,
ncclReduceOp
)
{
std
::
unique_ptr
<
f
::
OpDesc
Bind
>
op2
(
new
f
::
OpDescBind
);
std
::
unique_ptr
<
f
::
OpDesc
>
op2
(
new
f
::
OpDesc
);
const
int
kRoot
=
0
;
const
int
kRoot
=
0
;
op2
->
SetType
(
"ncclReduce"
);
op2
->
SetType
(
"ncclReduce"
);
op2
->
SetInput
(
"X"
,
{
"st"
});
op2
->
SetInput
(
"X"
,
{
"st"
});
...
@@ -240,7 +239,7 @@ TEST_F(NCCLTester, ncclReduceOp) {
...
@@ -240,7 +239,7 @@ TEST_F(NCCLTester, ncclReduceOp) {
// ncclBcastOp with desc
// ncclBcastOp with desc
TEST_F
(
NCCLTester
,
ncclBcastOp
)
{
TEST_F
(
NCCLTester
,
ncclBcastOp
)
{
std
::
unique_ptr
<
f
::
OpDesc
Bind
>
op2
(
new
f
::
OpDescBind
);
std
::
unique_ptr
<
f
::
OpDesc
>
op2
(
new
f
::
OpDesc
);
const
int
kRoot
=
5
;
const
int
kRoot
=
5
;
op2
->
SetType
(
"ncclBcast"
);
op2
->
SetType
(
"ncclBcast"
);
op2
->
SetInput
(
"X"
,
{
"st"
});
op2
->
SetInput
(
"X"
,
{
"st"
});
...
...
paddle/operators/pad_op.cc
浏览文件 @
09189732
...
@@ -116,14 +116,14 @@ class PadOpGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -116,14 +116,14 @@ class PadOpGradMaker : public framework::SingleGradOpDescMaker {
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
protected:
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
bind
=
new
framework
::
OpDesc
Bind
();
auto
*
bind
=
new
framework
::
OpDesc
();
bind
->
SetInput
(
"X"
,
Input
(
"X"
));
bind
->
SetInput
(
"X"
,
Input
(
"X"
));
bind
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
bind
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
bind
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
bind
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
bind
->
SetAttrMap
(
Attrs
());
bind
->
SetAttrMap
(
Attrs
());
bind
->
SetType
(
"pad_grad"
);
bind
->
SetType
(
"pad_grad"
);
return
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
(
bind
);
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
bind
);
}
}
};
};
...
...
paddle/operators/recurrent_op.cc
浏览文件 @
09189732
...
@@ -234,7 +234,7 @@ class RecurrentOp : public RecurrentBase {
...
@@ -234,7 +234,7 @@ class RecurrentOp : public RecurrentBase {
auto
reverse
=
Attr
<
bool
>
(
kReverse
);
auto
reverse
=
Attr
<
bool
>
(
kReverse
);
framework
::
Executor
executor
(
dev_ctx
);
framework
::
Executor
executor
(
dev_ctx
);
auto
*
block
=
Attr
<
framework
::
BlockDesc
Bind
*>
(
kStepBlock
);
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
kStepBlock
);
auto
*
program
=
block
->
Program
();
auto
*
program
=
block
->
Program
();
for
(
size_t
i
=
0
;
i
<
seq_len
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
seq_len
;
++
i
)
{
...
@@ -317,7 +317,7 @@ class RecurrentGradOp : public RecurrentBase {
...
@@ -317,7 +317,7 @@ class RecurrentGradOp : public RecurrentBase {
auto
reverse
=
Attr
<
bool
>
(
kReverse
);
auto
reverse
=
Attr
<
bool
>
(
kReverse
);
framework
::
Executor
executor
(
dev_ctx
);
framework
::
Executor
executor
(
dev_ctx
);
auto
*
block
=
Attr
<
framework
::
BlockDesc
Bind
*>
(
kStepBlock
);
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
kStepBlock
);
auto
*
program
=
block
->
Program
();
auto
*
program
=
block
->
Program
();
for
(
size_t
step_id
=
0
;
step_id
<
seq_len
;
++
step_id
)
{
for
(
size_t
step_id
=
0
;
step_id
<
seq_len
;
++
step_id
)
{
...
@@ -522,8 +522,7 @@ The ex-state means the state value in the ex-timestep or the previous time step
...
@@ -522,8 +522,7 @@ The ex-state means the state value in the ex-timestep or the previous time step
string
::
Sprintf
(
string
::
Sprintf
(
"The state variable names. [%s, %s, %s] must be the same order"
,
"The state variable names. [%s, %s, %s] must be the same order"
,
kExStates
,
kStates
,
kInitStateGrads
));
kExStates
,
kStates
,
kInitStateGrads
));
AddAttr
<
framework
::
BlockDescBind
*>
(
kStepBlock
,
AddAttr
<
framework
::
BlockDesc
*>
(
kStepBlock
,
"The step block inside RNN"
);
"The step block inside RNN"
);
AddAttr
<
bool
>
(
kReverse
,
R"DOC(Calculate RNN reversely or not.
AddAttr
<
bool
>
(
kReverse
,
R"DOC(Calculate RNN reversely or not.
By default reverse=False
By default reverse=False
...
@@ -565,8 +564,8 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker {
...
@@ -565,8 +564,8 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker {
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
protected:
virtual
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
Apply
()
const
{
virtual
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
{
auto
*
grad
=
new
framework
::
OpDesc
Bind
();
auto
*
grad
=
new
framework
::
OpDesc
();
grad
->
SetType
(
"recurrent_grad"
);
grad
->
SetType
(
"recurrent_grad"
);
for
(
auto
&
input_param
:
this
->
InputNames
())
{
for
(
auto
&
input_param
:
this
->
InputNames
())
{
grad
->
SetInput
(
input_param
,
this
->
Input
(
input_param
));
grad
->
SetInput
(
input_param
,
this
->
Input
(
input_param
));
...
@@ -588,7 +587,7 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker {
...
@@ -588,7 +587,7 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker {
grad
->
SetAttrMap
(
this
->
Attrs
());
grad
->
SetAttrMap
(
this
->
Attrs
());
grad
->
SetBlockAttr
(
kStepBlock
,
*
grad_block_
[
0
]);
grad
->
SetBlockAttr
(
kStepBlock
,
*
grad_block_
[
0
]);
return
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
(
grad
);
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad
);
}
}
};
};
...
...
paddle/operators/scale_op.cc
浏览文件 @
09189732
...
@@ -58,13 +58,13 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -58,13 +58,13 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker {
public:
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
grad_op
=
new
framework
::
OpDesc
Bind
();
auto
*
grad_op
=
new
framework
::
OpDesc
();
grad_op
->
SetType
(
"scale"
);
grad_op
->
SetType
(
"scale"
);
grad_op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
grad_op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
grad_op
->
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
grad_op
->
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
grad_op
->
SetAttr
(
"scale"
,
GetAttr
(
"scale"
));
grad_op
->
SetAttr
(
"scale"
,
GetAttr
(
"scale"
));
return
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
(
grad_op
);
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad_op
);
}
}
};
};
...
...
paddle/operators/shrink_rnn_memory_op.cc
浏览文件 @
09189732
...
@@ -136,14 +136,14 @@ class ShrinkRNNGradOpMaker : public framework::SingleGradOpDescMaker {
...
@@ -136,14 +136,14 @@ class ShrinkRNNGradOpMaker : public framework::SingleGradOpDescMaker {
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
protected:
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
op
=
new
framework
::
OpDesc
Bind
();
auto
*
op
=
new
framework
::
OpDesc
();
op
->
SetType
(
"shrink_rnn_memory_grad"
);
op
->
SetType
(
"shrink_rnn_memory_grad"
);
op
->
SetInput
(
"X"
,
Input
(
"X"
));
op
->
SetInput
(
"X"
,
Input
(
"X"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
op
->
SetAttrMap
(
Attrs
());
op
->
SetAttrMap
(
Attrs
());
return
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
(
op
);
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
op
);
}
}
};
};
...
...
paddle/operators/sign_op.cc
浏览文件 @
09189732
...
@@ -50,13 +50,13 @@ class SignGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -50,13 +50,13 @@ class SignGradMaker : public framework::SingleGradOpDescMaker {
public:
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
grad_op
=
new
framework
::
OpDesc
Bind
();
auto
*
grad_op
=
new
framework
::
OpDesc
();
grad_op
->
SetType
(
"scale"
);
grad_op
->
SetType
(
"scale"
);
grad_op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
grad_op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
grad_op
->
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
grad_op
->
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
grad_op
->
SetAttr
(
"scale"
,
0.0
f
);
grad_op
->
SetAttr
(
"scale"
,
0.0
f
);
return
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
(
grad_op
);
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad_op
);
}
}
};
};
...
...
paddle/operators/softmax_with_cross_entropy_op.cc
浏览文件 @
09189732
...
@@ -173,8 +173,8 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -173,8 +173,8 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker {
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
protected:
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
grad_op
=
new
framework
::
OpDesc
Bind
();
auto
*
grad_op
=
new
framework
::
OpDesc
();
grad_op
->
SetType
(
"softmax_with_cross_entropy_grad"
);
grad_op
->
SetType
(
"softmax_with_cross_entropy_grad"
);
grad_op
->
SetInput
(
"Label"
,
Input
(
"Label"
));
grad_op
->
SetInput
(
"Label"
,
Input
(
"Label"
));
grad_op
->
SetInput
(
"Softmax"
,
Output
(
"Softmax"
));
grad_op
->
SetInput
(
"Softmax"
,
Output
(
"Softmax"
));
...
@@ -183,7 +183,7 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -183,7 +183,7 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker {
grad_op
->
SetInput
(
framework
::
GradVarName
(
"Loss"
),
OutputGrad
(
"Loss"
));
grad_op
->
SetInput
(
framework
::
GradVarName
(
"Loss"
),
OutputGrad
(
"Loss"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"Logits"
),
InputGrad
(
"Logits"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"Logits"
),
InputGrad
(
"Logits"
));
grad_op
->
SetAttrMap
(
Attrs
());
grad_op
->
SetAttrMap
(
Attrs
());
return
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
(
grad_op
);
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad_op
);
}
}
};
};
...
...
paddle/operators/split_lod_tensor_op.cc
浏览文件 @
09189732
...
@@ -163,8 +163,8 @@ class SplitLoDTensorArrayGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -163,8 +163,8 @@ class SplitLoDTensorArrayGradMaker : public framework::SingleGradOpDescMaker {
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
protected:
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
grad_op
=
new
framework
::
OpDesc
Bind
();
auto
*
grad_op
=
new
framework
::
OpDesc
();
grad_op
->
SetType
(
"merge_lod_tensor"
);
grad_op
->
SetType
(
"merge_lod_tensor"
);
grad_op
->
SetInput
(
"InTrue"
,
OutputGrad
(
"OutTrue"
));
grad_op
->
SetInput
(
"InTrue"
,
OutputGrad
(
"OutTrue"
));
grad_op
->
SetInput
(
"InFalse"
,
OutputGrad
(
"OutFalse"
));
grad_op
->
SetInput
(
"InFalse"
,
OutputGrad
(
"OutFalse"
));
...
@@ -172,7 +172,7 @@ class SplitLoDTensorArrayGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -172,7 +172,7 @@ class SplitLoDTensorArrayGradMaker : public framework::SingleGradOpDescMaker {
grad_op
->
SetInput
(
"X"
,
Input
(
"X"
));
grad_op
->
SetInput
(
"X"
,
Input
(
"X"
));
grad_op
->
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
grad_op
->
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
grad_op
->
SetAttrMap
(
Attrs
());
grad_op
->
SetAttrMap
(
Attrs
());
return
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
(
grad_op
);
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad_op
);
}
}
};
};
...
...
paddle/operators/split_op.cc
浏览文件 @
09189732
...
@@ -108,13 +108,13 @@ class SplitGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -108,13 +108,13 @@ class SplitGradMaker : public framework::SingleGradOpDescMaker {
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
protected:
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
op
=
new
framework
::
OpDesc
Bind
();
auto
op
=
new
framework
::
OpDesc
();
op
->
SetType
(
"concat"
);
op
->
SetType
(
"concat"
);
op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
op
->
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
op
->
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
op
->
SetAttrMap
(
Attrs
());
op
->
SetAttrMap
(
Attrs
());
return
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
(
op
);
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
op
);
}
}
};
};
...
...
paddle/operators/sum_op.cc
浏览文件 @
09189732
...
@@ -115,8 +115,8 @@ the LoD information with the first input.
...
@@ -115,8 +115,8 @@ the LoD information with the first input.
class
SumOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
SumOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
const
framework
::
OpDesc
Bind
&
op_desc
,
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
Bind
*
block
)
const
override
{
framework
::
BlockDesc
*
block
)
const
override
{
auto
&
inputs
=
op_desc
.
Input
(
"X"
);
auto
&
inputs
=
op_desc
.
Input
(
"X"
);
auto
var_type
=
framework
::
proto
::
VarDesc
::
SELECTED_ROWS
;
auto
var_type
=
framework
::
proto
::
VarDesc
::
SELECTED_ROWS
;
...
@@ -169,20 +169,19 @@ class SumGradMaker : public framework::GradOpDescMakerBase {
...
@@ -169,20 +169,19 @@ class SumGradMaker : public framework::GradOpDescMakerBase {
public:
public:
using
framework
::
GradOpDescMakerBase
::
GradOpDescMakerBase
;
using
framework
::
GradOpDescMakerBase
::
GradOpDescMakerBase
;
std
::
vector
<
std
::
unique_ptr
<
framework
::
OpDescBind
>>
operator
()()
std
::
vector
<
std
::
unique_ptr
<
framework
::
OpDesc
>>
operator
()()
const
override
{
const
override
{
auto
x_grads
=
InputGrad
(
"X"
);
auto
x_grads
=
InputGrad
(
"X"
);
std
::
vector
<
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>>
grad_ops
;
std
::
vector
<
std
::
unique_ptr
<
framework
::
OpDesc
>>
grad_ops
;
grad_ops
.
reserve
(
x_grads
.
size
());
grad_ops
.
reserve
(
x_grads
.
size
());
auto
og
=
OutputGrad
(
"Out"
);
auto
og
=
OutputGrad
(
"Out"
);
std
::
transform
(
x_grads
.
begin
(),
x_grads
.
end
(),
std
::
back_inserter
(
grad_ops
),
std
::
transform
(
x_grads
.
begin
(),
x_grads
.
end
(),
std
::
back_inserter
(
grad_ops
),
[
&
og
](
const
std
::
string
&
x_grad
)
{
[
&
og
](
const
std
::
string
&
x_grad
)
{
auto
*
grad_op
=
new
framework
::
OpDesc
Bind
();
auto
*
grad_op
=
new
framework
::
OpDesc
();
grad_op
->
SetType
(
"scale"
);
grad_op
->
SetType
(
"scale"
);
grad_op
->
SetInput
(
"X"
,
og
);
grad_op
->
SetInput
(
"X"
,
og
);
grad_op
->
SetOutput
(
"Out"
,
{
x_grad
});
grad_op
->
SetOutput
(
"Out"
,
{
x_grad
});
grad_op
->
SetAttr
(
"scale"
,
1.0
f
);
grad_op
->
SetAttr
(
"scale"
,
1.0
f
);
return
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
(
grad_op
);
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad_op
);
});
});
return
grad_ops
;
return
grad_ops
;
}
}
...
...
paddle/operators/tensor_array_read_write_op.cc
浏览文件 @
09189732
...
@@ -96,8 +96,8 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
...
@@ -96,8 +96,8 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
class
WriteToArrayInferVarType
:
public
framework
::
VarTypeInference
{
class
WriteToArrayInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
const
framework
::
OpDesc
Bind
&
op_desc
,
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
Bind
*
block
)
const
override
{
framework
::
BlockDesc
*
block
)
const
override
{
auto
x_name
=
op_desc
.
Input
(
"X"
)[
0
];
auto
x_name
=
op_desc
.
Input
(
"X"
)[
0
];
auto
out_name
=
op_desc
.
Output
(
"Out"
)[
0
];
auto
out_name
=
op_desc
.
Output
(
"Out"
)[
0
];
VLOG
(
10
)
<<
"Set Variable "
<<
out_name
<<
" as LOD_TENSOR_ARRAY"
;
VLOG
(
10
)
<<
"Set Variable "
<<
out_name
<<
" as LOD_TENSOR_ARRAY"
;
...
@@ -175,14 +175,14 @@ class WriteToArrayGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -175,14 +175,14 @@ class WriteToArrayGradMaker : public framework::SingleGradOpDescMaker {
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
protected:
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
grad_op
=
new
framework
::
OpDesc
Bind
();
auto
*
grad_op
=
new
framework
::
OpDesc
();
grad_op
->
SetType
(
"read_from_array"
);
grad_op
->
SetType
(
"read_from_array"
);
grad_op
->
SetInput
(
"I"
,
Input
(
"I"
));
grad_op
->
SetInput
(
"I"
,
Input
(
"I"
));
grad_op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
grad_op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
grad_op
->
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
grad_op
->
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
grad_op
->
SetAttrMap
(
Attrs
());
grad_op
->
SetAttrMap
(
Attrs
());
return
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
(
grad_op
);
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad_op
);
}
}
};
};
...
@@ -191,14 +191,14 @@ class ReadFromArrayGradMaker : public framework::SingleGradOpDescMaker {
...
@@ -191,14 +191,14 @@ class ReadFromArrayGradMaker : public framework::SingleGradOpDescMaker {
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
protected:
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
grad_op
=
new
framework
::
OpDesc
Bind
();
auto
*
grad_op
=
new
framework
::
OpDesc
();
grad_op
->
SetType
(
"write_to_array"
);
grad_op
->
SetType
(
"write_to_array"
);
grad_op
->
SetInput
(
"I"
,
Input
(
"I"
));
grad_op
->
SetInput
(
"I"
,
Input
(
"I"
));
grad_op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
grad_op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
grad_op
->
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
grad_op
->
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
grad_op
->
SetAttrMap
(
Attrs
());
grad_op
->
SetAttrMap
(
Attrs
());
return
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
(
grad_op
);
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad_op
);
}
}
};
};
...
...
paddle/operators/while_op.cc
浏览文件 @
09189732
...
@@ -46,7 +46,7 @@ class WhileOp : public framework::OperatorBase {
...
@@ -46,7 +46,7 @@ class WhileOp : public framework::OperatorBase {
PADDLE_ENFORCE_EQ
(
cond
.
dims
(),
paddle
::
framework
::
make_ddim
({
1
}));
PADDLE_ENFORCE_EQ
(
cond
.
dims
(),
paddle
::
framework
::
make_ddim
({
1
}));
framework
::
Executor
executor
(
dev_ctx
);
framework
::
Executor
executor
(
dev_ctx
);
auto
*
block
=
Attr
<
framework
::
BlockDesc
Bind
*>
(
kStepBlock
);
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
kStepBlock
);
auto
*
program
=
block
->
Program
();
auto
*
program
=
block
->
Program
();
auto
step_scopes
=
auto
step_scopes
=
...
@@ -82,7 +82,7 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -82,7 +82,7 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
"(StepScopeVar) A vector of local scope, which size equals the "
"(StepScopeVar) A vector of local scope, which size equals the "
"step number of While Op. The i'th scope storages temporary "
"step number of While Op. The i'th scope storages temporary "
"variables generated in the i'th step."
);
"variables generated in the i'th step."
);
AddAttr
<
framework
::
BlockDesc
Bind
*>
(
kStepBlock
,
AddAttr
<
framework
::
BlockDesc
*>
(
kStepBlock
,
"The step block inside WhileOp"
);
"The step block inside WhileOp"
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
)DOC"
);
)DOC"
);
...
@@ -99,7 +99,7 @@ class WhileGradOp : public framework::OperatorBase {
...
@@ -99,7 +99,7 @@ class WhileGradOp : public framework::OperatorBase {
void
Run
(
const
framework
::
Scope
&
scope
,
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
framework
::
Executor
executor
(
dev_ctx
);
framework
::
Executor
executor
(
dev_ctx
);
auto
*
block
=
Attr
<
framework
::
BlockDesc
Bind
*>
(
kStepBlock
);
auto
*
block
=
Attr
<
framework
::
BlockDesc
*>
(
kStepBlock
);
auto
*
program
=
block
->
Program
();
auto
*
program
=
block
->
Program
();
auto
*
step_scopes
=
auto
*
step_scopes
=
...
@@ -209,8 +209,8 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
...
@@ -209,8 +209,8 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
protected:
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
Apply
()
const
override
{
std
::
unique_ptr
<
framework
::
OpDesc
>
Apply
()
const
override
{
auto
*
grad
=
new
framework
::
OpDesc
Bind
();
auto
*
grad
=
new
framework
::
OpDesc
();
grad
->
SetType
(
"while_grad"
);
grad
->
SetType
(
"while_grad"
);
grad
->
SetInput
(
kParameters
,
Input
(
kParameters
));
grad
->
SetInput
(
kParameters
,
Input
(
kParameters
));
...
@@ -279,14 +279,14 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
...
@@ -279,14 +279,14 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
// while operator could be renamed.
// while operator could be renamed.
grad
->
SetAttr
(
"original_output_grad"
,
extra_inputs_list
);
grad
->
SetAttr
(
"original_output_grad"
,
extra_inputs_list
);
return
std
::
unique_ptr
<
framework
::
OpDesc
Bind
>
(
grad
);
return
std
::
unique_ptr
<
framework
::
OpDesc
>
(
grad
);
}
}
};
};
class
WhileGradOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
WhileGradOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
const
framework
::
OpDesc
Bind
&
op_desc
,
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
Bind
*
block
)
const
override
{
framework
::
BlockDesc
*
block
)
const
override
{
auto
p_names
=
op_desc
.
Input
(
kParameters
);
auto
p_names
=
op_desc
.
Input
(
kParameters
);
auto
pg_names
=
op_desc
.
Output
(
framework
::
GradVarName
(
kParameters
));
auto
pg_names
=
op_desc
.
Output
(
framework
::
GradVarName
(
kParameters
));
...
...
paddle/pybind/protobuf.cc
浏览文件 @
09189732
...
@@ -108,21 +108,21 @@ static py::bytes SerializeMessage(T &self) {
...
@@ -108,21 +108,21 @@ static py::bytes SerializeMessage(T &self) {
// Bind Methods
// Bind Methods
void
BindProgramDesc
(
py
::
module
&
m
)
{
void
BindProgramDesc
(
py
::
module
&
m
)
{
py
::
class_
<
ProgramDesc
Bind
>
(
m
,
"ProgramDesc"
,
""
)
py
::
class_
<
ProgramDesc
>
(
m
,
"ProgramDesc"
,
""
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<>
())
.
def
(
"__init__"
,
.
def
(
"__init__"
,
[](
ProgramDesc
Bind
&
self
,
const
ProgramDescBind
&
other
)
{
[](
ProgramDesc
&
self
,
const
ProgramDesc
&
other
)
{
new
(
&
self
)
ProgramDesc
Bind
(
other
);
new
(
&
self
)
ProgramDesc
(
other
);
})
})
.
def
(
"__init__"
,
.
def
(
"__init__"
,
[](
ProgramDesc
Bind
&
self
,
const
py
::
bytes
&
binary_str
)
{
[](
ProgramDesc
&
self
,
const
py
::
bytes
&
binary_str
)
{
std
::
string
str
(
binary_str
);
std
::
string
str
(
binary_str
);
new
(
&
self
)
ProgramDesc
Bind
(
str
);
new
(
&
self
)
ProgramDesc
(
str
);
})
})
.
def
(
"append_block"
,
&
ProgramDesc
Bind
::
AppendBlock
,
.
def
(
"append_block"
,
&
ProgramDesc
::
AppendBlock
,
py
::
return_value_policy
::
reference
)
py
::
return_value_policy
::
reference
)
.
def
(
"append_backward"
,
.
def
(
"append_backward"
,
[](
ProgramDesc
Bind
&
program_desc
,
const
VarDescBind
&
target
,
[](
ProgramDesc
&
program_desc
,
const
VarDesc
&
target
,
const
std
::
unordered_set
<
std
::
string
>
&
no_grad_vars
)
{
const
std
::
unordered_set
<
std
::
string
>
&
no_grad_vars
)
{
ParamGradInfoMap
param_grad_map
=
ParamGradInfoMap
param_grad_map
=
AppendBackward
(
program_desc
,
target
,
no_grad_vars
);
AppendBackward
(
program_desc
,
target
,
no_grad_vars
);
...
@@ -138,12 +138,12 @@ void BindProgramDesc(py::module &m) {
...
@@ -138,12 +138,12 @@ void BindProgramDesc(py::module &m) {
}
}
return
retv
;
return
retv
;
})
})
.
def
(
"block"
,
&
ProgramDesc
Bind
::
MutableBlock
,
.
def
(
"block"
,
&
ProgramDesc
::
MutableBlock
,
py
::
return_value_policy
::
reference
)
py
::
return_value_policy
::
reference
)
.
def
(
"num_blocks"
,
&
ProgramDesc
Bind
::
Size
)
.
def
(
"num_blocks"
,
&
ProgramDesc
::
Size
)
.
def
(
"serialize_to_string"
,
SerializeMessage
<
ProgramDesc
Bind
>
)
.
def
(
"serialize_to_string"
,
SerializeMessage
<
ProgramDesc
>
)
.
def
(
"parse_from_string"
,
.
def
(
"parse_from_string"
,
[](
ProgramDesc
Bind
&
program_desc
,
const
std
::
string
&
data
)
{
[](
ProgramDesc
&
program_desc
,
const
std
::
string
&
data
)
{
proto
::
ProgramDesc
*
desc
=
program_desc
.
Proto
();
proto
::
ProgramDesc
*
desc
=
program_desc
.
Proto
();
PADDLE_ENFORCE
(
desc
->
ParseFromString
(
data
),
PADDLE_ENFORCE
(
desc
->
ParseFromString
(
data
),
"Fail to parse ProgramDesc from string. This could "
"Fail to parse ProgramDesc from string. This could "
...
@@ -152,35 +152,34 @@ void BindProgramDesc(py::module &m) {
...
@@ -152,35 +152,34 @@ void BindProgramDesc(py::module &m) {
}
}
void
BindBlockDesc
(
py
::
module
&
m
)
{
void
BindBlockDesc
(
py
::
module
&
m
)
{
py
::
class_
<
BlockDesc
Bind
>
(
m
,
"BlockDesc"
,
""
)
py
::
class_
<
BlockDesc
>
(
m
,
"BlockDesc"
,
""
)
.
def_property_readonly
(
"id"
,
&
BlockDesc
Bind
::
ID
)
.
def_property_readonly
(
"id"
,
&
BlockDesc
::
ID
)
.
def_property_readonly
(
"parent"
,
&
BlockDesc
Bind
::
Parent
)
.
def_property_readonly
(
"parent"
,
&
BlockDesc
::
Parent
)
.
def
(
"append_op"
,
&
BlockDesc
Bind
::
AppendOp
,
.
def
(
"append_op"
,
&
BlockDesc
::
AppendOp
,
py
::
return_value_policy
::
reference
)
py
::
return_value_policy
::
reference
)
.
def
(
"prepend_op"
,
&
BlockDesc
Bind
::
PrependOp
,
.
def
(
"prepend_op"
,
&
BlockDesc
::
PrependOp
,
py
::
return_value_policy
::
reference
)
py
::
return_value_policy
::
reference
)
.
def
(
"var"
,
.
def
(
"var"
,
[](
BlockDesc
Bind
&
self
,
py
::
bytes
byte_name
)
{
[](
BlockDesc
&
self
,
py
::
bytes
byte_name
)
{
std
::
string
name
=
byte_name
;
std
::
string
name
=
byte_name
;
return
self
.
Var
(
name
);
return
self
.
Var
(
name
);
},
},
py
::
return_value_policy
::
reference
)
py
::
return_value_policy
::
reference
)
.
def
(
"has_var"
,
.
def
(
"has_var"
,
[](
BlockDesc
Bind
&
self
,
py
::
bytes
byte_name
)
{
[](
BlockDesc
&
self
,
py
::
bytes
byte_name
)
{
std
::
string
name
=
byte_name
;
std
::
string
name
=
byte_name
;
return
self
.
HasVar
(
name
);
return
self
.
HasVar
(
name
);
})
})
.
def
(
"find_var"
,
.
def
(
"find_var"
,
[](
BlockDesc
Bind
&
self
,
py
::
bytes
byte_name
)
{
[](
BlockDesc
&
self
,
py
::
bytes
byte_name
)
{
std
::
string
name
=
byte_name
;
std
::
string
name
=
byte_name
;
return
self
.
FindVar
(
name
);
return
self
.
FindVar
(
name
);
},
},
py
::
return_value_policy
::
reference
)
py
::
return_value_policy
::
reference
)
.
def
(
"all_vars"
,
&
BlockDescBind
::
AllVars
,
.
def
(
"all_vars"
,
&
BlockDesc
::
AllVars
,
py
::
return_value_policy
::
reference
)
py
::
return_value_policy
::
reference
)
.
def
(
"op_size"
,
&
BlockDesc
::
OpSize
)
.
def
(
"op_size"
,
&
BlockDescBind
::
OpSize
)
.
def
(
"op"
,
&
BlockDesc
::
Op
,
py
::
return_value_policy
::
reference
)
.
def
(
"op"
,
&
BlockDescBind
::
Op
,
py
::
return_value_policy
::
reference
)
.
def
(
"serialize_to_string"
,
SerializeMessage
<
BlockDesc
>
);
.
def
(
"serialize_to_string"
,
SerializeMessage
<
BlockDescBind
>
);
}
}
void
BindVarDsec
(
py
::
module
&
m
)
{
void
BindVarDsec
(
py
::
module
&
m
)
{
...
@@ -193,25 +192,25 @@ void BindVarDsec(py::module &m) {
...
@@ -193,25 +192,25 @@ void BindVarDsec(py::module &m) {
.
value
(
"FP32"
,
proto
::
DataType
::
FP32
)
.
value
(
"FP32"
,
proto
::
DataType
::
FP32
)
.
value
(
"FP64"
,
proto
::
DataType
::
FP64
);
.
value
(
"FP64"
,
proto
::
DataType
::
FP64
);
py
::
class_
<
VarDesc
Bind
>
var_desc
(
m
,
"VarDesc"
,
""
);
py
::
class_
<
VarDesc
>
var_desc
(
m
,
"VarDesc"
,
""
);
var_desc
var_desc
.
def
(
"name"
,
.
def
(
"name"
,
[](
const
VarDesc
Bind
&
self
)
{
[](
const
VarDesc
&
self
)
{
py
::
bytes
name
=
self
.
Name
();
py
::
bytes
name
=
self
.
Name
();
return
name
;
return
name
;
},
},
py
::
return_value_policy
::
reference
)
py
::
return_value_policy
::
reference
)
.
def
(
"set_shape"
,
&
VarDesc
Bind
::
SetShape
)
.
def
(
"set_shape"
,
&
VarDesc
::
SetShape
)
.
def
(
"set_dtype"
,
&
VarDesc
Bind
::
SetDataType
)
.
def
(
"set_dtype"
,
&
VarDesc
::
SetDataType
)
.
def
(
"shape"
,
&
VarDesc
Bind
::
Shape
,
py
::
return_value_policy
::
reference
)
.
def
(
"shape"
,
&
VarDesc
::
Shape
,
py
::
return_value_policy
::
reference
)
.
def
(
"dtype"
,
&
VarDesc
Bind
::
GetDataType
)
.
def
(
"dtype"
,
&
VarDesc
::
GetDataType
)
.
def
(
"lod_level"
,
&
VarDesc
Bind
::
GetLodLevel
)
.
def
(
"lod_level"
,
&
VarDesc
::
GetLodLevel
)
.
def
(
"set_lod_level"
,
&
VarDesc
Bind
::
SetLoDLevel
)
.
def
(
"set_lod_level"
,
&
VarDesc
::
SetLoDLevel
)
.
def
(
"type"
,
&
VarDesc
Bind
::
GetType
)
.
def
(
"type"
,
&
VarDesc
::
GetType
)
.
def
(
"set_type"
,
&
VarDesc
Bind
::
SetType
)
.
def
(
"set_type"
,
&
VarDesc
::
SetType
)
.
def
(
"serialize_to_string"
,
SerializeMessage
<
VarDesc
Bind
>
)
.
def
(
"serialize_to_string"
,
SerializeMessage
<
VarDesc
>
)
.
def
(
"persistable"
,
&
VarDesc
Bind
::
Persistable
)
.
def
(
"persistable"
,
&
VarDesc
::
Persistable
)
.
def
(
"set_persistable"
,
&
VarDesc
Bind
::
SetPersistable
);
.
def
(
"set_persistable"
,
&
VarDesc
::
SetPersistable
);
py
::
enum_
<
proto
::
VarDesc
::
VarType
>
(
var_desc
,
"VarType"
,
""
)
py
::
enum_
<
proto
::
VarDesc
::
VarType
>
(
var_desc
,
"VarType"
,
""
)
.
value
(
"LOD_TENSOR"
,
proto
::
VarDesc
::
LOD_TENSOR
)
.
value
(
"LOD_TENSOR"
,
proto
::
VarDesc
::
LOD_TENSOR
)
...
@@ -235,26 +234,26 @@ void BindOpDesc(py::module &m) {
...
@@ -235,26 +234,26 @@ void BindOpDesc(py::module &m) {
.
value
(
"BOOLS"
,
proto
::
AttrType
::
BOOLEANS
)
.
value
(
"BOOLS"
,
proto
::
AttrType
::
BOOLEANS
)
.
value
(
"BLOCK"
,
proto
::
AttrType
::
BLOCK
);
.
value
(
"BLOCK"
,
proto
::
AttrType
::
BLOCK
);
py
::
class_
<
OpDesc
Bind
>
op_desc
(
m
,
"OpDesc"
,
""
);
py
::
class_
<
OpDesc
>
op_desc
(
m
,
"OpDesc"
,
""
);
op_desc
.
def
(
"type"
,
&
OpDesc
Bind
::
Type
)
op_desc
.
def
(
"type"
,
&
OpDesc
::
Type
)
.
def
(
"set_type"
,
&
OpDesc
Bind
::
SetType
)
.
def
(
"set_type"
,
&
OpDesc
::
SetType
)
.
def
(
"input"
,
&
OpDesc
Bind
::
Input
)
.
def
(
"input"
,
&
OpDesc
::
Input
)
.
def
(
"input_names"
,
&
OpDesc
Bind
::
InputNames
)
.
def
(
"input_names"
,
&
OpDesc
::
InputNames
)
.
def
(
"set_input"
,
&
OpDesc
Bind
::
SetInput
)
.
def
(
"set_input"
,
&
OpDesc
::
SetInput
)
.
def
(
"output"
,
&
OpDesc
Bind
::
Output
)
.
def
(
"output"
,
&
OpDesc
::
Output
)
.
def
(
"output_names"
,
&
OpDesc
Bind
::
OutputNames
)
.
def
(
"output_names"
,
&
OpDesc
::
OutputNames
)
.
def
(
"set_output"
,
&
OpDesc
Bind
::
SetOutput
)
.
def
(
"set_output"
,
&
OpDesc
::
SetOutput
)
.
def
(
"has_attr"
,
&
OpDesc
Bind
::
HasAttr
)
.
def
(
"has_attr"
,
&
OpDesc
::
HasAttr
)
.
def
(
"attr_type"
,
&
OpDesc
Bind
::
GetAttrType
)
.
def
(
"attr_type"
,
&
OpDesc
::
GetAttrType
)
.
def
(
"attr_names"
,
&
OpDesc
Bind
::
AttrNames
)
.
def
(
"attr_names"
,
&
OpDesc
::
AttrNames
)
.
def
(
"set_attr"
,
&
OpDesc
Bind
::
SetAttr
)
.
def
(
"set_attr"
,
&
OpDesc
::
SetAttr
)
.
def
(
"attr"
,
&
OpDesc
Bind
::
GetAttr
)
.
def
(
"attr"
,
&
OpDesc
::
GetAttr
)
.
def
(
"set_block_attr"
,
&
OpDesc
Bind
::
SetBlockAttr
)
.
def
(
"set_block_attr"
,
&
OpDesc
::
SetBlockAttr
)
.
def
(
"block_attr"
,
&
OpDesc
Bind
::
GetBlockAttr
)
.
def
(
"block_attr"
,
&
OpDesc
::
GetBlockAttr
)
.
def
(
"check_attrs"
,
&
OpDesc
Bind
::
CheckAttrs
)
.
def
(
"check_attrs"
,
&
OpDesc
::
CheckAttrs
)
.
def
(
"infer_shape"
,
&
OpDesc
Bind
::
InferShape
)
.
def
(
"infer_shape"
,
&
OpDesc
::
InferShape
)
.
def
(
"infer_var_type"
,
&
OpDesc
Bind
::
InferVarType
)
.
def
(
"infer_var_type"
,
&
OpDesc
::
InferVarType
)
.
def
(
"serialize_to_string"
,
SerializeMessage
<
OpDesc
Bind
>
);
.
def
(
"serialize_to_string"
,
SerializeMessage
<
OpDesc
>
);
}
}
}
// namespace pybind
}
// namespace pybind
...
...
paddle/pybind/pybind.cc
浏览文件 @
09189732
...
@@ -266,36 +266,36 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -266,36 +266,36 @@ All parameter, weight, gradient are variables in Paddle.
return
ret_values
;
return
ret_values
;
});
});
m
.
def
(
"get_grad_op_descs"
,
m
.
def
(
"get_grad_op_descs"
,
[](
const
OpDesc
Bind
&
op_desc
,
[](
const
OpDesc
&
op_desc
,
const
std
::
unordered_set
<
std
::
string
>
&
no_grad_set
,
const
std
::
unordered_set
<
std
::
string
>
&
no_grad_set
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
grad_to_var
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>
&
grad_to_var
,
const
std
::
vector
<
BlockDesc
Bind
*>
&
grad_sub_block
)
{
const
std
::
vector
<
BlockDesc
*>
&
grad_sub_block
)
{
std
::
vector
<
std
::
unique_ptr
<
OpDesc
Bind
>>
grad_op_descs
=
std
::
vector
<
std
::
unique_ptr
<
OpDesc
>>
grad_op_descs
=
framework
::
OpInfoMap
::
Instance
()
framework
::
OpInfoMap
::
Instance
()
.
Get
(
op_desc
.
Type
())
.
Get
(
op_desc
.
Type
())
.
GradOpMaker
()(
op_desc
,
no_grad_set
,
&
grad_to_var
,
.
GradOpMaker
()(
op_desc
,
no_grad_set
,
&
grad_to_var
,
grad_sub_block
);
grad_sub_block
);
std
::
vector
<
OpDesc
Bind
*>
grad_op_desc_ptrs
(
grad_op_descs
.
size
());
std
::
vector
<
OpDesc
*>
grad_op_desc_ptrs
(
grad_op_descs
.
size
());
std
::
transform
(
std
::
transform
(
grad_op_descs
.
begin
(),
grad_op_descs
.
end
(),
grad_op_descs
.
begin
(),
grad_op_descs
.
end
(),
grad_op_desc_ptrs
.
begin
(),
grad_op_desc_ptrs
.
begin
(),
[](
std
::
unique_ptr
<
OpDesc
Bind
>
&
p
)
{
return
p
.
release
();
});
[](
std
::
unique_ptr
<
OpDesc
>
&
p
)
{
return
p
.
release
();
});
return
grad_op_desc_ptrs
;
return
grad_op_desc_ptrs
;
});
});
m
.
def
(
"prune"
,
[](
const
ProgramDesc
Bind
&
origin
,
m
.
def
(
"prune"
,
[](
const
ProgramDesc
&
origin
,
const
std
::
vector
<
std
::
array
<
size_t
,
2
>>
&
targets
)
{
const
std
::
vector
<
std
::
array
<
size_t
,
2
>>
&
targets
)
{
ProgramDesc
Bind
prog_with_targets
(
origin
);
ProgramDesc
prog_with_targets
(
origin
);
for
(
const
auto
&
t
:
targets
)
{
for
(
const
auto
&
t
:
targets
)
{
prog_with_targets
.
MutableBlock
(
t
[
0
])
->
Op
(
t
[
1
])
->
MarkAsTarget
();
prog_with_targets
.
MutableBlock
(
t
[
0
])
->
Op
(
t
[
1
])
->
MarkAsTarget
();
}
}
proto
::
ProgramDesc
pruned_desc
;
proto
::
ProgramDesc
pruned_desc
;
Prune
(
*
prog_with_targets
.
Proto
(),
&
pruned_desc
);
Prune
(
*
prog_with_targets
.
Proto
(),
&
pruned_desc
);
return
new
ProgramDesc
Bind
(
pruned_desc
);
return
new
ProgramDesc
(
pruned_desc
);
});
});
m
.
def
(
"inference_optimize"
,
[](
ProgramDesc
Bind
&
origin
)
{
m
.
def
(
"inference_optimize"
,
[](
ProgramDesc
&
origin
)
{
proto
::
ProgramDesc
pruned_desc
;
proto
::
ProgramDesc
pruned_desc
;
InferenceOptimize
(
*
(
origin
.
Proto
()),
&
pruned_desc
);
InferenceOptimize
(
*
(
origin
.
Proto
()),
&
pruned_desc
);
return
new
ProgramDesc
Bind
(
pruned_desc
);
return
new
ProgramDesc
(
pruned_desc
);
});
});
m
.
def_submodule
(
m
.
def_submodule
(
"var_names"
,
"var_names"
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录