Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1ed5f02d
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2310
Star
20933
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1ed5f02d
编写于
8月 14, 2017
作者:
Y
Yu Yang
提交者:
GitHub
8月 14, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request
#14
from reyoung/feature/refactorize_framework_proto
Polish Our code by YuYang's review
上级
88a3d8dd
f09cb657
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
128 addition
and
164 deletion
+128
-164
paddle/framework/backward_test.cc
paddle/framework/backward_test.cc
+14
-12
paddle/framework/ddim.cc
paddle/framework/ddim.cc
+0
-7
paddle/framework/ddim.h
paddle/framework/ddim.h
+0
-2
paddle/framework/grad_op_builder.cc
paddle/framework/grad_op_builder.cc
+0
-3
paddle/framework/grad_op_builder_test.cc
paddle/framework/grad_op_builder_test.cc
+6
-6
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+18
-18
paddle/framework/op_registry_test.cc
paddle/framework/op_registry_test.cc
+19
-30
paddle/framework/operator.cc
paddle/framework/operator.cc
+44
-13
paddle/framework/operator.h
paddle/framework/operator.h
+5
-32
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+18
-27
paddle/operators/mean_op.cc
paddle/operators/mean_op.cc
+1
-1
paddle/operators/recurrent_op.cc
paddle/operators/recurrent_op.cc
+3
-3
paddle/operators/recurrent_op_test.cc
paddle/operators/recurrent_op_test.cc
+0
-2
python/paddle/v2/framework/tests/test_add_two_op.py
python/paddle/v2/framework/tests/test_add_two_op.py
+0
-8
未找到文件。
paddle/framework/backward_test.cc
浏览文件 @
1ed5f02d
...
@@ -39,9 +39,9 @@ class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
...
@@ -39,9 +39,9 @@ class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
public:
public:
RowWiseAddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
RowWiseAddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input X of Add"
).
Ignore
Gradient
();
AddInput
(
"X"
,
"Input X of Add"
).
AsNo
Gradient
();
AddInput
(
"b"
,
"Bias of Add"
).
Ignore
Gradient
();
AddInput
(
"b"
,
"Bias of Add"
).
AsNo
Gradient
();
AddOutput
(
"Out"
,
"Out of Add"
).
Ignore
Gradient
();
AddOutput
(
"Out"
,
"Out of Add"
).
AsNo
Gradient
();
AddComment
(
"Add Op"
);
AddComment
(
"Add Op"
);
}
}
};
};
...
@@ -111,8 +111,8 @@ class FcOpMaker : public OpProtoAndCheckerMaker {
...
@@ -111,8 +111,8 @@ class FcOpMaker : public OpProtoAndCheckerMaker {
AddInput
(
"X"
,
"x"
);
AddInput
(
"X"
,
"x"
);
AddInput
(
"W"
,
"w"
);
AddInput
(
"W"
,
"w"
);
AddInput
(
"b"
,
"b"
);
AddInput
(
"b"
,
"b"
);
AddOutput
(
"mul_result"
,
""
).
SetTemporary
();
AddOutput
(
"mul_result"
,
""
).
AsIntermediate
();
AddOutput
(
"add_result"
,
""
).
SetTemporary
();
AddOutput
(
"add_result"
,
""
).
AsIntermediate
();
AddOutput
(
"Out"
,
""
);
AddOutput
(
"Out"
,
""
);
AddComment
(
""
);
AddComment
(
""
);
}
}
...
@@ -143,7 +143,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
...
@@ -143,7 +143,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
public:
public:
AddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
AddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"x"
).
SetMultip
le
();
AddInput
(
"X"
,
"x"
).
AsDuplicab
le
();
AddOutput
(
"Y"
,
"y"
);
AddOutput
(
"Y"
,
"y"
);
AddComment
(
""
);
AddComment
(
""
);
}
}
...
@@ -392,18 +392,20 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
...
@@ -392,18 +392,20 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
auto
bwd_net
=
static_cast
<
ops
::
NetOp
*>
(
backward
.
get
());
auto
bwd_net
=
static_cast
<
ops
::
NetOp
*>
(
backward
.
get
());
ASSERT_EQ
(
bwd_net
->
ops_
.
size
(),
3UL
);
ASSERT_EQ
(
bwd_net
->
ops_
.
size
(),
3UL
);
auto
&
grad_fc
=
*
bwd_net
->
ops_
[
0
];
auto
&
grad_fc
=
*
bwd_net
->
ops_
[
0
];
EXPECT_EQ
(
grad_fc
.
inputs_
[
"all"
].
size
(),
const
char
*
all
=
paddle
::
operators
::
NetOp
::
kAll
;
EXPECT_EQ
(
grad_fc
.
inputs_
[
all
].
size
(),
2UL
/* external input number */
2UL
/* external input number */
+
1UL
/* external output number*/
+
1UL
/* external output number*/
+
1UL
/* number of gradient of external output*/
+
1UL
/* number of gradient of external output*/
+
2U
/* internal variable number*/
);
+
2U
/* internal variable number*/
);
EXPECT_EQ
(
grad_fc
.
outputs_
[
"all"
].
size
(),
EXPECT_EQ
(
grad_fc
.
outputs_
[
all
].
size
(),
2UL
/* input number of mul*/
2UL
/* input number of mul*/
+
2UL
/* input number of rowwise_add
+
2UL
/* input number of rowwise_add
*/
*/
+
1UL
/* input number of sigmod */
);
+
1UL
/* input number of sigmod */
);
EXPECT_EQ
(
bwd_net
->
ops_
[
1
]
->
inputs_
[
"all"
].
size
(),
0UL
);
EXPECT_EQ
(
bwd_net
->
ops_
[
1
]
->
inputs_
[
all
].
size
(),
0UL
);
EXPECT_EQ
(
bwd_net
->
ops_
[
1
]
->
outputs_
[
"all"
].
size
(),
0UL
);
EXPECT_EQ
(
bwd_net
->
ops_
[
1
]
->
outputs_
[
all
].
size
(),
0UL
);
EXPECT_EQ
(
bwd_net
->
ops_
[
2
]
->
inputs_
[
"all"
].
size
(),
0UL
);
EXPECT_EQ
(
bwd_net
->
ops_
[
2
]
->
inputs_
[
all
].
size
(),
0UL
);
EXPECT_EQ
(
bwd_net
->
ops_
[
2
]
->
outputs_
[
"all"
].
size
(),
0UL
);
EXPECT_EQ
(
bwd_net
->
ops_
[
2
]
->
outputs_
[
all
].
size
(),
0UL
);
}
}
paddle/framework/ddim.cc
浏览文件 @
1ed5f02d
...
@@ -283,12 +283,5 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
...
@@ -283,12 +283,5 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
DDim
::
DDim
(
std
::
initializer_list
<
int
>
init_list
)
{
DDim
::
DDim
(
std
::
initializer_list
<
int
>
init_list
)
{
*
this
=
make_ddim
(
init_list
);
*
this
=
make_ddim
(
init_list
);
}
}
std
::
string
DDim
::
DebugString
()
const
{
std
::
ostringstream
ss
;
ss
<<
*
this
;
return
ss
.
str
();
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/ddim.h
浏览文件 @
1ed5f02d
...
@@ -72,8 +72,6 @@ struct DDim {
...
@@ -72,8 +72,6 @@ struct DDim {
DDim
operator
*
(
DDim
d
)
const
;
DDim
operator
*
(
DDim
d
)
const
;
ssize_t
size
()
const
;
ssize_t
size
()
const
;
std
::
string
DebugString
()
const
;
};
};
/**
/**
...
...
paddle/framework/grad_op_builder.cc
浏览文件 @
1ed5f02d
...
@@ -18,9 +18,6 @@ permissions and limitations under the License. */
...
@@ -18,9 +18,6 @@ permissions and limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
OpRegistry
;
enum
class
OpArgType
{
IN
,
OUT
};
enum
class
OpArgType
{
IN
,
OUT
};
static
void
TransOpArg
(
const
OperatorBase
*
src_op
,
OperatorBase
*
dst_op
,
static
void
TransOpArg
(
const
OperatorBase
*
src_op
,
OperatorBase
*
dst_op
,
...
...
paddle/framework/grad_op_builder_test.cc
浏览文件 @
1ed5f02d
...
@@ -21,10 +21,10 @@ class MutiInOutOpMaker : public OpProtoAndCheckerMaker {
...
@@ -21,10 +21,10 @@ class MutiInOutOpMaker : public OpProtoAndCheckerMaker {
MutiInOutOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
MutiInOutOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"In1"
,
"a single input"
);
AddInput
(
"In1"
,
"a single input"
);
AddInput
(
"In2_mult"
,
"a multiple input"
).
SetMultip
le
();
AddInput
(
"In2_mult"
,
"a multiple input"
).
AsDuplicab
le
();
AddInput
(
"In3"
,
"another single input"
);
AddInput
(
"In3"
,
"another single input"
);
AddOutput
(
"Out1"
,
"a single output"
);
AddOutput
(
"Out1"
,
"a single output"
);
AddOutput
(
"Out2_mult"
,
"a multiple output"
).
SetMultip
le
();
AddOutput
(
"Out2_mult"
,
"a multiple output"
).
AsDuplicab
le
();
AddComment
(
"test op with multiple inputs and outputs"
);
AddComment
(
"test op with multiple inputs and outputs"
);
}
}
};
};
...
@@ -34,10 +34,10 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
...
@@ -34,10 +34,10 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
IOIgnoredOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
IOIgnoredOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"In1"
,
"a single input"
);
AddInput
(
"In1"
,
"a single input"
);
AddInput
(
"In2_mult"
,
"a multiple input"
).
SetMultiple
().
Ignore
Gradient
();
AddInput
(
"In2_mult"
,
"a multiple input"
).
AsDuplicable
().
AsNo
Gradient
();
AddInput
(
"In3_mult"
,
"another multiple input"
).
SetMultip
le
();
AddInput
(
"In3_mult"
,
"another multiple input"
).
AsDuplicab
le
();
AddOutput
(
"Out1_mult"
,
"a multiple output"
).
SetMultip
le
();
AddOutput
(
"Out1_mult"
,
"a multiple output"
).
AsDuplicab
le
();
AddOutput
(
"Out2"
,
"a single output"
).
Ignore
Gradient
();
AddOutput
(
"Out2"
,
"a single output"
).
AsNo
Gradient
();
AddComment
(
"op with inputs and outputs ignored in gradient calculating"
);
AddComment
(
"op with inputs and outputs ignored in gradient calculating"
);
}
}
};
};
...
...
paddle/framework/op_registry.h
浏览文件 @
1ed5f02d
...
@@ -47,17 +47,20 @@ class OpProtoAndCheckerMaker {
...
@@ -47,17 +47,20 @@ class OpProtoAndCheckerMaker {
struct
VariableBuilder
{
struct
VariableBuilder
{
OpProto
::
Var
*
var_
;
OpProto
::
Var
*
var_
;
VariableBuilder
&
SetMultip
le
()
{
VariableBuilder
&
AsDuplicab
le
()
{
var_
->
set_duplicable
(
true
);
var_
->
set_duplicable
(
true
);
return
*
this
;
return
*
this
;
}
}
VariableBuilder
&
SetTemporary
()
{
VariableBuilder
&
AsIntermediate
()
{
var_
->
set_intermediate
(
true
);
var_
->
set_intermediate
(
true
);
return
*
this
;
return
*
this
;
}
}
VariableBuilder
&
IgnoreGradient
()
{
// TODO(FengJiayi, yuyang18): `AsNoGradient` is a very bad name, because it
// means that input/output is not needed when calculate gradient. It does
// not mean no gradient when backward. It should be changed soon.
VariableBuilder
&
AsNoGradient
()
{
var_
->
set_no_gradient
(
true
);
var_
->
set_no_gradient
(
true
);
return
*
this
;
return
*
this
;
}
}
...
@@ -118,7 +121,7 @@ class OpProtoAndCheckerMaker {
...
@@ -118,7 +121,7 @@ class OpProtoAndCheckerMaker {
class
OpRegistry
{
class
OpRegistry
{
using
OpCreator
=
std
::
function
<
OperatorBase
*
()
>
;
using
OpCreator
=
std
::
function
<
OperatorBase
*
()
>
;
using
VarNameMap
=
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
;
using
VarNameMap
=
OperatorBase
::
VarNameMap
;
public:
public:
template
<
typename
OpType
,
typename
ProtoMakerType
>
template
<
typename
OpType
,
typename
ProtoMakerType
>
...
@@ -164,25 +167,22 @@ class OpRegistry {
...
@@ -164,25 +167,22 @@ class OpRegistry {
return
std
::
shared_ptr
<
OperatorBase
>
(
op
);
return
std
::
shared_ptr
<
OperatorBase
>
(
op
);
}
}
static
std
::
shared_ptr
<
OperatorBase
>
CreateOp
(
const
OpDesc
&
op_desc
)
{
static
VarNameMap
ConvertOpDescVarsToVarNameMap
(
VarNameMap
inputs
;
const
google
::
protobuf
::
RepeatedPtrField
<
OpDesc
::
Var
>&
op_desc_vars
)
{
for
(
auto
&
input
:
op_desc
.
inputs
())
{
VarNameMap
ret_val
;
auto
&
var_names
=
inputs
[
input
.
parameter
()];
for
(
auto
&
var
:
op_desc_vars
)
{
auto
&
var_names_in_proto
=
input
.
arguments
();
auto
&
var_names
=
ret_val
[
var
.
parameter
()];
var_names
.
reserve
(
static_cast
<
size_t
>
(
var_names_in_proto
.
size
()));
auto
&
var_names_in_proto
=
var
.
arguments
();
std
::
copy
(
var_names_in_proto
.
begin
(),
var_names_in_proto
.
end
(),
std
::
back_inserter
(
var_names
));
}
VarNameMap
outputs
;
for
(
auto
&
output
:
op_desc
.
outputs
())
{
auto
&
var_names
=
outputs
[
output
.
parameter
()];
auto
&
var_names_in_proto
=
output
.
arguments
();
var_names
.
reserve
(
static_cast
<
size_t
>
(
var_names_in_proto
.
size
()));
var_names
.
reserve
(
static_cast
<
size_t
>
(
var_names_in_proto
.
size
()));
std
::
copy
(
var_names_in_proto
.
begin
(),
var_names_in_proto
.
end
(),
std
::
copy
(
var_names_in_proto
.
begin
(),
var_names_in_proto
.
end
(),
std
::
back_inserter
(
var_names
));
std
::
back_inserter
(
var_names
));
}
}
return
ret_val
;
}
static
std
::
shared_ptr
<
OperatorBase
>
CreateOp
(
const
OpDesc
&
op_desc
)
{
VarNameMap
inputs
=
ConvertOpDescVarsToVarNameMap
(
op_desc
.
inputs
());
VarNameMap
outputs
=
ConvertOpDescVarsToVarNameMap
(
op_desc
.
outputs
());
AttributeMap
attrs
;
AttributeMap
attrs
;
for
(
auto
&
attr
:
op_desc
.
attrs
())
{
for
(
auto
&
attr
:
op_desc
.
attrs
())
{
attrs
[
attr
.
name
()]
=
GetAttrValue
(
attr
);
attrs
[
attr
.
name
()]
=
GetAttrValue
(
attr
);
...
...
paddle/framework/op_registry_test.cc
浏览文件 @
1ed5f02d
...
@@ -38,8 +38,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -38,8 +38,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
public:
MyTestOpProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
MyTestOpProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"input"
,
"input of cosine op"
).
SetMultip
le
();
AddInput
(
"input"
,
"input of cosine op"
).
AsDuplicab
le
();
AddOutput
(
"output"
,
"output of cosine op"
).
SetTemporary
();
AddOutput
(
"output"
,
"output of cosine op"
).
AsIntermediate
();
auto
my_checker
=
[](
int
i
)
{
auto
my_checker
=
[](
int
i
)
{
PADDLE_ENFORCE
(
i
%
2
==
0
,
"'test_attr' must be even!"
);
PADDLE_ENFORCE
(
i
%
2
==
0
,
"'test_attr' must be even!"
);
};
};
...
@@ -51,6 +51,15 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -51,6 +51,15 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
static
void
BuildVar
(
const
std
::
string
&
param_name
,
std
::
initializer_list
<
const
char
*>
arguments
,
paddle
::
framework
::
OpDesc
::
Var
*
var
)
{
var
->
set_parameter
(
param_name
);
for
(
auto
&
arg_name
:
arguments
)
{
var
->
add_arguments
(
arg_name
);
}
}
REGISTER_OP
(
cos_sim
,
paddle
::
framework
::
CosineOp
,
REGISTER_OP
(
cos_sim
,
paddle
::
framework
::
CosineOp
,
paddle
::
framework
::
CosineOpProtoAndCheckerMaker
);
paddle
::
framework
::
CosineOpProtoAndCheckerMaker
);
REGISTER_OP
(
my_test_op
,
paddle
::
framework
::
MyTestOp
,
REGISTER_OP
(
my_test_op
,
paddle
::
framework
::
MyTestOp
,
...
@@ -59,13 +68,8 @@ REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
...
@@ -59,13 +68,8 @@ REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
TEST
(
OpRegistry
,
CreateOp
)
{
TEST
(
OpRegistry
,
CreateOp
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"cos_sim"
);
op_desc
.
set_type
(
"cos_sim"
);
auto
input
=
op_desc
.
add_inputs
();
BuildVar
(
"input"
,
{
"aa"
},
op_desc
.
add_inputs
());
input
->
set_parameter
(
"input"
);
BuildVar
(
"output"
,
{
"bb"
},
op_desc
.
add_outputs
());
*
input
->
mutable_arguments
()
->
Add
()
=
"aa"
;
auto
output
=
op_desc
.
add_outputs
();
output
->
set_parameter
(
"output"
);
*
output
->
mutable_arguments
()
->
Add
()
=
"bb"
;
float
scale
=
3.3
;
float
scale
=
3.3
;
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
...
@@ -85,13 +89,8 @@ TEST(OpRegistry, CreateOp) {
...
@@ -85,13 +89,8 @@ TEST(OpRegistry, CreateOp) {
TEST
(
OpRegistry
,
IllegalAttr
)
{
TEST
(
OpRegistry
,
IllegalAttr
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"cos_sim"
);
op_desc
.
set_type
(
"cos_sim"
);
auto
input
=
op_desc
.
add_inputs
();
BuildVar
(
"input"
,
{
"aa"
},
op_desc
.
add_inputs
());
input
->
set_parameter
(
"input"
);
BuildVar
(
"output"
,
{
"bb"
},
op_desc
.
add_outputs
());
*
input
->
mutable_arguments
()
->
Add
()
=
"aa"
;
auto
output
=
op_desc
.
add_outputs
();
output
->
set_parameter
(
"output"
);
*
output
->
mutable_arguments
()
->
Add
()
=
"bb"
;
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_name
(
"scale"
);
...
@@ -115,13 +114,8 @@ TEST(OpRegistry, IllegalAttr) {
...
@@ -115,13 +114,8 @@ TEST(OpRegistry, IllegalAttr) {
TEST
(
OpRegistry
,
DefaultValue
)
{
TEST
(
OpRegistry
,
DefaultValue
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"cos_sim"
);
op_desc
.
set_type
(
"cos_sim"
);
auto
input
=
op_desc
.
add_inputs
();
BuildVar
(
"input"
,
{
"aa"
},
op_desc
.
add_inputs
());
input
->
set_parameter
(
"input"
);
BuildVar
(
"output"
,
{
"bb"
},
op_desc
.
add_outputs
());
*
input
->
mutable_arguments
()
->
Add
()
=
"aa"
;
auto
output
=
op_desc
.
add_outputs
();
output
->
set_parameter
(
"output"
);
*
output
->
mutable_arguments
()
->
Add
()
=
"bb"
;
ASSERT_TRUE
(
op_desc
.
IsInitialized
());
ASSERT_TRUE
(
op_desc
.
IsInitialized
());
...
@@ -136,13 +130,8 @@ TEST(OpRegistry, DefaultValue) {
...
@@ -136,13 +130,8 @@ TEST(OpRegistry, DefaultValue) {
TEST
(
OpRegistry
,
CustomChecker
)
{
TEST
(
OpRegistry
,
CustomChecker
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"my_test_op"
);
op_desc
.
set_type
(
"my_test_op"
);
auto
input
=
op_desc
.
add_inputs
();
BuildVar
(
"input"
,
{
"ii"
},
op_desc
.
add_inputs
());
input
->
set_parameter
(
"input"
);
BuildVar
(
"output"
,
{
"oo"
},
op_desc
.
add_outputs
());
*
input
->
mutable_arguments
()
->
Add
()
=
"ii"
;
auto
output
=
op_desc
.
add_outputs
();
output
->
set_parameter
(
"output"
);
*
output
->
mutable_arguments
()
->
Add
()
=
"oo"
;
// attr 'test_attr' is not set
// attr 'test_attr' is not set
bool
caught
=
false
;
bool
caught
=
false
;
...
...
paddle/framework/operator.cc
浏览文件 @
1ed5f02d
...
@@ -42,33 +42,35 @@ std::unordered_map<std::string, OpProto>& OpProtos() {
...
@@ -42,33 +42,35 @@ std::unordered_map<std::string, OpProto>& OpProtos() {
}
}
const
std
::
string
&
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
auto
it
=
inputs_
.
find
(
name
);
auto
&
ins
=
Inputs
(
name
);
PADDLE_ENFORCE
(
it
!=
inputs_
.
end
(),
"Op %s does not have input %s"
,
type_
,
PADDLE_ENFORCE_EQ
(
ins
.
size
(),
1UL
,
name
);
PADDLE_ENFORCE_EQ
(
it
->
second
.
size
(),
1UL
,
"Op %s input %s should contain only one variable"
,
type_
,
"Op %s input %s should contain only one variable"
,
type_
,
name
);
name
);
return
i
t
->
second
[
0
];
return
i
ns
[
0
];
}
}
const
std
::
vector
<
std
::
string
>&
OperatorBase
::
Inputs
(
const
std
::
vector
<
std
::
string
>&
OperatorBase
::
Inputs
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
name
)
const
{
return
inputs_
.
at
(
name
);
auto
it
=
inputs_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
inputs_
.
end
(),
"Op %s do not have input %s"
,
type_
,
name
);
return
it
->
second
;
}
}
const
std
::
string
&
OperatorBase
::
Output
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
OperatorBase
::
Output
(
const
std
::
string
&
name
)
const
{
auto
it
=
outputs_
.
find
(
name
);
auto
&
outs
=
Outputs
(
name
);
PADDLE_ENFORCE
(
it
!=
outputs_
.
end
(),
"Op %s does not have output %s"
,
type_
,
PADDLE_ENFORCE_EQ
(
outs
.
size
(),
1UL
,
name
);
"Op %s output %s should contain only one variable"
,
type_
,
PADDLE_ENFORCE_EQ
(
it
->
second
.
size
(),
1UL
,
"Op %s input %s should contain only one variable"
,
type_
,
name
);
name
);
return
it
->
second
[
0
];
return
outs
[
0
];
}
}
const
std
::
vector
<
std
::
string
>&
OperatorBase
::
Outputs
(
const
std
::
vector
<
std
::
string
>&
OperatorBase
::
Outputs
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
name
)
const
{
return
outputs_
.
at
(
name
);
auto
it
=
outputs_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
outputs_
.
end
(),
"Op %s does not have output %s"
,
type_
,
name
);
return
it
->
second
;
}
}
std
::
string
OperatorBase
::
DebugString
()
const
{
std
::
string
OperatorBase
::
DebugString
()
const
{
...
@@ -120,5 +122,34 @@ void OperatorBase::Rename(const std::string& old_name,
...
@@ -120,5 +122,34 @@ void OperatorBase::Rename(const std::string& old_name,
}
}
}
}
std
::
vector
<
std
::
string
>
OperatorBase
::
OutputVars
(
bool
has_intermediate
)
const
{
std
::
vector
<
std
::
string
>
ret_val
;
if
(
has_intermediate
)
{
// push all outputs into ret_val
for
(
auto
&
o
:
outputs_
)
{
ret_val
.
reserve
(
ret_val
.
size
()
+
o
.
second
.
size
());
ret_val
.
insert
(
ret_val
.
end
(),
o
.
second
.
begin
(),
o
.
second
.
end
());
}
return
ret_val
;
}
auto
it
=
OpProtos
().
find
(
type_
);
PADDLE_ENFORCE
(
it
!=
OpProtos
().
end
(),
"Operator %s not registered, cannot figure out intermediate outputs"
,
type_
);
// get all OpProto::Var for outputs
for
(
auto
&
o
:
it
->
second
.
outputs
())
{
// ignore all intermediate output
if
(
o
.
intermediate
())
continue
;
auto
out
=
outputs_
.
find
(
o
.
name
());
if
(
out
!=
outputs_
.
end
())
{
ret_val
.
reserve
(
ret_val
.
size
()
+
out
->
second
.
size
());
ret_val
.
insert
(
ret_val
.
end
(),
out
->
second
.
begin
(),
out
->
second
.
end
());
}
}
return
ret_val
;
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/operator.h
浏览文件 @
1ed5f02d
...
@@ -116,34 +116,7 @@ class OperatorBase {
...
@@ -116,34 +116,7 @@ class OperatorBase {
//! TODO add a vector_view to prevent memory copy.
//! TODO add a vector_view to prevent memory copy.
const
std
::
vector
<
std
::
string
>&
Outputs
(
const
std
::
string
&
name
)
const
;
const
std
::
vector
<
std
::
string
>&
Outputs
(
const
std
::
string
&
name
)
const
;
virtual
std
::
vector
<
std
::
string
>
OutputVars
(
bool
has_intermediate
)
const
{
virtual
std
::
vector
<
std
::
string
>
OutputVars
(
bool
has_intermediate
)
const
;
std
::
vector
<
std
::
string
>
ret_val
;
if
(
has_intermediate
)
{
// push all outputs into ret_val
for
(
auto
&
o
:
outputs_
)
{
ret_val
.
reserve
(
ret_val
.
size
()
+
o
.
second
.
size
());
ret_val
.
insert
(
ret_val
.
end
(),
o
.
second
.
begin
(),
o
.
second
.
end
());
}
return
ret_val
;
}
auto
it
=
OpProtos
().
find
(
type_
);
PADDLE_ENFORCE
(
it
!=
OpProtos
().
end
(),
"Operator %s not registered, cannot figure out intermediate outputs"
,
type_
);
// get all OpProto::Var for outputs
for
(
auto
&
o
:
it
->
second
.
outputs
())
{
// ignore all intermediate output
if
(
o
.
intermediate
())
continue
;
auto
out
=
outputs_
.
find
(
o
.
name
());
if
(
out
!=
outputs_
.
end
())
{
ret_val
.
reserve
(
ret_val
.
size
()
+
out
->
second
.
size
());
ret_val
.
insert
(
ret_val
.
end
(),
out
->
second
.
begin
(),
out
->
second
.
end
());
}
}
return
ret_val
;
}
std
::
string
Type
()
const
{
return
type_
;
}
std
::
string
Type
()
const
{
return
type_
;
}
const
AttributeMap
&
Attrs
()
const
{
return
attrs_
;
}
const
AttributeMap
&
Attrs
()
const
{
return
attrs_
;
}
...
@@ -154,11 +127,11 @@ class OperatorBase {
...
@@ -154,11 +127,11 @@ class OperatorBase {
// I (Inputs)
// I (Inputs)
// O (Outputs)
// O (Outputs)
// OG (Output Gradients)
// OG (Output Gradients)
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
inputs_
;
VarNameMap
inputs_
;
// NOTE: in case of OpGrad, outputs_ contains
// NOTE: in case of OpGrad, outputs_ contains
// IG (Inputs Gradients)
// IG (Inputs Gradients)
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
outputs_
;
VarNameMap
outputs_
;
AttributeMap
attrs_
;
AttributeMap
attrs_
;
};
};
...
@@ -177,11 +150,11 @@ class InferShapeContext {
...
@@ -177,11 +150,11 @@ class InferShapeContext {
:
op_
(
op
),
scope_
(
scope
)
{}
:
op_
(
op
),
scope_
(
scope
)
{}
size_t
InputSize
(
const
std
::
string
&
name
)
const
{
size_t
InputSize
(
const
std
::
string
&
name
)
const
{
return
op_
.
inputs_
.
at
(
name
).
size
();
return
op_
.
Inputs
(
name
).
size
();
}
}
size_t
OutputSize
(
const
std
::
string
&
name
)
const
{
size_t
OutputSize
(
const
std
::
string
&
name
)
const
{
return
op_
.
outputs_
.
at
(
name
).
size
();
return
op_
.
Outputs
(
name
).
size
();
}
}
const
Variable
*
InputVar
(
const
std
::
string
&
name
)
const
{
const
Variable
*
InputVar
(
const
std
::
string
&
name
)
const
{
...
...
paddle/framework/operator_test.cc
浏览文件 @
1ed5f02d
...
@@ -56,19 +56,24 @@ class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -56,19 +56,24 @@ class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
static
void
BuildVar
(
const
std
::
string
&
param_name
,
std
::
initializer_list
<
const
char
*>
arguments
,
paddle
::
framework
::
OpDesc
::
Var
*
var
)
{
var
->
set_parameter
(
param_name
);
for
(
auto
&
arg_name
:
arguments
)
{
*
var
->
mutable_arguments
()
->
Add
()
=
arg_name
;
}
}
REGISTER_OP
(
test_operator
,
paddle
::
framework
::
OpWithoutKernelTest
,
REGISTER_OP
(
test_operator
,
paddle
::
framework
::
OpWithoutKernelTest
,
paddle
::
framework
::
OpeWithoutKernelTestProtoAndCheckerMaker
);
paddle
::
framework
::
OpeWithoutKernelTestProtoAndCheckerMaker
);
TEST
(
OperatorBase
,
all
)
{
TEST
(
OperatorBase
,
all
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"test_operator"
);
op_desc
.
set_type
(
"test_operator"
);
auto
*
ipt
=
op_desc
.
mutable_inputs
()
->
Add
();
BuildVar
(
"IN1"
,
{
"input"
},
op_desc
.
add_inputs
());
*
ipt
->
mutable_arguments
()
->
Add
()
=
"IN1"
;
BuildVar
(
"OUT1"
,
{
"output"
},
op_desc
.
add_outputs
());
ipt
->
set_parameter
(
"input"
);
auto
*
output
=
op_desc
.
mutable_outputs
()
->
Add
();
*
output
->
mutable_arguments
()
->
Add
()
=
"OUT1"
;
output
->
set_parameter
(
"output"
);
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_name
(
"scale"
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
...
@@ -127,9 +132,9 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
...
@@ -127,9 +132,9 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
OpKernelTestMultiInputsProtoAndCheckerMaker
(
OpProto
*
proto
,
OpKernelTestMultiInputsProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"xs"
,
"inputs of test op"
).
SetMultip
le
();
AddInput
(
"xs"
,
"inputs of test op"
).
AsDuplicab
le
();
AddInput
(
"k"
,
"input of test op"
);
AddInput
(
"k"
,
"input of test op"
);
AddOutput
(
"ys"
,
"outputs of test op"
).
SetMultip
le
();
AddOutput
(
"ys"
,
"outputs of test op"
).
AsDuplicab
le
();
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
)
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
)
.
SetDefault
(
1.0
)
.
SetDefault
(
1.0
)
.
LargerThan
(
0.0
);
.
LargerThan
(
0.0
);
...
@@ -186,13 +191,8 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel,
...
@@ -186,13 +191,8 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel,
TEST
(
OpKernel
,
all
)
{
TEST
(
OpKernel
,
all
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"op_with_kernel"
);
op_desc
.
set_type
(
"op_with_kernel"
);
auto
*
ipt
=
op_desc
.
mutable_inputs
()
->
Add
();
BuildVar
(
"IN1"
,
{
"x"
},
op_desc
.
add_inputs
());
*
ipt
->
mutable_arguments
()
->
Add
()
=
"IN1"
;
BuildVar
(
"OUT1"
,
{
"y"
},
op_desc
.
add_outputs
());
ipt
->
set_parameter
(
"x"
);
auto
*
output
=
op_desc
.
mutable_outputs
()
->
Add
();
*
output
->
mutable_arguments
()
->
Add
()
=
"OUT1"
;
output
->
set_parameter
(
"y"
);
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_name
(
"scale"
);
...
@@ -219,18 +219,9 @@ TEST(OpKernel, multi_inputs) {
...
@@ -219,18 +219,9 @@ TEST(OpKernel, multi_inputs) {
OpDesc
op_desc
;
OpDesc
op_desc
;
op_desc
.
set_type
(
"op_multi_inputs_with_kernel"
);
op_desc
.
set_type
(
"op_multi_inputs_with_kernel"
);
auto
x
=
op_desc
.
mutable_inputs
()
->
Add
();
BuildVar
(
"xs"
,
{
"x0"
,
"x1"
,
"x2"
},
op_desc
.
add_inputs
());
x
->
set_parameter
(
"xs"
);
BuildVar
(
"k"
,
{
"k0"
},
op_desc
.
add_inputs
());
*
x
->
mutable_arguments
()
->
Add
()
=
"x0"
;
BuildVar
(
"ys"
,
{
"y0"
,
"y1"
},
op_desc
.
add_outputs
());
*
x
->
mutable_arguments
()
->
Add
()
=
"x1"
;
*
x
->
mutable_arguments
()
->
Add
()
=
"x2"
;
auto
k
=
op_desc
.
mutable_inputs
()
->
Add
();
k
->
set_parameter
(
"k"
);
*
k
->
mutable_arguments
()
->
Add
()
=
"k0"
;
auto
y
=
op_desc
.
mutable_outputs
()
->
Add
();
y
->
set_parameter
(
"ys"
);
*
y
->
mutable_arguments
()
->
Add
()
=
"y0"
;
*
y
->
mutable_arguments
()
->
Add
()
=
"y1"
;
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_name
(
"scale"
);
...
...
paddle/operators/mean_op.cc
浏览文件 @
1ed5f02d
...
@@ -32,7 +32,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -32,7 +32,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
MeanOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
MeanOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The input of mean op"
);
AddInput
(
"X"
,
"The input of mean op"
);
AddOutput
(
"Out"
,
"The output of mean op"
).
Ignore
Gradient
();
AddOutput
(
"Out"
,
"The output of mean op"
).
AsNo
Gradient
();
AddComment
(
"Mean Operator"
);
AddComment
(
"Mean Operator"
);
}
}
};
};
...
...
paddle/operators/recurrent_op.cc
浏览文件 @
1ed5f02d
...
@@ -152,13 +152,13 @@ class RecurrentAlgorithmProtoAndCheckerMaker
...
@@ -152,13 +152,13 @@ class RecurrentAlgorithmProtoAndCheckerMaker
// inputs and outputs stored in proto
// inputs and outputs stored in proto
AddInput
(
name
.
inlinks
,
AddInput
(
name
.
inlinks
,
"the inputs that need to be segmented for each step."
)
"the inputs that need to be segmented for each step."
)
.
SetMultip
le
();
.
AsDuplicab
le
();
AddInput
(
name
.
boot_memories
,
"variables to initialize memories."
)
AddInput
(
name
.
boot_memories
,
"variables to initialize memories."
)
.
SetMultip
le
();
.
AsDuplicab
le
();
AddInput
(
name
.
step_net
,
"network shared by all steps."
);
AddInput
(
name
.
step_net
,
"network shared by all steps."
);
AddOutput
(
name
.
outlinks
,
"the outputs that need to concated for all steps."
)
AddOutput
(
name
.
outlinks
,
"the outputs that need to concated for all steps."
)
.
SetMultip
le
();
.
AsDuplicab
le
();
AddOutput
(
name
.
step_scopes
,
"step scopes"
);
AddOutput
(
name
.
step_scopes
,
"step scopes"
);
// Attributes stored in AttributeMap
// Attributes stored in AttributeMap
...
...
paddle/operators/recurrent_op_test.cc
浏览文件 @
1ed5f02d
...
@@ -26,8 +26,6 @@ namespace paddle {
...
@@ -26,8 +26,6 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
framework
;
// using framework::make_ddim;
// using framework::DDim;
class
RecurrentGradientAlgorithmTest
:
public
::
testing
::
Test
{
class
RecurrentGradientAlgorithmTest
:
public
::
testing
::
Test
{
protected:
protected:
...
...
python/paddle/v2/framework/tests/test_add_two_op.py
浏览文件 @
1ed5f02d
...
@@ -19,13 +19,5 @@ class TestAddOp(unittest.TestCase):
...
@@ -19,13 +19,5 @@ class TestAddOp(unittest.TestCase):
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
+
self
.
inputs
[
'Y'
]}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
+
self
.
inputs
[
'Y'
]}
#class TestAddGradOp(unittest.TestCase):
# def test_add_grad(self):
# op = Operator('add_two', X="X", Y="Y", Out="Out")
# backward_op = core.Operator.backward(op, set())
# self.assertEqual(backward_op.type(), "add_two_grad")
# expected = '''Op(add_two_grad), inputs:(X, Y, Out, Out@GRAD), outputs:(X@GRAD, Y@GRAD).'''
# self.assertEqual(expected, str(backward_op))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录