Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
704245ae
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
704245ae
编写于
9月 28, 2017
作者:
F
fengjiayi
提交者:
GitHub
9月 28, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #4486 from Canpio/dev_backward_for_op_desc
Add grad_op_desc_builder
上级
01d9c465
099b2c19
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
139 addition
and
1 deletion
+139
-1
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+1
-1
paddle/framework/grad_op_builder.cc
paddle/framework/grad_op_builder.cc
+39
-0
paddle/framework/grad_op_builder.h
paddle/framework/grad_op_builder.h
+3
-0
paddle/framework/grad_op_builder_test.cc
paddle/framework/grad_op_builder_test.cc
+79
-0
paddle/framework/op_desc.cc
paddle/framework/op_desc.cc
+11
-0
paddle/framework/op_desc.h
paddle/framework/op_desc.h
+6
-0
未找到文件。
paddle/framework/CMakeLists.txt
浏览文件 @
704245ae
...
...
@@ -26,7 +26,7 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library
(
operator SRCS operator.cc DEPS op_info device_context tensor scope
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry
)
cc_library
(
grad_op_builder SRCS grad_op_builder.cc DEPS operator
)
cc_library
(
grad_op_builder SRCS grad_op_builder.cc DEPS operator
proto_desc
)
cc_library
(
op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker op_info
)
cc_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry
)
cc_test
(
grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op
)
...
...
paddle/framework/grad_op_builder.cc
浏览文件 @
704245ae
...
...
@@ -54,5 +54,44 @@ OperatorBase* BuildGradOp(const OperatorBase* op) {
return
grad_info
.
Creator
()(
info
.
grad_op_type_
,
inputs
,
outputs
,
op
->
Attrs
());
}
static
void
TransOpDescArg
(
const
OpDescBind
*
src_op
,
const
OpArgType
&
src_type
,
bool
is_grad
,
OpDescBind
*
dst_op
,
const
OpArgType
&
dst_type
)
{
PADDLE_ENFORCE
(
dst_op
!=
nullptr
,
"Protobuf desc of gradient op must be initialized first."
);
const
auto
&
proto
=
OpInfoMap
::
Instance
().
Get
(
src_op
->
Type
()).
Proto
();
const
auto
&
src_arg_list
=
src_type
==
OpArgType
::
IN
?
proto
.
inputs
()
:
proto
.
outputs
();
for
(
const
auto
&
arg
:
src_arg_list
)
{
if
(
arg
.
not_in_gradient
()
&&
!
is_grad
)
continue
;
const
std
::
string
src_name
=
arg
.
name
();
std
::
vector
<
std
::
string
>
vars
=
src_type
==
OpArgType
::
IN
?
src_op
->
Input
(
src_name
)
:
src_op
->
Output
(
src_name
);
if
(
is_grad
)
{
for
(
std
::
string
&
var
:
vars
)
{
var
=
GradVarName
(
var
);
}
}
std
::
string
dst_name
=
is_grad
?
GradVarName
(
src_name
)
:
src_name
;
dst_type
==
OpArgType
::
IN
?
dst_op
->
SetInput
(
dst_name
,
vars
)
:
dst_op
->
SetOutput
(
dst_name
,
vars
);
}
}
void
CompleteGradOpDesc
(
const
OpDescBind
*
forw_op
,
OpDescBind
*
grad_op
)
{
auto
&
info
=
OpInfoMap
::
Instance
().
Get
(
forw_op
->
Type
());
PADDLE_ENFORCE
(
info
.
HasGradientOp
());
grad_op
->
SetType
(
info
.
grad_op_type_
);
TransOpDescArg
(
forw_op
,
OpArgType
::
IN
,
false
,
grad_op
,
OpArgType
::
IN
);
TransOpDescArg
(
forw_op
,
OpArgType
::
OUT
,
false
,
grad_op
,
OpArgType
::
IN
);
TransOpDescArg
(
forw_op
,
OpArgType
::
OUT
,
true
,
grad_op
,
OpArgType
::
IN
);
TransOpDescArg
(
forw_op
,
OpArgType
::
IN
,
true
,
grad_op
,
OpArgType
::
OUT
);
grad_op
->
SetAttrMap
(
forw_op
->
GetAttrMap
());
}
}
// namespace framework
}
// namespace paddle
paddle/framework/grad_op_builder.h
浏览文件 @
704245ae
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
...
...
@@ -21,5 +22,7 @@ namespace framework {
OperatorBase
*
BuildGradOp
(
const
OperatorBase
*
op
);
void
CompleteGradOpDesc
(
const
OpDescBind
*
forw_op
,
OpDescBind
*
grad_op
);
}
// namespace framework
}
// namespace paddle
paddle/framework/grad_op_builder_test.cc
浏览文件 @
704245ae
...
...
@@ -120,3 +120,82 @@ TEST(GradOpBuilder, IOIgnoredInGradient) {
std
::
vector
<
std
::
string
>
(
{
f
::
GradVarName
(
"in3_1"
),
f
::
GradVarName
(
"in3_2"
)}));
}
TEST
(
GradOpDescBuilder
,
MutiInOut
)
{
f
::
OpDescBind
*
forw_op
=
new
f
::
OpDescBind
();
forw_op
->
SetType
(
"mult_io"
);
forw_op
->
SetInput
(
"In1"
,
{
"in1"
});
forw_op
->
SetInput
(
"In2_mult"
,
{
"in2_1"
,
"in2_2"
,
"in2_3"
});
forw_op
->
SetInput
(
"In3"
,
{
"in3"
});
forw_op
->
SetOutput
(
"Out1"
,
{
"out1"
});
forw_op
->
SetOutput
(
"Out2_mult"
,
{
"out2_1"
,
"out2_2"
});
f
::
OpDescBind
*
grad_op
=
new
f
::
OpDescBind
();
f
::
CompleteGradOpDesc
(
forw_op
,
grad_op
);
EXPECT_EQ
(
grad_op
->
Type
(),
"mult_io_grad"
);
ASSERT_EQ
(
grad_op
->
InputNames
().
size
(),
3UL
+
2UL
+
2UL
);
EXPECT_EQ
(
grad_op
->
Input
(
"In1"
),
std
::
vector
<
std
::
string
>
({
"in1"
}));
EXPECT_EQ
(
grad_op
->
Input
(
"In2_mult"
),
std
::
vector
<
std
::
string
>
({
"in2_1"
,
"in2_2"
,
"in2_3"
}));
EXPECT_EQ
(
grad_op
->
Input
(
"In3"
),
std
::
vector
<
std
::
string
>
({
"in3"
}));
EXPECT_EQ
(
grad_op
->
Input
(
"Out1"
),
std
::
vector
<
std
::
string
>
({
"out1"
}));
EXPECT_EQ
(
grad_op
->
Input
(
"Out2_mult"
),
std
::
vector
<
std
::
string
>
({
"out2_1"
,
"out2_2"
}));
EXPECT_EQ
(
grad_op
->
Input
(
f
::
GradVarName
(
"Out1"
)),
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"out1"
)}));
EXPECT_EQ
(
grad_op
->
Input
(
f
::
GradVarName
(
"Out2_mult"
)),
std
::
vector
<
std
::
string
>
(
{
f
::
GradVarName
(
"out2_1"
),
f
::
GradVarName
(
"out2_2"
)}));
ASSERT_EQ
(
grad_op
->
OutputNames
().
size
(),
3UL
);
EXPECT_EQ
(
grad_op
->
Output
(
f
::
GradVarName
(
"In1"
)),
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"in1"
)}));
EXPECT_EQ
(
grad_op
->
Output
(
f
::
GradVarName
(
"In2_mult"
)),
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"in2_1"
),
f
::
GradVarName
(
"in2_2"
),
f
::
GradVarName
(
"in2_3"
)}));
EXPECT_EQ
(
grad_op
->
Output
(
f
::
GradVarName
(
"In3"
)),
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"in3"
)}));
delete
forw_op
;
delete
grad_op
;
}
TEST
(
GradOpDescBuilder
,
IOIgnoredInGradient
)
{
f
::
OpDescBind
*
forw_op
=
new
f
::
OpDescBind
();
forw_op
->
SetType
(
"io_ignored"
);
forw_op
->
SetInput
(
"In1"
,
{
"in1"
});
forw_op
->
SetInput
(
"In2_mult"
,
{
"in2_1"
,
"in2_2"
});
forw_op
->
SetInput
(
"In3_mult"
,
{
"in3_1"
,
"in3_2"
});
forw_op
->
SetOutput
(
"Out1_mult"
,
{
"out1_1"
,
"out1_2"
});
forw_op
->
SetOutput
(
"Out2"
,
{
"out2"
});
f
::
OpDescBind
*
grad_op
=
new
f
::
OpDescBind
();
f
::
CompleteGradOpDesc
(
forw_op
,
grad_op
);
EXPECT_EQ
(
grad_op
->
Type
(),
"io_ignored_grad"
);
// 'In2' and 'Out2' are ignored in gradient calculating
ASSERT_EQ
(
grad_op
->
InputNames
().
size
(),
2UL
+
1UL
+
2UL
);
EXPECT_EQ
(
grad_op
->
Input
(
"In1"
),
std
::
vector
<
std
::
string
>
({
"in1"
}));
EXPECT_EQ
(
grad_op
->
Input
(
"In3_mult"
),
std
::
vector
<
std
::
string
>
({
"in3_1"
,
"in3_2"
}));
EXPECT_EQ
(
grad_op
->
Input
(
"Out1_mult"
),
std
::
vector
<
std
::
string
>
({
"out1_1"
,
"out1_2"
}));
EXPECT_EQ
(
grad_op
->
Input
(
f
::
GradVarName
(
"Out1_mult"
)),
std
::
vector
<
std
::
string
>
(
{
f
::
GradVarName
(
"out1_1"
),
f
::
GradVarName
(
"out1_2"
)}));
EXPECT_EQ
(
grad_op
->
Input
(
f
::
GradVarName
(
"Out2"
)),
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"out2"
)}));
ASSERT_EQ
(
grad_op
->
OutputNames
().
size
(),
3UL
);
EXPECT_EQ
(
grad_op
->
Output
(
f
::
GradVarName
(
"In1"
)),
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"in1"
)}));
EXPECT_EQ
(
grad_op
->
Output
(
f
::
GradVarName
(
"In2_mult"
)),
std
::
vector
<
std
::
string
>
(
{
f
::
GradVarName
(
"in2_1"
),
f
::
GradVarName
(
"in2_2"
)}));
EXPECT_EQ
(
grad_op
->
Output
(
f
::
GradVarName
(
"In3_mult"
)),
std
::
vector
<
std
::
string
>
(
{
f
::
GradVarName
(
"in3_1"
),
f
::
GradVarName
(
"in3_2"
)}));
delete
forw_op
;
delete
grad_op
;
}
\ No newline at end of file
paddle/framework/op_desc.cc
浏览文件 @
704245ae
...
...
@@ -89,6 +89,12 @@ void OpDescBind::SetAttr(const std::string &name, const Attribute &v) {
need_update_
=
true
;
}
void
OpDescBind
::
SetAttrMap
(
const
std
::
unordered_map
<
std
::
string
,
Attribute
>
&
attr_map
)
{
attrs_
=
attr_map
;
need_update_
=
true
;
}
Attribute
OpDescBind
::
GetAttr
(
const
std
::
string
&
name
)
const
{
auto
it
=
attrs_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
attrs_
.
end
(),
"Attribute %s is not found"
,
name
);
...
...
@@ -101,6 +107,11 @@ int OpDescBind::GetBlockAttr(const std::string &name) const {
return
boost
::
get
<
BlockDesc
*>
(
it
->
second
)
->
idx
();
}
const
std
::
unordered_map
<
std
::
string
,
Attribute
>
&
OpDescBind
::
GetAttrMap
()
const
{
return
attrs_
;
}
void
OpDescBind
::
Sync
()
{
if
(
need_update_
)
{
this
->
op_desc_
.
mutable_inputs
()
->
Clear
();
...
...
paddle/framework/op_desc.h
浏览文件 @
704245ae
...
...
@@ -60,10 +60,16 @@ class OpDescBind {
void
SetBlockAttr
(
const
std
::
string
&
name
,
BlockDescBind
&
block
);
// Only be used in C++
void
SetAttrMap
(
const
std
::
unordered_map
<
std
::
string
,
Attribute
>
&
attr_map
);
Attribute
GetAttr
(
const
std
::
string
&
name
)
const
;
int
GetBlockAttr
(
const
std
::
string
&
name
)
const
;
// Only be used in C++
const
std
::
unordered_map
<
std
::
string
,
Attribute
>
&
GetAttrMap
()
const
;
private:
struct
SetAttrDescVisitor
:
public
boost
::
static_visitor
<
void
>
{
explicit
SetAttrDescVisitor
(
OpDesc
::
Attr
*
attr
)
:
attr_
(
attr
)
{}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录