Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
d0b25ac9
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d0b25ac9
编写于
7月 28, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix some unittest error
上级
8bf0ca0f
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
32 addition
and
16 deletion
+32
-16
paddle/framework/backward.cc
paddle/framework/backward.cc
+9
-4
paddle/framework/backward_test.cc
paddle/framework/backward_test.cc
+20
-10
paddle/framework/operator.cc
paddle/framework/operator.cc
+2
-2
paddle/framework/operator.h
paddle/framework/operator.h
+1
-0
未找到文件。
paddle/framework/backward.cc
浏览文件 @
d0b25ac9
...
...
@@ -72,7 +72,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
return
EmptyOp
();
}
auto
*
net
=
new
NetOp
();
auto
net
=
std
::
make_shared
<
NetOp
>
();
if
(
forwardOp
.
IsNetOp
())
{
//! TODO(dzh)
...
...
@@ -84,7 +84,8 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
auto
&
forwardNet
=
static_cast
<
const
NetOp
&>
(
forwardOp
);
// travesal subnet/op
for
(
auto
it
=
forwardNet
.
ops_
.
end
();
it
!=
forwardNet
.
ops_
.
begin
();
--
it
)
{
for
(
auto
it
=
forwardNet
.
ops_
.
rbegin
();
it
!=
forwardNet
.
ops_
.
rend
();
++
it
)
{
auto
fwd
=
*
it
;
// for (auto& fwd : forwardNet.ops_) {
// auto bwd = Backward(*fwd, no_grad_names);
...
...
@@ -115,7 +116,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
insert_postion
.
push_back
(
{
dup_op
.
back
(),
OpRegistry
::
CreateOp
(
"
A
dd"
,
{
dup_outputs
},
{
name
},
"
a
dd"
,
{
dup_outputs
},
{
name
},
{{
"input_format"
,
std
::
vector
<
int
>
{
0
,
(
int
)
dup_outputs
.
size
()}}})});
}
...
...
@@ -142,11 +143,15 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
grad_output
=
OperatorBase
::
EMPTY_VAR_NAME
();
}
}
if
(
net
->
ops_
.
empty
())
{
// Current no aux op is added to network
return
grad_op
;
}
net
->
AddOp
(
grad_op
);
}
net
->
CompleteAddOp
();
return
std
::
shared_ptr
<
OperatorBase
>
(
net
)
;
return
net
;
}
extern
std
::
shared_ptr
<
OperatorBase
>
Backward
(
...
...
paddle/framework/backward_test.cc
浏览文件 @
d0b25ac9
...
...
@@ -63,14 +63,22 @@ class FcOp : public NetOp {
public:
void
Init
()
override
{
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
{
Input
(
"X"
),
Input
(
"W"
)},
{
Output
(
"
before_ac
t"
)},
{}));
{
Output
(
"
mul_resul
t"
)},
{}));
auto
b_name
=
Input
(
"b"
);
std
::
string
before_act
=
"mul_result"
;
if
(
b_name
!=
EMPTY_VAR_NAME
())
{
AddOp
(
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{
Output
(
"before_act"
),
b_name
},
{
Output
(
"before_act"
)},
{}));
AddOp
(
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{
Output
(
"mul_result"
),
b_name
},
{
Output
(
"add_result"
)},
{}));
before_act
=
"add_result"
;
}
else
{
auto
out_varname
=
Output
(
"add_result"
);
if
(
out_varname
!=
EMPTY_VAR_NAME
())
{
this
->
Rename
(
out_varname
,
EMPTY_VAR_NAME
());
}
}
AddOp
(
OpRegistry
::
CreateOp
(
"sigmoid"
,
{
Output
(
"before_act"
)},
{
Output
(
"Out"
)},
{}));
AddOp
(
OpRegistry
::
CreateOp
(
"sigmoid"
,
{
Output
(
before_act
)},
{
Output
(
"Out"
)},
{}));
CompleteAddOp
(
false
);
}
};
...
...
@@ -82,7 +90,8 @@ class FcOpMaker : public OpProtoAndCheckerMaker {
AddInput
(
"X"
,
"x"
);
AddInput
(
"W"
,
"w"
);
AddInput
(
"b"
,
"b"
);
AddOutput
(
"before_act"
,
"before act"
).
SetTemporary
();
AddOutput
(
"mul_result"
,
""
).
SetTemporary
();
AddOutput
(
"add_result"
,
""
).
SetTemporary
();
AddOutput
(
"Out"
,
""
);
AddComment
(
""
);
}
...
...
@@ -153,7 +162,7 @@ TEST(Backward, simple_op_grad) {
TEST
(
Backward
,
net_fc_backward_normal
)
{
std
::
shared_ptr
<
f
::
OperatorBase
>
fwd
=
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{
"X"
,
"w"
,
"b"
},
{
"
out"
,
"tmp_forward
"
},
{});
"fc"
,
{
"X"
,
"w"
,
"b"
},
{
"
mul_result"
,
"add_result"
,
"out
"
},
{});
ASSERT_NE
(
fwd
,
nullptr
);
std
::
shared_ptr
<
f
::
OperatorBase
>
gop
=
f
::
Backward
(
*
fwd
,
{});
ASSERT_TRUE
(
gop
->
IsNetOp
());
...
...
@@ -176,7 +185,7 @@ TEST(Backward, net_fc_backward_normal) {
TEST
(
Backward
,
net_fc_backward_not_have_b
)
{
std
::
shared_ptr
<
f
::
OperatorBase
>
fwd
=
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{
"X"
,
"w"
,
f
::
OperatorBase
::
EMPTY_VAR_NAME
()},
{
"
out"
,
"tmp_forward
"
},
{});
{
"
mul_result"
,
"add_result"
,
"tmp
"
},
{});
ASSERT_NE
(
fwd
,
nullptr
);
std
::
shared_ptr
<
f
::
OperatorBase
>
gop
=
f
::
Backward
(
*
fwd
,
{});
ASSERT_TRUE
(
gop
->
IsNetOp
());
...
...
@@ -196,9 +205,9 @@ TEST(Backward, net_fc_backward_not_have_b) {
TEST
(
Backward
,
net_input_of_network_not_need_grad
)
{
f
::
NetOp
net
;
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{
"X"
,
"W1"
,
"b1"
},
{
"
hidden0"
,
"tmp
0"
},
{}));
{
"
mul_tmp_0"
,
"add_tmp_0"
,
"hidden
0"
},
{}));
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{
"hidden0"
,
"W2"
,
"b2"
},
{
"
hidden1"
,
"tmp
1"
},
{}));
{
"
mul_tmp_1"
,
"add_tmp_1"
,
"hidden
1"
},
{}));
net
.
CompleteAddOp
();
auto
bwd
=
Backward
(
net
,
{
"X"
});
// X@GRAD is not need.
ASSERT_TRUE
(
bwd
->
IsNetOp
());
...
...
@@ -235,6 +244,7 @@ TEST(Backward, net_shared_weight) {
ASSERT_TRUE
(
bwd
->
IsNetOp
());
auto
bwd_net
=
static_cast
<
f
::
NetOp
*>
(
bwd
.
get
());
ASSERT_EQ
(
3UL
,
bwd_net
->
ops_
.
size
());
LOG
(
INFO
)
<<
bwd_net
->
DebugString
();
ASSERT_EQ
(
"add_grad"
,
bwd_net
->
ops_
[
2
]
->
type_
);
}
...
...
paddle/framework/operator.cc
浏览文件 @
d0b25ac9
...
...
@@ -52,7 +52,7 @@ std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
PADDLE_ENFORCE
(
in_out_idxs_
!=
nullptr
,
"IO Idx could not be nullptr"
);
auto
input_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"input_format"
);
auto
offset
=
in_out_idxs_
->
at
(
name
);
PADDLE_ENFORCE
(
input_format
.
at
((
size_t
)
offset
+
1
)
<=
inputs_
.
size
(),
PADDLE_ENFORCE
(
input_format
.
at
((
size_t
)
offset
+
1
)
<=
(
int
)
inputs_
.
size
(),
"Input Out Of Range"
);
return
std
::
vector
<
std
::
string
>
{
...
...
@@ -78,7 +78,7 @@ std::vector<std::string> OperatorBase::Outputs(const std::string& name) const {
PADDLE_ENFORCE
(
in_out_idxs_
!=
nullptr
,
"InOut Indice could not be nullptr"
);
auto
output_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"output_format"
);
auto
offset
=
in_out_idxs_
->
at
(
name
);
PADDLE_ENFORCE
(
output_format
.
at
((
size_t
)
offset
+
1
)
<=
outputs_
.
size
(),
PADDLE_ENFORCE
(
output_format
.
at
((
size_t
)
offset
+
1
)
<=
(
int
)
outputs_
.
size
(),
"Output Out of Range"
);
return
std
::
vector
<
std
::
string
>
{
outputs_
.
begin
()
+
output_format
.
at
(
offset
),
...
...
paddle/framework/operator.h
浏览文件 @
d0b25ac9
...
...
@@ -101,6 +101,7 @@ class OperatorBase {
//! Get a input with argument's name described in `op_proto`
const
std
::
string
&
Input
(
const
std
::
string
&
name
)
const
;
//! Get a input which has multiple variables.
//! TODO add a vector_view to prevent memory copy.
std
::
vector
<
std
::
string
>
Inputs
(
const
std
::
string
&
name
)
const
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录