Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
dfb4ea76
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dfb4ea76
编写于
8月 11, 2017
作者:
Q
qingqing01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
make unit test of backward_test pass.
上级
88104905
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
249 addition
and
216 deletion
+249
-216
paddle/framework/backward.cc
paddle/framework/backward.cc
+6
-6
paddle/framework/backward_test.cc
paddle/framework/backward_test.cc
+242
-209
paddle/framework/operator.cc
paddle/framework/operator.cc
+1
-1
未找到文件。
paddle/framework/backward.cc
浏览文件 @
dfb4ea76
...
@@ -25,7 +25,7 @@ template <typename Map, typename T>
...
@@ -25,7 +25,7 @@ template <typename Map, typename T>
static
void
ForEachVarName
(
Map
&
names
,
T
callback
)
{
static
void
ForEachVarName
(
Map
&
names
,
T
callback
)
{
for
(
auto
&
name
:
names
)
{
for
(
auto
&
name
:
names
)
{
for
(
auto
&
n
:
name
.
second
)
{
for
(
auto
&
n
:
name
.
second
)
{
if
(
callback
(
n
))
break
;
if
(
callback
(
n
))
return
;
}
}
}
}
}
}
...
@@ -33,12 +33,12 @@ static void ForEachVarName(Map& names, T callback) {
...
@@ -33,12 +33,12 @@ static void ForEachVarName(Map& names, T callback) {
static
bool
AllInSet
(
static
bool
AllInSet
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>&
names
,
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>&
names
,
const
std
::
string
&
suffix
,
const
std
::
unordered_set
<
std
::
string
>&
set
)
{
const
std
::
string
&
suffix
,
const
std
::
unordered_set
<
std
::
string
>&
set
)
{
bool
ret_val
=
true
;
bool
all_in_set
=
true
;
ForEachVarName
(
names
,
[
&
ret_val
,
&
set
,
&
suffix
](
const
std
::
string
&
n
)
{
ForEachVarName
(
names
,
[
&
all_in_set
,
&
set
,
&
suffix
](
const
std
::
string
&
n
)
{
ret_val
=
set
.
find
(
n
+
suffix
)
=
=
set
.
end
();
all_in_set
=
set
.
find
(
n
+
suffix
)
!
=
set
.
end
();
return
!
ret_val
;
return
!
all_in_set
;
});
});
return
ret_val
;
return
all_in_set
;
}
}
static
std
::
shared_ptr
<
OperatorBase
>
NOP
()
{
static
std
::
shared_ptr
<
OperatorBase
>
NOP
()
{
...
...
paddle/framework/backward_test.cc
浏览文件 @
dfb4ea76
...
@@ -82,11 +82,11 @@ class FcOp : public operators::NetOp {
...
@@ -82,11 +82,11 @@ class FcOp : public operators::NetOp {
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
{{
"X"
,
{
Input
(
"X"
)}},
{
"Y"
,
{
Input
(
"W"
)}}},
{{
"X"
,
{
Input
(
"X"
)}},
{
"Y"
,
{
Input
(
"W"
)}}},
{{
"Out"
,
{
Output
(
"mul_result"
)}}},
{}));
{{
"Out"
,
{
Output
(
"mul_result"
)}}},
{}));
auto
b_name
=
Input
(
"b"
);
auto
input_b
=
Inputs
(
"b"
);
std
::
string
before_act
=
"mul_result"
;
std
::
string
before_act
=
"mul_result"
;
if
(
b_name
!=
kEmptyVarName
)
{
if
(
input_b
.
size
()
!=
0
)
{
AddOp
(
OpRegistry
::
CreateOp
(
AddOp
(
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{{
"X"
,
{
Output
(
"mul_result"
)}},
{
"b"
,
{
b_name
}}},
"rowwise_add"
,
{{
"X"
,
{
Output
(
"mul_result"
)}},
{
"b"
,
{
input_b
[
0
]
}}},
{{
"Out"
,
{
Output
(
"add_result"
)}}},
{}));
{{
"Out"
,
{
Output
(
"add_result"
)}}},
{}));
before_act
=
"add_result"
;
before_act
=
"add_result"
;
}
else
{
}
else
{
...
@@ -166,209 +166,242 @@ REGISTER_OP(fc, f::FcOp, f::FcOpMaker);
...
@@ -166,209 +166,242 @@ REGISTER_OP(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP
(
many_output_op
,
f
::
EmptyOp
,
f
::
ManyOutputOpMaker
);
REGISTER_OP
(
many_output_op
,
f
::
EmptyOp
,
f
::
ManyOutputOpMaker
);
REGISTER_GRADIENT_OP
(
many_output_op
,
many_output_op_grad
,
f
::
EmptyOp
);
REGISTER_GRADIENT_OP
(
many_output_op
,
many_output_op_grad
,
f
::
EmptyOp
);
// TEST(Backward, simple_op_grad) {
TEST
(
Backward
,
simple_op_grad
)
{
// auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
// ASSERT_NE(fwd, nullptr);
"rowwise_add"
,
{{
"X"
,
{
"x"
}},
{
"b"
,
{
"b"
}}},
{{
"Out"
,
{
"out"
}}},
{});
// auto gop = f::OpRegistry::CreateGradOp(*fwd);
ASSERT_NE
(
fwd
,
nullptr
);
// ASSERT_EQ(4UL, gop->inputs_.size());
auto
gop
=
f
::
OpRegistry
::
CreateGradOp
(
*
fwd
);
// ASSERT_EQ(f::kEmptyVarName, gop->inputs_[0]);
ASSERT_EQ
(
1UL
,
gop
->
inputs_
.
size
());
// ASSERT_EQ("rowwise_add_grad", gop->type_);
ASSERT_EQ
(
"rowwise_add_grad"
,
gop
->
type_
);
// ASSERT_EQ(f::GradVarName("X"), gop->outputs_[0]);
ASSERT_EQ
(
f
::
GradVarName
(
"x"
),
gop
->
Output
(
f
::
GradVarName
(
"X"
)));
// ASSERT_EQ(f::GradVarName("b"), gop->outputs_[1]);
ASSERT_EQ
(
f
::
GradVarName
(
"b"
),
gop
->
Output
(
f
::
GradVarName
(
"b"
)));
//
}
// ASSERT_EQ(f::GradVarName("X"), gop->Output(f::GradVarName("X")));
//}
TEST
(
Backward
,
simple_op_not_need_grad
)
{
//
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
// TEST(Backward, simple_op_not_need_grad) {
"rowwise_add"
,
{{
"X"
,
{
"x"
}},
{
"b"
,
{
"b"
}}},
{{
"Out"
,
{
"out"
}}},
{});
// auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
ASSERT_NE
(
fwd
,
nullptr
);
// ASSERT_NE(fwd, nullptr);
auto
gop
=
f
::
Backward
(
*
fwd
,
{
"x"
});
// auto gop = f::Backward(*fwd, {"X"});
ASSERT_EQ
(
gop
->
Output
(
f
::
GradVarName
(
"X"
)),
f
::
kEmptyVarName
);
// ASSERT_EQ(std::find(gop->outputs_.begin(), gop->outputs_.end(),
// f::GradVarName("X")),
auto
no_input_gop
=
f
::
Backward
(
*
fwd
,
{
"x"
,
"b"
});
// gop->outputs_.end());
ASSERT_NE
(
no_input_gop
,
nullptr
);
//
ASSERT_TRUE
(
no_input_gop
->
IsNetOp
());
// auto no_input_gop = f::Backward(*fwd, {"X", "b"});
ASSERT_EQ
(
0UL
,
// ASSERT_NE(no_input_gop, nullptr);
std
::
static_pointer_cast
<
ops
::
NetOp
>
(
no_input_gop
)
->
ops_
.
size
());
// ASSERT_TRUE(no_input_gop->IsNetOp());
}
// ASSERT_EQ(0UL,
// std::static_pointer_cast<ops::NetOp>(no_input_gop)->ops_.size());
TEST
(
Backward
,
net_fc_backward_normal
)
{
//}
std
::
shared_ptr
<
f
::
OperatorBase
>
fwd
=
//
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{{
"X"
,
{
"x"
}},
{
"W"
,
{
"w"
}},
{
"b"
,
{
"b"
}}},
// TEST(Backward, net_fc_backward_normal) {
{{
"mul_result"
,
{
"mul_res"
}},
// std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
{
"add_result"
,
{
"add_re"
}},
// "fc", {"X", "w", "b"}, {"mul_result", "add_result", "out"}, {});
{
"Out"
,
{
"out"
}}},
// ASSERT_NE(fwd, nullptr);
{});
// std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
ASSERT_NE
(
fwd
,
nullptr
);
// ASSERT_TRUE(gop->IsNetOp());
std
::
shared_ptr
<
f
::
OperatorBase
>
gop
=
f
::
Backward
(
*
fwd
,
{});
// auto net = static_cast<ops::NetOp *>(gop.get());
ASSERT_TRUE
(
gop
->
IsNetOp
());
//
auto
net
=
static_cast
<
ops
::
NetOp
*>
(
gop
.
get
());
// ASSERT_NO_THROW(net->DebugString());
//
ASSERT_NO_THROW
(
net
->
DebugString
());
// ASSERT_EQ(3UL, net->ops_.size());
//
ASSERT_EQ
(
3UL
,
net
->
ops_
.
size
());
// f::OperatorBase &d_sigmoid = *net->ops_[0];
// ASSERT_EQ("sigmoid_grad", d_sigmoid.type_);
f
::
OperatorBase
&
d_sigmoid
=
*
net
->
ops_
[
0
];
//
ASSERT_EQ
(
"sigmoid_grad"
,
d_sigmoid
.
type_
);
// f::OperatorBase &d_add = *net->ops_[1];
// ASSERT_EQ("rowwise_add_grad", d_add.type_);
f
::
OperatorBase
&
d_add
=
*
net
->
ops_
[
1
];
//
ASSERT_EQ
(
"rowwise_add_grad"
,
d_add
.
type_
);
// f::OperatorBase &d_mul = *net->ops_[2];
// ASSERT_EQ("mul_grad", d_mul.type_);
f
::
OperatorBase
&
d_mul
=
*
net
->
ops_
[
2
];
//}
ASSERT_EQ
(
"mul_grad"
,
d_mul
.
type_
);
//
}
// TEST(Backward, net_fc_backward_not_have_b) {
// std::shared_ptr<f::OperatorBase> fwd =
TEST
(
Backward
,
net_fc_backward_not_have_b
)
{
// f::OpRegistry::CreateOp("fc", {"X", "w", f::kEmptyVarName},
std
::
shared_ptr
<
f
::
OperatorBase
>
fwd
=
// {"mul_result", "add_result", "tmp"}, {});
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{{
"X"
,
{
"x"
}},
{
"W"
,
{
"w"
}},
{
"b"
,
{}}},
// ASSERT_NE(fwd, nullptr);
{{
"mul_result"
,
{
"mul_res"
}},
// std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
{
"add_result"
,
{
"add_res"
}},
// ASSERT_TRUE(gop->IsNetOp());
{
"Out"
,
{
"tmp"
}}},
// auto net = static_cast<ops::NetOp *>(gop.get());
{});
//
ASSERT_NE
(
fwd
,
nullptr
);
// ASSERT_NO_THROW(net->DebugString());
std
::
shared_ptr
<
f
::
OperatorBase
>
gop
=
f
::
Backward
(
*
fwd
,
{});
//
ASSERT_TRUE
(
gop
->
IsNetOp
());
// ASSERT_EQ(2UL, net->ops_.size());
auto
net
=
static_cast
<
ops
::
NetOp
*>
(
gop
.
get
());
//
// f::OperatorBase &d_sigmoid = *net->ops_[0];
ASSERT_NO_THROW
(
net
->
DebugString
());
// ASSERT_EQ("sigmoid_grad", d_sigmoid.type_);
//
ASSERT_EQ
(
2UL
,
net
->
ops_
.
size
());
// f::OperatorBase &d_mul = *net->ops_[1];
// ASSERT_EQ("mul_grad", d_mul.type_);
f
::
OperatorBase
&
d_sigmoid
=
*
net
->
ops_
[
0
];
//}
ASSERT_EQ
(
"sigmoid_grad"
,
d_sigmoid
.
type_
);
//
// TEST(Backward, net_input_of_network_not_need_grad) {
f
::
OperatorBase
&
d_mul
=
*
net
->
ops_
[
1
];
// ops::NetOp net;
ASSERT_EQ
(
"mul_grad"
,
d_mul
.
type_
);
// net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"},
}
// {"mul_tmp_0", "add_tmp_0", "hidden0"},
// {}));
TEST
(
Backward
,
net_input_of_network_not_need_grad
)
{
// net.AddOp(f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"},
ops
::
NetOp
net
;
// {"mul_tmp_1", "add_tmp_1", "hidden1"},
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
// {}));
"fc"
,
{{
"X"
,
{
"x"
}},
{
"W"
,
{
"W1"
}},
{
"b"
,
{
"b1"
}}},
// net.CompleteAddOp();
{{
"mul_result"
,
{
"mul_tmp_0"
}},
// auto bwd = Backward(net, {"X"}); // X@GRAD is not need.
{
"add_result"
,
{
"add_tmp_0"
}},
// ASSERT_TRUE(bwd->IsNetOp());
{
"Out"
,
{
"hidden0"
}}},
// auto bwd_net = static_cast<ops::NetOp *>(bwd.get());
{}));
//
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
// std::unordered_set<std::string> all_output =
"fc"
,
{{
"X"
,
{
"hidden0"
}},
{
"W"
,
{
"W2"
}},
{
"b"
,
{
"b2"
}}},
// std::unordered_set<std::string>(
{{
"mul_result"
,
{
"mul_tmp_1"
}},
// bwd_net->outputs_.begin(), bwd_net->outputs_.end());
{
"add_result"
,
{
"add_tmp_1"
}},
// all_output.erase(f::kEmptyVarName);
{
"Out"
,
{
"hidden1"
}}},
//
{}));
// for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) {
net
.
CompleteAddOp
();
// ASSERT_NE(all_output.find(f::GradVarName(out)), all_output.end());
auto
bwd
=
Backward
(
net
,
{
"x"
});
// x@GRAD is not need.
// }
ASSERT_TRUE
(
bwd
->
IsNetOp
());
//
auto
bwd_net
=
static_cast
<
ops
::
NetOp
*>
(
bwd
.
get
());
// // Not Generated X
// ASSERT_EQ(all_output.find(f::GradVarName("X")), all_output.end());
auto
output_vars
=
bwd_net
->
OutputVars
(
true
);
//
std
::
unordered_set
<
std
::
string
>
all_outputs
=
// ASSERT_EQ(2UL, bwd_net->ops_.size());
std
::
unordered_set
<
std
::
string
>
(
output_vars
.
begin
(),
output_vars
.
end
());
// ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
all_outputs
.
erase
(
f
::
kEmptyVarName
);
// auto first_fc_grad = static_cast<ops::NetOp *>(bwd_net->ops_[1].get());
// ASSERT_EQ(3UL, first_fc_grad->ops_.size());
for
(
auto
&
out
:
{
"W1"
,
"b1"
,
"hidden0"
,
"W2"
,
"b2"
})
{
// ASSERT_EQ(f::kEmptyVarName,
ASSERT_NE
(
all_outputs
.
find
(
f
::
GradVarName
(
out
)),
all_outputs
.
end
());
// first_fc_grad->ops_[2]->Output(f::GradVarName("A")));
}
//}
//
// Not Generated X
// TEST(Backward, net_shared_weight) {
ASSERT_EQ
(
all_outputs
.
find
(
f
::
GradVarName
(
"X"
)),
all_outputs
.
end
());
// ops::NetOp net;
// net.AddOp(f::OpRegistry::CreateOp("mul", {"X", "W"}, {"Out"}, {}));
ASSERT_EQ
(
2UL
,
bwd_net
->
ops_
.
size
());
// net.AddOp(f::OpRegistry::CreateOp("mul", {"Out", "W"}, {"FinalOut"}, {}));
ASSERT_TRUE
(
bwd_net
->
ops_
[
1
]
->
IsNetOp
());
// net.CompleteAddOp();
auto
first_fc_grad
=
static_cast
<
ops
::
NetOp
*>
(
bwd_net
->
ops_
[
1
].
get
());
//
ASSERT_EQ
(
3UL
,
first_fc_grad
->
ops_
.
size
());
// auto bwd = f::Backward(net, {});
ASSERT_EQ
(
f
::
kEmptyVarName
,
// ASSERT_TRUE(bwd->IsNetOp());
first_fc_grad
->
ops_
[
2
]
->
Output
(
f
::
GradVarName
(
"X"
)));
// auto bwd_net = static_cast<ops::NetOp *>(bwd.get());
}
// ASSERT_EQ(3UL, bwd_net->ops_.size());
// ASSERT_EQ("add", bwd_net->ops_[2]->type_);
TEST
(
Backward
,
net_shared_weight
)
{
//}
ops
::
NetOp
net
;
//
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"mul"
,
{{
"X"
,
{
"x"
}},
{
"Y"
,
{
"w"
}}},
// TEST(Backward, op_register_grad_not_for_network) {
{{
"Out"
,
{
"out"
}}},
{}));
// auto fwd = f::OpRegistry::CreateOp(
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"mul"
,
{{
"X"
,
{
"out"
}},
{
"Y"
,
{
"w"
}}},
// "fc", {"X", "W", "b"}, {"mul_out", "add_out", "out1"},
{{
"Out"
,
{
"FinalOut"
}}},
{}));
// {{"temporary_index", std::vector<int>{0, 1}}});
net
.
CompleteAddOp
();
//
// ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet);
auto
bwd
=
f
::
Backward
(
net
,
{});
//}
ASSERT_TRUE
(
bwd
->
IsNetOp
());
//
auto
bwd_net
=
static_cast
<
ops
::
NetOp
*>
(
bwd
.
get
());
// TEST(Backward, op_all_input_are_not_need) {
ASSERT_EQ
(
3UL
,
bwd_net
->
ops_
.
size
());
// auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
ASSERT_EQ
(
"add"
,
bwd_net
->
ops_
[
2
]
->
type_
);
// auto backward = f::Backward(*fwd, {"X", "b"});
}
// ASSERT_TRUE(backward->IsNetOp());
// auto net = static_cast<ops::NetOp *>(backward.get());
TEST
(
Backward
,
op_register_grad_not_for_network
)
{
// ASSERT_TRUE(net->ops_.empty());
auto
fwd
=
//}
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{{
"X"
,
{
"x"
}},
{
"W"
,
{
"w"
}},
{
"b"
,
{
"b"
}}},
//
{{
"mul_result"
,
{
"mul_out"
}},
// TEST(Backward, op_all_output_are_not_need) {
{
"add_result"
,
{
"add_out"
}},
// auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
{
"Out"
,
{
"out1"
}}},
// auto backward = f::Backward(*fwd, {"Out"});
{{
"temporary_index"
,
std
::
vector
<
int
>
{
0
,
1
}}});
// ASSERT_TRUE(backward->IsNetOp());
// auto net = static_cast<ops::NetOp *>(backward.get());
ASSERT_THROW
(
f
::
OpRegistry
::
CreateGradOp
(
*
fwd
),
EnforceNotMet
);
// ASSERT_TRUE(net->ops_.empty());
}
//}
//
TEST
(
Backward
,
op_all_input_are_not_need
)
{
// TEST(Backward, op_part_of_output_are_not_need) {
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
// auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {});
"rowwise_add"
,
{{
"X"
,
{
"x"
}},
{
"b"
,
{
"b"
}}},
{{
"Out"
,
{
"out"
}}},
{});
// auto backward = f::Backward(*fwd, {"Z"});
auto
backward
=
f
::
Backward
(
*
fwd
,
{
"x"
,
"b"
});
// ASSERT_TRUE(backward->IsNetOp());
ASSERT_TRUE
(
backward
->
IsNetOp
());
// auto net = static_cast<ops::NetOp *>(backward.get());
auto
net
=
static_cast
<
ops
::
NetOp
*>
(
backward
.
get
());
// ASSERT_EQ(net->ops_.size(), 2UL);
ASSERT_TRUE
(
net
->
ops_
.
empty
());
//
}
// auto &fill_zero = *net->ops_[0];
// ASSERT_EQ("fill_zeros_like", fill_zero.type_);
TEST
(
Backward
,
op_all_output_are_not_need
)
{
// ASSERT_EQ(1UL, fill_zero.inputs_.size());
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
// ASSERT_EQ("Z", fill_zero.inputs_[0]);
"rowwise_add"
,
{{
"X"
,
{
"x"
}},
{
"b"
,
{
"b"
}}},
{{
"Out"
,
{
"out"
}}},
{});
// ASSERT_EQ(1UL, fill_zero.outputs_.size());
auto
backward
=
f
::
Backward
(
*
fwd
,
{
"out"
});
// ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, fill_zero.outputs_[0]);
ASSERT_TRUE
(
backward
->
IsNetOp
());
//
auto
net
=
static_cast
<
ops
::
NetOp
*>
(
backward
.
get
());
// auto &d_many_out = *net->ops_[1];
ASSERT_TRUE
(
net
->
ops_
.
empty
());
// ASSERT_EQ("many_output_op_grad", d_many_out.type_);
}
// ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG
// ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix,
TEST
(
Backward
,
op_part_of_output_are_not_need
)
{
// d_many_out.Input(f::GradVarName("z")));
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"many_output_op"
,
{{
"x"
,
{
"X"
}}},
// ASSERT_EQ(f::GradVarName("Y"), d_many_out.Input(f::GradVarName("y")));
{{
"y"
,
{
"Y"
}},
{
"z"
,
{
"Z"
}}},
{});
// ASSERT_EQ(f::GradVarName("X"), d_many_out.Output(f::GradVarName("x")));
auto
backward
=
f
::
Backward
(
*
fwd
,
{
"Z"
});
//}
ASSERT_TRUE
(
backward
->
IsNetOp
());
//
auto
net
=
static_cast
<
ops
::
NetOp
*>
(
backward
.
get
());
// TEST(Backward, op_part_of_input_are_not_need) {
ASSERT_EQ
(
net
->
ops_
.
size
(),
2UL
);
// auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {});
// auto backward = f::Backward(*fwd, {"a"});
auto
&
fill_zero
=
*
net
->
ops_
[
0
];
// auto &grad_mul = *backward;
ASSERT_EQ
(
"fill_zeros_like"
,
fill_zero
.
type_
);
// ASSERT_EQ(grad_mul.type_, "mul_grad");
ASSERT_EQ
(
1UL
,
fill_zero
.
Inputs
(
"Src"
).
size
());
// ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL);
ASSERT_EQ
(
"Z"
,
fill_zero
.
Input
(
"Src"
));
// ASSERT_EQ(grad_mul.outputs_.size(), 2UL);
ASSERT_EQ
(
1UL
,
fill_zero
.
Outputs
(
"Dst"
).
size
());
// ASSERT_EQ(grad_mul.Output(f::GradVarName("A")), f::kEmptyVarName);
ASSERT_EQ
(
std
::
string
(
"Z"
)
+
f
::
kZeroVarSuffix
,
fill_zero
.
Output
(
"Dst"
));
// ASSERT_EQ(grad_mul.Output(f::GradVarName("B")), f::GradVarName("b"));
// ASSERT_EQ(grad_mul.Input(f::GradVarName("Out")), f::GradVarName("out"));
auto
&
d_many_out
=
*
net
->
ops_
[
1
];
// ASSERT_EQ(grad_mul.Input("A"), "a");
ASSERT_EQ
(
"many_output_op_grad"
,
d_many_out
.
type_
);
// ASSERT_EQ(grad_mul.Input("B"), "b");
ASSERT_EQ
(
1UL
+
2UL
+
2UL
,
d_many_out
.
inputs_
.
size
());
// I/O/OG
// ASSERT_EQ(grad_mul.Input("Out"), "out");
ASSERT_EQ
(
std
::
string
(
"Z"
)
+
f
::
kZeroVarSuffix
,
//}
d_many_out
.
Input
(
f
::
GradVarName
(
"z"
)));
//
ASSERT_EQ
(
f
::
GradVarName
(
"Y"
),
d_many_out
.
Input
(
f
::
GradVarName
(
"y"
)));
// TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
ASSERT_EQ
(
f
::
GradVarName
(
"X"
),
d_many_out
.
Output
(
f
::
GradVarName
(
"x"
)));
// ops::NetOp net;
}
// net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"},
// {"mul_out1", "add_out1", "out1"}, {}));
TEST
(
Backward
,
op_part_of_input_are_not_need
)
{
// net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"},
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"mul"
,
{{
"X"
,
{
"a"
}},
{
"Y"
,
{
"b"
}}},
// {"mul_out2", "tmp_out2", "out2"}, {}));
{{
"Out"
,
{
"out"
}}},
{});
// net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"},
auto
backward
=
f
::
Backward
(
*
fwd
,
{
"a"
});
// {"mul_out3", "tmp_out3", "out3"}, {}));
auto
&
grad_mul
=
*
backward
;
// net.CompleteAddOp();
ASSERT_EQ
(
grad_mul
.
type_
,
"mul_grad"
);
// auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"});
ASSERT_EQ
(
grad_mul
.
inputs_
.
size
(),
2UL
+
1UL
+
1UL
);
// ASSERT_TRUE(backward->IsNetOp());
ASSERT_EQ
(
grad_mul
.
outputs_
.
size
(),
2UL
);
// auto bwd_net = static_cast<ops::NetOp *>(backward.get());
ASSERT_EQ
(
grad_mul
.
Output
(
f
::
GradVarName
(
"X"
)),
f
::
kEmptyVarName
);
// ASSERT_EQ(bwd_net->ops_.size(), 3UL);
ASSERT_EQ
(
grad_mul
.
Output
(
f
::
GradVarName
(
"Y"
)),
f
::
GradVarName
(
"b"
));
// auto &grad_fc = *bwd_net->ops_[0];
ASSERT_EQ
(
grad_mul
.
Input
(
f
::
GradVarName
(
"Out"
)),
f
::
GradVarName
(
"out"
));
// EXPECT_EQ(grad_fc.inputs_.size(),
ASSERT_EQ
(
grad_mul
.
Input
(
"X"
),
"a"
);
// 3UL /* external input number */
ASSERT_EQ
(
grad_mul
.
Input
(
"Y"
),
"b"
);
// + 1UL /* external output number*/
ASSERT_EQ
(
grad_mul
.
Input
(
"Out"
),
"out"
);
// + 1UL /* number of gradient of external output*/
}
// + 2U /* internal variable number*/);
// EXPECT_EQ(grad_fc.outputs_.size(), 2UL /* input number of mul*/
TEST
(
Backward
,
linear_net_intermediate_variable_has_no_grad
)
{
// + 2UL /* input number of rowwise_add
ops
::
NetOp
net
;
// */
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
// + 1UL /* input number of sigmod */);
"fc"
,
{{
"X"
,
{
"x1"
}},
{
"W"
,
{
"w1"
}},
{
"b"
,
{
"b1"
}}},
// EXPECT_EQ(bwd_net->ops_[1]->inputs_.size(), 0UL);
{{
"mul_result"
,
{
"mul_out1"
}},
// EXPECT_EQ(bwd_net->ops_[1]->outputs_.size(), 0UL);
{
"add_result"
,
{
"add_out1"
}},
// EXPECT_EQ(bwd_net->ops_[2]->inputs_.size(), 0UL);
{
"Out"
,
{
"out1"
}}},
// EXPECT_EQ(bwd_net->ops_[2]->outputs_.size(), 0UL);
{}));
//}
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{{
"X"
,
{
"out1"
}},
{
"W"
,
{
"w2"
}},
{
"b"
,
{
"b2"
}}},
{{
"mul_result"
,
{
"mul_out2"
}},
{
"add_result"
,
{
"tmp_out2"
}},
{
"Out"
,
{
"out2"
}}},
{}));
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{{
"X"
,
{
"out2"
}},
{
"W"
,
{
"w3"
}},
{
"b"
,
{
"b3"
}}},
{{
"mul_result"
,
{
"mul_out3"
}},
{
"add_result"
,
{
"tmp_out3"
}},
{
"Out"
,
{
"out3"
}}},
{}));
net
.
CompleteAddOp
();
auto
backward
=
f
::
Backward
(
net
,
{
"mul_out2"
,
"tmp_out2"
,
"out2"
});
ASSERT_TRUE
(
backward
->
IsNetOp
());
auto
bwd_net
=
static_cast
<
ops
::
NetOp
*>
(
backward
.
get
());
ASSERT_EQ
(
bwd_net
->
ops_
.
size
(),
3UL
);
auto
&
grad_fc
=
*
bwd_net
->
ops_
[
0
];
EXPECT_EQ
(
grad_fc
.
inputs_
[
"all"
].
size
(),
2UL
/* external input number */
+
1UL
/* external output number*/
+
1UL
/* number of gradient of external output*/
+
2U
/* internal variable number*/
);
EXPECT_EQ
(
grad_fc
.
outputs_
[
"all"
].
size
(),
2UL
/* input number of mul*/
+
2UL
/* input number of rowwise_add
*/
+
1UL
/* input number of sigmod */
);
EXPECT_EQ
(
bwd_net
->
ops_
[
1
]
->
inputs_
[
"all"
].
size
(),
0UL
);
EXPECT_EQ
(
bwd_net
->
ops_
[
1
]
->
outputs_
[
"all"
].
size
(),
0UL
);
EXPECT_EQ
(
bwd_net
->
ops_
[
2
]
->
inputs_
[
"all"
].
size
(),
0UL
);
EXPECT_EQ
(
bwd_net
->
ops_
[
2
]
->
outputs_
[
"all"
].
size
(),
0UL
);
}
paddle/framework/operator.cc
浏览文件 @
dfb4ea76
...
@@ -43,7 +43,7 @@ std::unordered_map<std::string, OpProto>& OpProtos() {
...
@@ -43,7 +43,7 @@ std::unordered_map<std::string, OpProto>& OpProtos() {
const
std
::
string
&
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
auto
it
=
inputs_
.
find
(
name
);
auto
it
=
inputs_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
inputs_
.
end
(),
"Op %s does not have
out
put %s"
,
type_
,
PADDLE_ENFORCE
(
it
!=
inputs_
.
end
(),
"Op %s does not have
in
put %s"
,
type_
,
name
);
name
);
PADDLE_ENFORCE_EQ
(
it
->
second
.
size
(),
1UL
,
PADDLE_ENFORCE_EQ
(
it
->
second
.
size
(),
1UL
,
"Op %s input %s should contain only one variable"
,
type_
,
"Op %s input %s should contain only one variable"
,
type_
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录