Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
4a604c26
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4a604c26
编写于
8月 14, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Polish Our code by YuYang's review
上级
88a3d8dd
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
138 addition
and
155 deletion
+138
-155
paddle/framework/backward_test.cc
paddle/framework/backward_test.cc
+14
-12
paddle/framework/ddim.cc
paddle/framework/ddim.cc
+0
-7
paddle/framework/ddim.h
paddle/framework/ddim.h
+0
-2
paddle/framework/grad_op_builder.cc
paddle/framework/grad_op_builder.cc
+0
-3
paddle/framework/grad_op_builder_test.cc
paddle/framework/grad_op_builder_test.cc
+6
-6
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+15
-18
paddle/framework/op_registry_test.cc
paddle/framework/op_registry_test.cc
+27
-26
paddle/framework/operator.cc
paddle/framework/operator.cc
+44
-13
paddle/framework/operator.h
paddle/framework/operator.h
+5
-32
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+23
-22
paddle/operators/mean_op.cc
paddle/operators/mean_op.cc
+1
-1
paddle/operators/recurrent_op.cc
paddle/operators/recurrent_op.cc
+3
-3
paddle/operators/recurrent_op_test.cc
paddle/operators/recurrent_op_test.cc
+0
-2
python/paddle/v2/framework/tests/test_add_two_op.py
python/paddle/v2/framework/tests/test_add_two_op.py
+0
-8
未找到文件。
paddle/framework/backward_test.cc
浏览文件 @
4a604c26
...
@@ -39,9 +39,9 @@ class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
...
@@ -39,9 +39,9 @@ class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
public:
public:
RowWiseAddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
RowWiseAddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input X of Add"
).
Ignore
Gradient
();
AddInput
(
"X"
,
"Input X of Add"
).
No
Gradient
();
AddInput
(
"b"
,
"Bias of Add"
).
Ignore
Gradient
();
AddInput
(
"b"
,
"Bias of Add"
).
No
Gradient
();
AddOutput
(
"Out"
,
"Out of Add"
).
Ignore
Gradient
();
AddOutput
(
"Out"
,
"Out of Add"
).
No
Gradient
();
AddComment
(
"Add Op"
);
AddComment
(
"Add Op"
);
}
}
};
};
...
@@ -111,8 +111,8 @@ class FcOpMaker : public OpProtoAndCheckerMaker {
...
@@ -111,8 +111,8 @@ class FcOpMaker : public OpProtoAndCheckerMaker {
AddInput
(
"X"
,
"x"
);
AddInput
(
"X"
,
"x"
);
AddInput
(
"W"
,
"w"
);
AddInput
(
"W"
,
"w"
);
AddInput
(
"b"
,
"b"
);
AddInput
(
"b"
,
"b"
);
AddOutput
(
"mul_result"
,
""
).
Set
Temporary
();
AddOutput
(
"mul_result"
,
""
).
Set
Intermediate
();
AddOutput
(
"add_result"
,
""
).
Set
Temporary
();
AddOutput
(
"add_result"
,
""
).
Set
Intermediate
();
AddOutput
(
"Out"
,
""
);
AddOutput
(
"Out"
,
""
);
AddComment
(
""
);
AddComment
(
""
);
}
}
...
@@ -143,7 +143,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
...
@@ -143,7 +143,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
public:
public:
AddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
AddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"x"
).
Set
Multip
le
();
AddInput
(
"X"
,
"x"
).
Set
Duplicab
le
();
AddOutput
(
"Y"
,
"y"
);
AddOutput
(
"Y"
,
"y"
);
AddComment
(
""
);
AddComment
(
""
);
}
}
...
@@ -392,18 +392,20 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
...
@@ -392,18 +392,20 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
auto
bwd_net
=
static_cast
<
ops
::
NetOp
*>
(
backward
.
get
());
auto
bwd_net
=
static_cast
<
ops
::
NetOp
*>
(
backward
.
get
());
ASSERT_EQ
(
bwd_net
->
ops_
.
size
(),
3UL
);
ASSERT_EQ
(
bwd_net
->
ops_
.
size
(),
3UL
);
auto
&
grad_fc
=
*
bwd_net
->
ops_
[
0
];
auto
&
grad_fc
=
*
bwd_net
->
ops_
[
0
];
EXPECT_EQ
(
grad_fc
.
inputs_
[
"all"
].
size
(),
const
char
*
all
=
paddle
::
operators
::
NetOp
::
kAll
;
EXPECT_EQ
(
grad_fc
.
inputs_
[
all
].
size
(),
2UL
/* external input number */
2UL
/* external input number */
+
1UL
/* external output number*/
+
1UL
/* external output number*/
+
1UL
/* number of gradient of external output*/
+
1UL
/* number of gradient of external output*/
+
2U
/* internal variable number*/
);
+
2U
/* internal variable number*/
);
EXPECT_EQ
(
grad_fc
.
outputs_
[
"all"
].
size
(),
EXPECT_EQ
(
grad_fc
.
outputs_
[
all
].
size
(),
2UL
/* input number of mul*/
2UL
/* input number of mul*/
+
2UL
/* input number of rowwise_add
+
2UL
/* input number of rowwise_add
*/
*/
+
1UL
/* input number of sigmod */
);
+
1UL
/* input number of sigmod */
);
EXPECT_EQ
(
bwd_net
->
ops_
[
1
]
->
inputs_
[
"all"
].
size
(),
0UL
);
EXPECT_EQ
(
bwd_net
->
ops_
[
1
]
->
inputs_
[
all
].
size
(),
0UL
);
EXPECT_EQ
(
bwd_net
->
ops_
[
1
]
->
outputs_
[
"all"
].
size
(),
0UL
);
EXPECT_EQ
(
bwd_net
->
ops_
[
1
]
->
outputs_
[
all
].
size
(),
0UL
);
EXPECT_EQ
(
bwd_net
->
ops_
[
2
]
->
inputs_
[
"all"
].
size
(),
0UL
);
EXPECT_EQ
(
bwd_net
->
ops_
[
2
]
->
inputs_
[
all
].
size
(),
0UL
);
EXPECT_EQ
(
bwd_net
->
ops_
[
2
]
->
outputs_
[
"all"
].
size
(),
0UL
);
EXPECT_EQ
(
bwd_net
->
ops_
[
2
]
->
outputs_
[
all
].
size
(),
0UL
);
}
}
paddle/framework/ddim.cc
浏览文件 @
4a604c26
...
@@ -283,12 +283,5 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
...
@@ -283,12 +283,5 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
DDim
::
DDim
(
std
::
initializer_list
<
int
>
init_list
)
{
DDim
::
DDim
(
std
::
initializer_list
<
int
>
init_list
)
{
*
this
=
make_ddim
(
init_list
);
*
this
=
make_ddim
(
init_list
);
}
}
std
::
string
DDim
::
DebugString
()
const
{
std
::
ostringstream
ss
;
ss
<<
*
this
;
return
ss
.
str
();
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/ddim.h
浏览文件 @
4a604c26
...
@@ -72,8 +72,6 @@ struct DDim {
...
@@ -72,8 +72,6 @@ struct DDim {
DDim
operator
*
(
DDim
d
)
const
;
DDim
operator
*
(
DDim
d
)
const
;
ssize_t
size
()
const
;
ssize_t
size
()
const
;
std
::
string
DebugString
()
const
;
};
};
/**
/**
...
...
paddle/framework/grad_op_builder.cc
浏览文件 @
4a604c26
...
@@ -18,9 +18,6 @@ permissions and limitations under the License. */
...
@@ -18,9 +18,6 @@ permissions and limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
OpRegistry
;
enum
class
OpArgType
{
IN
,
OUT
};
enum
class
OpArgType
{
IN
,
OUT
};
static
void
TransOpArg
(
const
OperatorBase
*
src_op
,
OperatorBase
*
dst_op
,
static
void
TransOpArg
(
const
OperatorBase
*
src_op
,
OperatorBase
*
dst_op
,
...
...
paddle/framework/grad_op_builder_test.cc
浏览文件 @
4a604c26
...
@@ -21,10 +21,10 @@ class MutiInOutOpMaker : public OpProtoAndCheckerMaker {
...
@@ -21,10 +21,10 @@ class MutiInOutOpMaker : public OpProtoAndCheckerMaker {
MutiInOutOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
MutiInOutOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"In1"
,
"a single input"
);
AddInput
(
"In1"
,
"a single input"
);
AddInput
(
"In2_mult"
,
"a multiple input"
).
Set
Multip
le
();
AddInput
(
"In2_mult"
,
"a multiple input"
).
Set
Duplicab
le
();
AddInput
(
"In3"
,
"another single input"
);
AddInput
(
"In3"
,
"another single input"
);
AddOutput
(
"Out1"
,
"a single output"
);
AddOutput
(
"Out1"
,
"a single output"
);
AddOutput
(
"Out2_mult"
,
"a multiple output"
).
Set
Multip
le
();
AddOutput
(
"Out2_mult"
,
"a multiple output"
).
Set
Duplicab
le
();
AddComment
(
"test op with multiple inputs and outputs"
);
AddComment
(
"test op with multiple inputs and outputs"
);
}
}
};
};
...
@@ -34,10 +34,10 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
...
@@ -34,10 +34,10 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
IOIgnoredOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
IOIgnoredOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"In1"
,
"a single input"
);
AddInput
(
"In1"
,
"a single input"
);
AddInput
(
"In2_mult"
,
"a multiple input"
).
Set
Multiple
().
Ignore
Gradient
();
AddInput
(
"In2_mult"
,
"a multiple input"
).
Set
Duplicable
().
No
Gradient
();
AddInput
(
"In3_mult"
,
"another multiple input"
).
Set
Multip
le
();
AddInput
(
"In3_mult"
,
"another multiple input"
).
Set
Duplicab
le
();
AddOutput
(
"Out1_mult"
,
"a multiple output"
).
Set
Multip
le
();
AddOutput
(
"Out1_mult"
,
"a multiple output"
).
Set
Duplicab
le
();
AddOutput
(
"Out2"
,
"a single output"
).
Ignore
Gradient
();
AddOutput
(
"Out2"
,
"a single output"
).
No
Gradient
();
AddComment
(
"op with inputs and outputs ignored in gradient calculating"
);
AddComment
(
"op with inputs and outputs ignored in gradient calculating"
);
}
}
};
};
...
...
paddle/framework/op_registry.h
浏览文件 @
4a604c26
...
@@ -47,17 +47,17 @@ class OpProtoAndCheckerMaker {
...
@@ -47,17 +47,17 @@ class OpProtoAndCheckerMaker {
struct
VariableBuilder
{
struct
VariableBuilder
{
OpProto
::
Var
*
var_
;
OpProto
::
Var
*
var_
;
VariableBuilder
&
Set
Multip
le
()
{
VariableBuilder
&
Set
Duplicab
le
()
{
var_
->
set_duplicable
(
true
);
var_
->
set_duplicable
(
true
);
return
*
this
;
return
*
this
;
}
}
VariableBuilder
&
Set
Temporary
()
{
VariableBuilder
&
Set
Intermediate
()
{
var_
->
set_intermediate
(
true
);
var_
->
set_intermediate
(
true
);
return
*
this
;
return
*
this
;
}
}
VariableBuilder
&
Ignore
Gradient
()
{
VariableBuilder
&
No
Gradient
()
{
var_
->
set_no_gradient
(
true
);
var_
->
set_no_gradient
(
true
);
return
*
this
;
return
*
this
;
}
}
...
@@ -118,7 +118,7 @@ class OpProtoAndCheckerMaker {
...
@@ -118,7 +118,7 @@ class OpProtoAndCheckerMaker {
class
OpRegistry
{
class
OpRegistry
{
using
OpCreator
=
std
::
function
<
OperatorBase
*
()
>
;
using
OpCreator
=
std
::
function
<
OperatorBase
*
()
>
;
using
VarNameMap
=
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
;
using
VarNameMap
=
OperatorBase
::
VarNameMap
;
public:
public:
template
<
typename
OpType
,
typename
ProtoMakerType
>
template
<
typename
OpType
,
typename
ProtoMakerType
>
...
@@ -164,25 +164,22 @@ class OpRegistry {
...
@@ -164,25 +164,22 @@ class OpRegistry {
return
std
::
shared_ptr
<
OperatorBase
>
(
op
);
return
std
::
shared_ptr
<
OperatorBase
>
(
op
);
}
}
static
std
::
shared_ptr
<
OperatorBase
>
CreateOp
(
const
OpDesc
&
op_desc
)
{
static
VarNameMap
ConvertOpDescVarsToVarNameMap
(
VarNameMap
inputs
;
const
google
::
protobuf
::
RepeatedPtrField
<
OpDesc
::
Var
>&
op_desc_vars
)
{
for
(
auto
&
input
:
op_desc
.
inputs
())
{
VarNameMap
ret_val
;
auto
&
var_names
=
inputs
[
input
.
parameter
()];
for
(
auto
&
var
:
op_desc_vars
)
{
auto
&
var_names_in_proto
=
input
.
arguments
();
auto
&
var_names
=
ret_val
[
var
.
parameter
()];
auto
&
var_names_in_proto
=
var
.
arguments
();
var_names
.
reserve
(
static_cast
<
size_t
>
(
var_names_in_proto
.
size
()));
var_names
.
reserve
(
static_cast
<
size_t
>
(
var_names_in_proto
.
size
()));
std
::
copy
(
var_names_in_proto
.
begin
(),
var_names_in_proto
.
end
(),
std
::
copy
(
var_names_in_proto
.
begin
(),
var_names_in_proto
.
end
(),
std
::
back_inserter
(
var_names
));
std
::
back_inserter
(
var_names
));
}
}
return
ret_val
;
VarNameMap
outputs
;
for
(
auto
&
output
:
op_desc
.
outputs
())
{
auto
&
var_names
=
outputs
[
output
.
parameter
()];
auto
&
var_names_in_proto
=
output
.
arguments
();
var_names
.
reserve
(
static_cast
<
size_t
>
(
var_names_in_proto
.
size
()));
std
::
copy
(
var_names_in_proto
.
begin
(),
var_names_in_proto
.
end
(),
std
::
back_inserter
(
var_names
));
}
}
static
std
::
shared_ptr
<
OperatorBase
>
CreateOp
(
const
OpDesc
&
op_desc
)
{
VarNameMap
inputs
=
ConvertOpDescVarsToVarNameMap
(
op_desc
.
inputs
());
VarNameMap
outputs
=
ConvertOpDescVarsToVarNameMap
(
op_desc
.
outputs
());
AttributeMap
attrs
;
AttributeMap
attrs
;
for
(
auto
&
attr
:
op_desc
.
attrs
())
{
for
(
auto
&
attr
:
op_desc
.
attrs
())
{
attrs
[
attr
.
name
()]
=
GetAttrValue
(
attr
);
attrs
[
attr
.
name
()]
=
GetAttrValue
(
attr
);
...
...
paddle/framework/op_registry_test.cc
浏览文件 @
4a604c26
...
@@ -38,8 +38,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -38,8 +38,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
public:
MyTestOpProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
MyTestOpProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"input"
,
"input of cosine op"
).
Set
Multip
le
();
AddInput
(
"input"
,
"input of cosine op"
).
Set
Duplicab
le
();
AddOutput
(
"output"
,
"output of cosine op"
).
Set
Temporary
();
AddOutput
(
"output"
,
"output of cosine op"
).
Set
Intermediate
();
auto
my_checker
=
[](
int
i
)
{
auto
my_checker
=
[](
int
i
)
{
PADDLE_ENFORCE
(
i
%
2
==
0
,
"'test_attr' must be even!"
);
PADDLE_ENFORCE
(
i
%
2
==
0
,
"'test_attr' must be even!"
);
};
};
...
@@ -51,6 +51,15 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -51,6 +51,15 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
static
void
ConstructVars
(
const
std
::
string
&
param_name
,
std
::
initializer_list
<
const
char
*>
arguments
,
paddle
::
framework
::
OpDesc
::
Var
*
var
)
{
var
->
set_parameter
(
param_name
);
for
(
auto
&
arg_name
:
arguments
)
{
*
var
->
mutable_arguments
()
->
Add
()
=
arg_name
;
}
}
REGISTER_OP
(
cos_sim
,
paddle
::
framework
::
CosineOp
,
REGISTER_OP
(
cos_sim
,
paddle
::
framework
::
CosineOp
,
paddle
::
framework
::
CosineOpProtoAndCheckerMaker
);
paddle
::
framework
::
CosineOpProtoAndCheckerMaker
);
REGISTER_OP
(
my_test_op
,
paddle
::
framework
::
MyTestOp
,
REGISTER_OP
(
my_test_op
,
paddle
::
framework
::
MyTestOp
,
...
@@ -59,13 +68,11 @@ REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
...
@@ -59,13 +68,11 @@ REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
TEST
(
OpRegistry
,
CreateOp
)
{
TEST
(
OpRegistry
,
CreateOp
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"cos_sim"
);
op_desc
.
set_type
(
"cos_sim"
);
auto
input
=
op_desc
.
add_inputs
();
auto
*
input
=
op_desc
.
add_inputs
();
input
->
set_parameter
(
"input"
);
ConstructVars
(
"input"
,
{
"aa"
},
input
);
*
input
->
mutable_arguments
()
->
Add
()
=
"aa"
;
auto
output
=
op_desc
.
add_outputs
();
auto
*
output
=
op_desc
.
add_outputs
();
output
->
set_parameter
(
"output"
);
ConstructVars
(
"output"
,
{
"bb"
},
output
);
*
output
->
mutable_arguments
()
->
Add
()
=
"bb"
;
float
scale
=
3.3
;
float
scale
=
3.3
;
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
...
@@ -85,13 +92,11 @@ TEST(OpRegistry, CreateOp) {
...
@@ -85,13 +92,11 @@ TEST(OpRegistry, CreateOp) {
TEST
(
OpRegistry
,
IllegalAttr
)
{
TEST
(
OpRegistry
,
IllegalAttr
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"cos_sim"
);
op_desc
.
set_type
(
"cos_sim"
);
auto
input
=
op_desc
.
add_inputs
();
auto
*
input
=
op_desc
.
add_inputs
();
input
->
set_parameter
(
"input"
);
ConstructVars
(
"input"
,
{
"aa"
},
input
);
*
input
->
mutable_arguments
()
->
Add
()
=
"aa"
;
auto
output
=
op_desc
.
add_outputs
();
auto
*
output
=
op_desc
.
add_outputs
();
output
->
set_parameter
(
"output"
);
ConstructVars
(
"output"
,
{
"bb"
},
output
);
*
output
->
mutable_arguments
()
->
Add
()
=
"bb"
;
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_name
(
"scale"
);
...
@@ -115,13 +120,11 @@ TEST(OpRegistry, IllegalAttr) {
...
@@ -115,13 +120,11 @@ TEST(OpRegistry, IllegalAttr) {
TEST
(
OpRegistry
,
DefaultValue
)
{
TEST
(
OpRegistry
,
DefaultValue
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"cos_sim"
);
op_desc
.
set_type
(
"cos_sim"
);
auto
input
=
op_desc
.
add_inputs
();
auto
*
input
=
op_desc
.
add_inputs
();
input
->
set_parameter
(
"input"
);
ConstructVars
(
"input"
,
{
"aa"
},
input
);
*
input
->
mutable_arguments
()
->
Add
()
=
"aa"
;
auto
output
=
op_desc
.
add_outputs
();
auto
*
output
=
op_desc
.
add_outputs
();
output
->
set_parameter
(
"output"
);
ConstructVars
(
"output"
,
{
"bb"
},
output
);
*
output
->
mutable_arguments
()
->
Add
()
=
"bb"
;
ASSERT_TRUE
(
op_desc
.
IsInitialized
());
ASSERT_TRUE
(
op_desc
.
IsInitialized
());
...
@@ -136,13 +139,11 @@ TEST(OpRegistry, DefaultValue) {
...
@@ -136,13 +139,11 @@ TEST(OpRegistry, DefaultValue) {
TEST
(
OpRegistry
,
CustomChecker
)
{
TEST
(
OpRegistry
,
CustomChecker
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"my_test_op"
);
op_desc
.
set_type
(
"my_test_op"
);
auto
input
=
op_desc
.
add_inputs
();
auto
*
input
=
op_desc
.
add_inputs
();
input
->
set_parameter
(
"input"
);
ConstructVars
(
"input"
,
{
"ii"
},
input
);
*
input
->
mutable_arguments
()
->
Add
()
=
"ii"
;
auto
output
=
op_desc
.
add_outputs
();
auto
*
output
=
op_desc
.
add_outputs
();
output
->
set_parameter
(
"output"
);
ConstructVars
(
"output"
,
{
"oo"
},
output
);
*
output
->
mutable_arguments
()
->
Add
()
=
"oo"
;
// attr 'test_attr' is not set
// attr 'test_attr' is not set
bool
caught
=
false
;
bool
caught
=
false
;
...
...
paddle/framework/operator.cc
浏览文件 @
4a604c26
...
@@ -42,33 +42,35 @@ std::unordered_map<std::string, OpProto>& OpProtos() {
...
@@ -42,33 +42,35 @@ std::unordered_map<std::string, OpProto>& OpProtos() {
}
}
const
std
::
string
&
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
auto
it
=
inputs_
.
find
(
name
);
auto
&
ins
=
Inputs
(
name
);
PADDLE_ENFORCE
(
it
!=
inputs_
.
end
(),
"Op %s does not have input %s"
,
type_
,
PADDLE_ENFORCE_EQ
(
ins
.
size
(),
1UL
,
name
);
PADDLE_ENFORCE_EQ
(
it
->
second
.
size
(),
1UL
,
"Op %s input %s should contain only one variable"
,
type_
,
"Op %s input %s should contain only one variable"
,
type_
,
name
);
name
);
return
i
t
->
second
[
0
];
return
i
ns
[
0
];
}
}
const
std
::
vector
<
std
::
string
>&
OperatorBase
::
Inputs
(
const
std
::
vector
<
std
::
string
>&
OperatorBase
::
Inputs
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
name
)
const
{
return
inputs_
.
at
(
name
);
auto
it
=
inputs_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
inputs_
.
end
(),
"Op %s do not have input %s"
,
type_
,
name
);
return
it
->
second
;
}
}
const
std
::
string
&
OperatorBase
::
Output
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
OperatorBase
::
Output
(
const
std
::
string
&
name
)
const
{
auto
it
=
outputs_
.
find
(
name
);
auto
&
outs
=
Outputs
(
name
);
PADDLE_ENFORCE
(
it
!=
outputs_
.
end
(),
"Op %s does not have output %s"
,
type_
,
PADDLE_ENFORCE_EQ
(
outs
.
size
(),
1UL
,
name
);
"Op %s output %s should contain only one variable"
,
type_
,
PADDLE_ENFORCE_EQ
(
it
->
second
.
size
(),
1UL
,
"Op %s input %s should contain only one variable"
,
type_
,
name
);
name
);
return
it
->
second
[
0
];
return
outs
[
0
];
}
}
const
std
::
vector
<
std
::
string
>&
OperatorBase
::
Outputs
(
const
std
::
vector
<
std
::
string
>&
OperatorBase
::
Outputs
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
name
)
const
{
return
outputs_
.
at
(
name
);
auto
it
=
outputs_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
outputs_
.
end
(),
"Op %s does not have output %s"
,
type_
,
name
);
return
it
->
second
;
}
}
std
::
string
OperatorBase
::
DebugString
()
const
{
std
::
string
OperatorBase
::
DebugString
()
const
{
...
@@ -120,5 +122,34 @@ void OperatorBase::Rename(const std::string& old_name,
...
@@ -120,5 +122,34 @@ void OperatorBase::Rename(const std::string& old_name,
}
}
}
}
std
::
vector
<
std
::
string
>
OperatorBase
::
OutputVars
(
bool
has_intermediate
)
const
{
std
::
vector
<
std
::
string
>
ret_val
;
if
(
has_intermediate
)
{
// push all outputs into ret_val
for
(
auto
&
o
:
outputs_
)
{
ret_val
.
reserve
(
ret_val
.
size
()
+
o
.
second
.
size
());
ret_val
.
insert
(
ret_val
.
end
(),
o
.
second
.
begin
(),
o
.
second
.
end
());
}
return
ret_val
;
}
auto
it
=
OpProtos
().
find
(
type_
);
PADDLE_ENFORCE
(
it
!=
OpProtos
().
end
(),
"Operator %s not registered, cannot figure out intermediate outputs"
,
type_
);
// get all OpProto::Var for outputs
for
(
auto
&
o
:
it
->
second
.
outputs
())
{
// ignore all intermediate output
if
(
o
.
intermediate
())
continue
;
auto
out
=
outputs_
.
find
(
o
.
name
());
if
(
out
!=
outputs_
.
end
())
{
ret_val
.
reserve
(
ret_val
.
size
()
+
out
->
second
.
size
());
ret_val
.
insert
(
ret_val
.
end
(),
out
->
second
.
begin
(),
out
->
second
.
end
());
}
}
return
ret_val
;
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/operator.h
浏览文件 @
4a604c26
...
@@ -116,34 +116,7 @@ class OperatorBase {
...
@@ -116,34 +116,7 @@ class OperatorBase {
//! TODO add a vector_view to prevent memory copy.
//! TODO add a vector_view to prevent memory copy.
const
std
::
vector
<
std
::
string
>&
Outputs
(
const
std
::
string
&
name
)
const
;
const
std
::
vector
<
std
::
string
>&
Outputs
(
const
std
::
string
&
name
)
const
;
virtual
std
::
vector
<
std
::
string
>
OutputVars
(
bool
has_intermediate
)
const
{
virtual
std
::
vector
<
std
::
string
>
OutputVars
(
bool
has_intermediate
)
const
;
std
::
vector
<
std
::
string
>
ret_val
;
if
(
has_intermediate
)
{
// push all outputs into ret_val
for
(
auto
&
o
:
outputs_
)
{
ret_val
.
reserve
(
ret_val
.
size
()
+
o
.
second
.
size
());
ret_val
.
insert
(
ret_val
.
end
(),
o
.
second
.
begin
(),
o
.
second
.
end
());
}
return
ret_val
;
}
auto
it
=
OpProtos
().
find
(
type_
);
PADDLE_ENFORCE
(
it
!=
OpProtos
().
end
(),
"Operator %s not registered, cannot figure out intermediate outputs"
,
type_
);
// get all OpProto::Var for outputs
for
(
auto
&
o
:
it
->
second
.
outputs
())
{
// ignore all intermediate output
if
(
o
.
intermediate
())
continue
;
auto
out
=
outputs_
.
find
(
o
.
name
());
if
(
out
!=
outputs_
.
end
())
{
ret_val
.
reserve
(
ret_val
.
size
()
+
out
->
second
.
size
());
ret_val
.
insert
(
ret_val
.
end
(),
out
->
second
.
begin
(),
out
->
second
.
end
());
}
}
return
ret_val
;
}
std
::
string
Type
()
const
{
return
type_
;
}
std
::
string
Type
()
const
{
return
type_
;
}
const
AttributeMap
&
Attrs
()
const
{
return
attrs_
;
}
const
AttributeMap
&
Attrs
()
const
{
return
attrs_
;
}
...
@@ -154,11 +127,11 @@ class OperatorBase {
...
@@ -154,11 +127,11 @@ class OperatorBase {
// I (Inputs)
// I (Inputs)
// O (Outputs)
// O (Outputs)
// OG (Output Gradients)
// OG (Output Gradients)
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
inputs_
;
VarNameMap
inputs_
;
// NOTE: in case of OpGrad, outputs_ contains
// NOTE: in case of OpGrad, outputs_ contains
// IG (Inputs Gradients)
// IG (Inputs Gradients)
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
outputs_
;
VarNameMap
outputs_
;
AttributeMap
attrs_
;
AttributeMap
attrs_
;
};
};
...
@@ -177,11 +150,11 @@ class InferShapeContext {
...
@@ -177,11 +150,11 @@ class InferShapeContext {
:
op_
(
op
),
scope_
(
scope
)
{}
:
op_
(
op
),
scope_
(
scope
)
{}
size_t
InputSize
(
const
std
::
string
&
name
)
const
{
size_t
InputSize
(
const
std
::
string
&
name
)
const
{
return
op_
.
inputs_
.
at
(
name
).
size
();
return
op_
.
Inputs
(
name
).
size
();
}
}
size_t
OutputSize
(
const
std
::
string
&
name
)
const
{
size_t
OutputSize
(
const
std
::
string
&
name
)
const
{
return
op_
.
outputs_
.
at
(
name
).
size
();
return
op_
.
Outputs
(
name
).
size
();
}
}
const
Variable
*
InputVar
(
const
std
::
string
&
name
)
const
{
const
Variable
*
InputVar
(
const
std
::
string
&
name
)
const
{
...
...
paddle/framework/operator_test.cc
浏览文件 @
4a604c26
...
@@ -56,19 +56,28 @@ class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -56,19 +56,28 @@ class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
static
void
ConstructVars
(
const
std
::
string
&
param_name
,
std
::
initializer_list
<
const
char
*>
arguments
,
paddle
::
framework
::
OpDesc
::
Var
*
var
)
{
var
->
set_parameter
(
param_name
);
for
(
auto
&
arg_name
:
arguments
)
{
*
var
->
mutable_arguments
()
->
Add
()
=
arg_name
;
}
}
REGISTER_OP
(
test_operator
,
paddle
::
framework
::
OpWithoutKernelTest
,
REGISTER_OP
(
test_operator
,
paddle
::
framework
::
OpWithoutKernelTest
,
paddle
::
framework
::
OpeWithoutKernelTestProtoAndCheckerMaker
);
paddle
::
framework
::
OpeWithoutKernelTestProtoAndCheckerMaker
);
TEST
(
OperatorBase
,
all
)
{
TEST
(
OperatorBase
,
all
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"test_operator"
);
op_desc
.
set_type
(
"test_operator"
);
auto
*
ipt
=
op_desc
.
mutable_inputs
()
->
Add
();
auto
*
ipt
=
op_desc
.
mutable_inputs
()
->
Add
();
*
ipt
->
mutable_arguments
()
->
Add
()
=
"IN1"
;
ConstructVars
(
"IN1"
,
{
"input"
},
ipt
);
ipt
->
set_parameter
(
"input"
);
auto
*
output
=
op_desc
.
mutable_outputs
()
->
Add
();
auto
*
output
=
op_desc
.
mutable_outputs
()
->
Add
();
*
output
->
mutable_arguments
()
->
Add
()
=
"OUT1"
;
ConstructVars
(
"OUT1"
,
{
"output"
},
output
)
;
output
->
set_parameter
(
"output"
);
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_name
(
"scale"
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
...
@@ -127,9 +136,9 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
...
@@ -127,9 +136,9 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
OpKernelTestMultiInputsProtoAndCheckerMaker
(
OpProto
*
proto
,
OpKernelTestMultiInputsProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"xs"
,
"inputs of test op"
).
Set
Multip
le
();
AddInput
(
"xs"
,
"inputs of test op"
).
Set
Duplicab
le
();
AddInput
(
"k"
,
"input of test op"
);
AddInput
(
"k"
,
"input of test op"
);
AddOutput
(
"ys"
,
"outputs of test op"
).
Set
Multip
le
();
AddOutput
(
"ys"
,
"outputs of test op"
).
Set
Duplicab
le
();
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
)
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
)
.
SetDefault
(
1.0
)
.
SetDefault
(
1.0
)
.
LargerThan
(
0.0
);
.
LargerThan
(
0.0
);
...
@@ -187,12 +196,10 @@ TEST(OpKernel, all) {
...
@@ -187,12 +196,10 @@ TEST(OpKernel, all) {
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"op_with_kernel"
);
op_desc
.
set_type
(
"op_with_kernel"
);
auto
*
ipt
=
op_desc
.
mutable_inputs
()
->
Add
();
auto
*
ipt
=
op_desc
.
mutable_inputs
()
->
Add
();
*
ipt
->
mutable_arguments
()
->
Add
()
=
"IN1"
;
ConstructVars
(
"IN1"
,
{
"x"
},
ipt
);
ipt
->
set_parameter
(
"x"
);
auto
*
output
=
op_desc
.
mutable_outputs
()
->
Add
();
auto
*
output
=
op_desc
.
mutable_outputs
()
->
Add
();
*
output
->
mutable_arguments
()
->
Add
()
=
"OUT1"
;
ConstructVars
(
"OUT1"
,
{
"y"
},
output
);
output
->
set_parameter
(
"y"
);
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_name
(
"scale"
);
...
@@ -219,18 +226,12 @@ TEST(OpKernel, multi_inputs) {
...
@@ -219,18 +226,12 @@ TEST(OpKernel, multi_inputs) {
OpDesc
op_desc
;
OpDesc
op_desc
;
op_desc
.
set_type
(
"op_multi_inputs_with_kernel"
);
op_desc
.
set_type
(
"op_multi_inputs_with_kernel"
);
auto
x
=
op_desc
.
mutable_inputs
()
->
Add
();
auto
*
x
=
op_desc
.
mutable_inputs
()
->
Add
();
x
->
set_parameter
(
"xs"
);
ConstructVars
(
"xs"
,
{
"x0"
,
"x1"
,
"x2"
},
x
);
*
x
->
mutable_arguments
()
->
Add
()
=
"x0"
;
auto
*
k
=
op_desc
.
mutable_inputs
()
->
Add
();
*
x
->
mutable_arguments
()
->
Add
()
=
"x1"
;
ConstructVars
(
"k"
,
{
"k0"
},
k
);
*
x
->
mutable_arguments
()
->
Add
()
=
"x2"
;
auto
*
y
=
op_desc
.
mutable_outputs
()
->
Add
();
auto
k
=
op_desc
.
mutable_inputs
()
->
Add
();
ConstructVars
(
"ys"
,
{
"y0"
,
"y1"
},
y
);
k
->
set_parameter
(
"k"
);
*
k
->
mutable_arguments
()
->
Add
()
=
"k0"
;
auto
y
=
op_desc
.
mutable_outputs
()
->
Add
();
y
->
set_parameter
(
"ys"
);
*
y
->
mutable_arguments
()
->
Add
()
=
"y0"
;
*
y
->
mutable_arguments
()
->
Add
()
=
"y1"
;
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_name
(
"scale"
);
...
...
paddle/operators/mean_op.cc
浏览文件 @
4a604c26
...
@@ -32,7 +32,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -32,7 +32,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
MeanOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
MeanOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The input of mean op"
);
AddInput
(
"X"
,
"The input of mean op"
);
AddOutput
(
"Out"
,
"The output of mean op"
).
Ignore
Gradient
();
AddOutput
(
"Out"
,
"The output of mean op"
).
No
Gradient
();
AddComment
(
"Mean Operator"
);
AddComment
(
"Mean Operator"
);
}
}
};
};
...
...
paddle/operators/recurrent_op.cc
浏览文件 @
4a604c26
...
@@ -152,13 +152,13 @@ class RecurrentAlgorithmProtoAndCheckerMaker
...
@@ -152,13 +152,13 @@ class RecurrentAlgorithmProtoAndCheckerMaker
// inputs and outputs stored in proto
// inputs and outputs stored in proto
AddInput
(
name
.
inlinks
,
AddInput
(
name
.
inlinks
,
"the inputs that need to be segmented for each step."
)
"the inputs that need to be segmented for each step."
)
.
Set
Multip
le
();
.
Set
Duplicab
le
();
AddInput
(
name
.
boot_memories
,
"variables to initialize memories."
)
AddInput
(
name
.
boot_memories
,
"variables to initialize memories."
)
.
Set
Multip
le
();
.
Set
Duplicab
le
();
AddInput
(
name
.
step_net
,
"network shared by all steps."
);
AddInput
(
name
.
step_net
,
"network shared by all steps."
);
AddOutput
(
name
.
outlinks
,
"the outputs that need to concated for all steps."
)
AddOutput
(
name
.
outlinks
,
"the outputs that need to concated for all steps."
)
.
Set
Multip
le
();
.
Set
Duplicab
le
();
AddOutput
(
name
.
step_scopes
,
"step scopes"
);
AddOutput
(
name
.
step_scopes
,
"step scopes"
);
// Attributes stored in AttributeMap
// Attributes stored in AttributeMap
...
...
paddle/operators/recurrent_op_test.cc
浏览文件 @
4a604c26
...
@@ -26,8 +26,6 @@ namespace paddle {
...
@@ -26,8 +26,6 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
framework
;
// using framework::make_ddim;
// using framework::DDim;
class
RecurrentGradientAlgorithmTest
:
public
::
testing
::
Test
{
class
RecurrentGradientAlgorithmTest
:
public
::
testing
::
Test
{
protected:
protected:
...
...
python/paddle/v2/framework/tests/test_add_two_op.py
浏览文件 @
4a604c26
...
@@ -19,13 +19,5 @@ class TestAddOp(unittest.TestCase):
...
@@ -19,13 +19,5 @@ class TestAddOp(unittest.TestCase):
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
+
self
.
inputs
[
'Y'
]}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
+
self
.
inputs
[
'Y'
]}
#class TestAddGradOp(unittest.TestCase):
# def test_add_grad(self):
# op = Operator('add_two', X="X", Y="Y", Out="Out")
# backward_op = core.Operator.backward(op, set())
# self.assertEqual(backward_op.type(), "add_two_grad")
# expected = '''Op(add_two_grad), inputs:(X, Y, Out, Out@GRAD), outputs:(X@GRAD, Y@GRAD).'''
# self.assertEqual(expected, str(backward_op))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录