Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
f77c63b8
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f77c63b8
编写于
7月 26, 2017
作者:
F
fengjiayi
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'feature/backward' of
https://github.com/reyoung/Paddle
into feature/backward
上级
e32e3068
831d4e1c
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
152 addition
and
24 deletion
+152
-24
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+1
-1
paddle/framework/backward_test.cc
paddle/framework/backward_test.cc
+136
-6
paddle/framework/grad_op_builder.cc
paddle/framework/grad_op_builder.cc
+9
-10
paddle/framework/grad_op_builder.h
paddle/framework/grad_op_builder.h
+2
-2
paddle/framework/grad_op_builder_test.cc
paddle/framework/grad_op_builder_test.cc
+1
-1
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+3
-4
未找到文件。
paddle/framework/CMakeLists.txt
浏览文件 @
f77c63b8
...
...
@@ -33,4 +33,4 @@ cc_library(net SRCS net.cc DEPS op_registry)
cc_test
(
net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op
)
cc_library
(
backward SRCS backward.cc DEPS net
)
cc_test
(
backward_test SRCS backward_test.cc DEPS
net
)
cc_test
(
backward_test SRCS backward_test.cc DEPS
backward
)
paddle/framework/backward_test.cc
浏览文件 @
f77c63b8
...
...
@@ -12,8 +12,11 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/backward.h"
#include <gtest/gtest.h>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -24,10 +27,9 @@ class EmptyOp : public OperatorBase {
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
};
class
RowwiseAddOp
:
public
EmptyOp
{};
class
RowwiseAddOpMaker
:
public
OpProtoAndCheckerMaker
{
class
RowWiseAddOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
Row
w
iseAddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
Row
W
iseAddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input X of Add"
).
IgnoreGradient
();
AddInput
(
"b"
,
"Bias of Add"
).
IgnoreGradient
();
...
...
@@ -36,15 +38,143 @@ class RowwiseAddOpMaker : public OpProtoAndCheckerMaker {
}
};
class
RowwiseAddGradOp
:
public
EmptyOp
{};
class
MulOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
MulOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"A"
,
"A"
);
AddInput
(
"B"
,
"B"
);
AddOutput
(
"Out"
,
"Out"
);
AddComment
(
"Mul"
);
}
};
class
SigmoidOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
SigmoidOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"X"
);
AddOutput
(
"Y"
,
"Y"
);
AddComment
(
"Sigmoid"
);
}
};
class
FcOp
:
public
NetOp
{
public:
void
Init
()
override
{
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
{
Input
(
"X"
),
Input
(
"W"
)},
{
Output
(
"before_act"
)},
{}));
auto
b_name
=
Input
(
"b"
);
if
(
b_name
!=
EMPTY_VAR_NAME
())
{
AddOp
(
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{
Output
(
"before_act"
),
b_name
},
{
Output
(
"before_act"
)},
{}));
}
AddOp
(
OpRegistry
::
CreateOp
(
"sigmoid"
,
{
Output
(
"before_act"
)},
{
Output
(
"Out"
)},
{}));
CompleteAddOp
(
false
);
}
};
class
FcOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
FcOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"x"
);
AddInput
(
"W"
,
"w"
);
AddInput
(
"b"
,
"b"
);
AddOutput
(
"before_act"
,
"before act"
).
SetTemporary
();
AddOutput
(
"Out"
,
""
);
AddComment
(
""
);
}
};
class
ManyOutputOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
ManyOutputOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"x"
,
"x"
);
AddOutput
(
"y"
,
"y"
);
AddOutput
(
"z"
,
"z"
);
AddComment
(
""
);
}
};
class
FillZeroOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
FillZeroOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"x"
,
"x"
);
AddOutput
(
"out"
,
"out"
);
AddComment
(
""
);
}
};
}
// namespace framework
}
// namespace paddle
namespace
f
=
paddle
::
framework
;
REGISTER_OP
(
rowwise_add
,
f
::
RowwiseAddOp
,
f
::
RowwiseAddOpMaker
);
REGISTER_GRADIENT_OP
(
rowwise_add
,
rowwise_add_grad
,
f
::
RowwiseAddGradOp
);
using
EnforceNotMet
=
paddle
::
platform
::
EnforceNotMet
;
REGISTER_OP
(
rowwise_add
,
f
::
EmptyOp
,
f
::
RowWiseAddOpMaker
);
REGISTER_GRADIENT_OP
(
rowwise_add
,
rowwise_add_grad
,
f
::
EmptyOp
);
REGISTER_OP
(
mul
,
f
::
EmptyOp
,
f
::
MulOpMaker
);
REGISTER_GRADIENT_OP
(
mul
,
mul_grad
,
f
::
EmptyOp
);
REGISTER_OP
(
sigmoid
,
f
::
EmptyOp
,
f
::
SigmoidOpMaker
);
REGISTER_GRADIENT_OP
(
sigmoid
,
sigmoid_grad
,
f
::
EmptyOp
);
REGISTER_OP
(
fc
,
f
::
FcOp
,
f
::
FcOpMaker
);
REGISTER_OP
(
many_output_op
,
f
::
EmptyOp
,
f
::
ManyOutputOpMaker
);
REGISTER_GRADIENT_OP
(
many_output_op
,
many_output_op_grad
,
f
::
EmptyOp
);
REGISTER_OP
(
fill_zeros_like
,
f
::
EmptyOp
,
f
::
FillZeroOpMaker
);
TEST
(
Backward
,
simple_grad
)
{
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{
"X"
,
"b"
},
{
"Out"
},
{});
ASSERT_NE
(
fwd
,
nullptr
);
auto
gop
=
f
::
OpRegistry
::
CreateGradOp
(
*
fwd
);
ASSERT_EQ
(
"Out"
+
f
::
OperatorBase
::
GRAD_VAR_SUFFIX
(),
gop
->
inputs_
[
0
]);
ASSERT_EQ
(
"rowwise_add_grad"
,
gop
->
type_
);
ASSERT_EQ
(
"X"
+
f
::
OperatorBase
::
GRAD_VAR_SUFFIX
(),
gop
->
outputs_
[
0
]);
ASSERT_EQ
(
"b"
+
f
::
OperatorBase
::
GRAD_VAR_SUFFIX
(),
gop
->
outputs_
[
1
]);
// LOG(INFO) << gop->Output("X" + "@GRAD");
}
TEST
(
Backward
,
not_for_network
)
{
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{
"X"
,
"W"
,
"b"
},
{
"Out"
,
"tmp_out"
},
{{
"temporary_index"
,
std
::
vector
<
int
>
{
1
}}});
ASSERT_THROW
(
f
::
OpRegistry
::
CreateGradOp
(
*
fwd
),
EnforceNotMet
);
}
TEST
(
Backward
,
all_input_are_not_need
)
{
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{
"X"
,
"b"
},
{
"Out"
},
{});
auto
backward
=
f
::
Backward
(
*
fwd
,
{
"X"
,
"b"
});
ASSERT_TRUE
(
backward
->
IsNetOp
());
auto
net
=
static_cast
<
f
::
NetOp
*>
(
backward
.
get
());
ASSERT_TRUE
(
net
->
ops_
.
empty
());
}
TEST
(
Backward
,
all_output_are_not_need
)
{
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{
"X"
,
"b"
},
{
"Out"
},
{});
auto
backward
=
f
::
Backward
(
*
fwd
,
{
"Out"
});
ASSERT_TRUE
(
backward
->
IsNetOp
());
auto
net
=
static_cast
<
f
::
NetOp
*>
(
backward
.
get
());
ASSERT_TRUE
(
net
->
ops_
.
empty
());
}
TEST
(
Backward
,
part_of_output_are_not_need
)
{
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"many_output_op"
,
{
"X"
},
{
"Y"
,
"Z"
},
{});
auto
backward
=
f
::
Backward
(
*
fwd
,
{
"Z"
});
ASSERT_TRUE
(
backward
->
IsNetOp
());
auto
net
=
static_cast
<
f
::
NetOp
*>
(
backward
.
get
());
ASSERT_EQ
(
net
->
ops_
.
size
(),
2
);
auto
&
fill_zero
=
*
net
->
ops_
[
0
];
ASSERT_EQ
(
"fill_zeros_like"
,
fill_zero
.
type_
);
ASSERT_EQ
(
1
,
fill_zero
.
inputs_
.
size
());
ASSERT_EQ
(
"Z"
,
fill_zero
.
inputs_
[
0
]);
ASSERT_EQ
(
1
,
fill_zero
.
outputs_
.
size
());
ASSERT_EQ
(
"Z@ZERO"
,
fill_zero
.
outputs_
[
0
]);
auto
&
d_many_out
=
*
net
->
ops_
[
1
];
ASSERT_EQ
(
"many_output_op_grad"
,
d_many_out
.
type_
);
ASSERT_EQ
(
1
+
2
+
2
,
d_many_out
.
inputs_
.
size
());
// I/O/OG
ASSERT_EQ
(
"Z@ZERO"
,
d_many_out
.
Input
(
"z@GRAD"
));
}
\ No newline at end of file
paddle/framework/grad_op_builder.cc
浏览文件 @
f77c63b8
...
...
@@ -20,7 +20,7 @@ namespace framework {
OperatorBase
*
GradOpBuilder
::
Build
()
{
BuildOpInOutArgList
();
std
::
string
grad_op_type
=
OpRegistry
::
grad_ops
().
at
(
op_
->
type_
);
std
::
string
grad_op_type
=
OpRegistry
::
grad_ops
().
at
(
op_
.
type_
);
OperatorBase
*
grad_op
=
OpRegistry
::
op_creators
().
at
(
grad_op_type
)();
grad_op
->
type_
=
grad_op_type
;
CompleteGradOp
(
grad_op
);
...
...
@@ -39,15 +39,15 @@ OpInOutArg* GradOpBuilder::BuildArg(const VarProto& var,
}
void
GradOpBuilder
::
BuildOpInOutArgList
()
{
const
OpProto
&
op_proto
=
OpRegistry
::
protos
().
at
(
op_
->
type_
);
const
auto
&
var_map
=
*
(
OpRegistry
::
VarIndexMaps
().
at
(
op_
->
type_
));
const
OpProto
&
op_proto
=
OpRegistry
::
protos
().
at
(
op_
.
type_
);
const
auto
&
var_map
=
*
(
OpRegistry
::
VarIndexMaps
().
at
(
op_
.
type_
));
const
std
::
vector
<
int
>&
in_format
=
op_
->
attrs_
.
count
(
"input_format"
)
?
op_
->
GetAttr
<
std
::
vector
<
int
>>
(
"input_format"
)
op_
.
attrs_
.
count
(
"input_format"
)
?
op_
.
GetAttr
<
std
::
vector
<
int
>>
(
"input_format"
)
:
std
::
vector
<
int
>
();
const
std
::
vector
<
int
>&
out_format
=
op_
->
attrs_
.
count
(
"output_format"
)
?
op_
->
GetAttr
<
std
::
vector
<
int
>>
(
"output_format"
)
op_
.
attrs_
.
count
(
"output_format"
)
?
op_
.
GetAttr
<
std
::
vector
<
int
>>
(
"output_format"
)
:
std
::
vector
<
int
>
();
for
(
const
auto
&
var
:
op_proto
.
inputs
())
{
arg_list_
.
emplace_back
(
...
...
@@ -70,8 +70,7 @@ void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg,
}
(
*
varmap
)[
var_name
]
=
idx
++
;
size_t
pre_sz
=
in_out
.
size
();
auto
base_it
=
arg
->
type_
==
IN
?
op_
->
inputs_
.
begin
()
:
op_
->
outputs_
.
begin
();
auto
base_it
=
arg
->
type_
==
IN
?
op_
.
inputs_
.
begin
()
:
op_
.
outputs_
.
begin
();
std
::
copy
(
base_it
+
arg
->
begin_idx_
,
base_it
+
arg
->
end_idx_
,
std
::
back_inserter
(
in_out
));
if
(
is_grad
)
{
...
...
@@ -83,7 +82,7 @@ void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg,
}
void
GradOpBuilder
::
CompleteGradOp
(
OperatorBase
*
grad_op
)
const
{
grad_op
->
attrs_
=
op_
->
attrs_
;
grad_op
->
attrs_
=
op_
.
attrs_
;
grad_op
->
attrs_
.
erase
(
"input_format"
);
grad_op
->
attrs_
.
erase
(
"output_format"
);
VarIndexMap
*
grad_varmap
=
new
VarIndexMap
();
...
...
paddle/framework/grad_op_builder.h
浏览文件 @
f77c63b8
...
...
@@ -29,7 +29,7 @@ class GradOpBuilder {
using
VarIndexMap
=
std
::
unordered_map
<
std
::
string
,
int
>
;
public:
GradOpBuilder
(
const
OperatorBase
*
op
)
:
op_
(
op
)
{}
GradOpBuilder
(
const
OperatorBase
&
op
)
:
op_
(
op
)
{}
OperatorBase
*
Build
();
private:
...
...
@@ -40,7 +40,7 @@ class GradOpBuilder {
std
::
vector
<
int
>&
format
,
VarIndexMap
*
varmap
,
int
&
idx
,
bool
is_grad
)
const
;
void
CompleteGradOp
(
OperatorBase
*
grad_op
)
const
;
const
OperatorBase
*
op_
;
const
OperatorBase
&
op_
;
std
::
vector
<
std
::
shared_ptr
<
OpInOutArg
>>
arg_list_
;
};
...
...
paddle/framework/grad_op_builder_test.cc
浏览文件 @
f77c63b8
...
...
@@ -11,7 +11,7 @@ namespace framework {
TEST
(
GradOpBuilder
,
AddTwo
)
{
std
::
shared_ptr
<
OperatorBase
>
add_op
(
OpRegistry
::
CreateOp
(
"add_two"
,
{
"x"
,
"y"
},
{
"out"
},
{}));
std
::
shared_ptr
<
OperatorBase
>
grad_add_op
=
OpRegistry
::
CreateGradOp
(
add_op
);
std
::
shared_ptr
<
OperatorBase
>
grad_add_op
=
OpRegistry
::
CreateGradOp
(
*
add_op
);
EXPECT_EQ
(
static_cast
<
int
>
(
grad_add_op
->
inputs_
.
size
()),
4
);
EXPECT_EQ
(
static_cast
<
int
>
(
grad_add_op
->
outputs_
.
size
()),
2
);
EXPECT_EQ
(
grad_add_op
->
Input
(
"X"
),
"x"
);
...
...
paddle/framework/op_registry.h
浏览文件 @
f77c63b8
...
...
@@ -303,11 +303,10 @@ class OpRegistry {
return
CreateOp
(
op_desc
.
type
(),
inputs
,
outputs
,
attrs
);
}
static
std
::
shared_ptr
<
OperatorBase
>
CreateGradOp
(
std
::
shared_ptr
<
OperatorBase
>
op
)
{
PADDLE_ENFORCE
(
!
op
->
IsNetOp
(),
static
std
::
shared_ptr
<
OperatorBase
>
CreateGradOp
(
const
OperatorBase
&
op
)
{
PADDLE_ENFORCE
(
!
op
.
IsNetOp
(),
"Use framework::Backward to get backward ops"
);
GradOpBuilder
builder
(
op
.
get
()
);
GradOpBuilder
builder
(
op
);
std
::
shared_ptr
<
OperatorBase
>
grad_op
(
builder
.
Build
());
grad_op
->
Init
();
return
grad_op
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录