Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
81f5f861
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
81f5f861
编写于
8月 14, 2017
作者:
Y
Yu Yang
提交者:
GitHub
8月 14, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3322 from wangkuiyi/refactorize_framework_proto
Refactorize framework/*.proto
上级
8747d60d
64a4dfef
变更
51
隐藏空白更改
内联
并排
Showing
51 changed file
with
821 addition
and
1318 deletion
+821
-1318
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+6
-10
paddle/framework/attribute.cc
paddle/framework/attribute.cc
+1
-1
paddle/framework/attribute.h
paddle/framework/attribute.h
+2
-3
paddle/framework/backward.cc
paddle/framework/backward.cc
+41
-24
paddle/framework/backward_test.cc
paddle/framework/backward_test.cc
+123
-82
paddle/framework/ddim.cc
paddle/framework/ddim.cc
+0
-1
paddle/framework/framework.proto
paddle/framework/framework.proto
+82
-0
paddle/framework/grad_op_builder.cc
paddle/framework/grad_op_builder.cc
+23
-82
paddle/framework/grad_op_builder_test.cc
paddle/framework/grad_op_builder_test.cc
+26
-30
paddle/framework/op_desc.proto
paddle/framework/op_desc.proto
+0
-56
paddle/framework/op_desc_test.cc
paddle/framework/op_desc_test.cc
+0
-35
paddle/framework/op_proto.proto
paddle/framework/op_proto.proto
+0
-116
paddle/framework/op_proto_test.cc
paddle/framework/op_proto_test.cc
+0
-31
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+43
-119
paddle/framework/op_registry_test.cc
paddle/framework/op_registry_test.cc
+21
-24
paddle/framework/operator.cc
paddle/framework/operator.cc
+95
-57
paddle/framework/operator.h
paddle/framework/operator.h
+36
-69
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+29
-56
paddle/framework/pybind.cc
paddle/framework/pybind.cc
+11
-23
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+1
-1
paddle/operators/add_op.cc
paddle/operators/add_op.cc
+5
-8
paddle/operators/add_op.h
paddle/operators/add_op.h
+3
-3
paddle/operators/cross_entropy_op.cc
paddle/operators/cross_entropy_op.cc
+7
-14
paddle/operators/cross_entropy_op.h
paddle/operators/cross_entropy_op.h
+1
-1
paddle/operators/fill_zeros_like_op.cc
paddle/operators/fill_zeros_like_op.cc
+4
-11
paddle/operators/fill_zeros_like_op.h
paddle/operators/fill_zeros_like_op.h
+1
-1
paddle/operators/gaussian_random_op.cc
paddle/operators/gaussian_random_op.cc
+2
-1
paddle/operators/mean_op.cc
paddle/operators/mean_op.cc
+4
-6
paddle/operators/mean_op.h
paddle/operators/mean_op.h
+3
-3
paddle/operators/mul_op.cc
paddle/operators/mul_op.cc
+5
-6
paddle/operators/mul_op.h
paddle/operators/mul_op.h
+4
-7
paddle/operators/net_op.cc
paddle/operators/net_op.cc
+35
-27
paddle/operators/net_op.h
paddle/operators/net_op.h
+5
-1
paddle/operators/net_op_test.cc
paddle/operators/net_op_test.cc
+14
-24
paddle/operators/recurrent_op.cc
paddle/operators/recurrent_op.cc
+11
-6
paddle/operators/recurrent_op.h
paddle/operators/recurrent_op.h
+3
-1
paddle/operators/recurrent_op_test.cc
paddle/operators/recurrent_op_test.cc
+11
-157
paddle/operators/rowwise_add_op.cc
paddle/operators/rowwise_add_op.cc
+4
-6
paddle/operators/rowwise_add_op.h
paddle/operators/rowwise_add_op.h
+3
-3
paddle/operators/sgd_op.cc
paddle/operators/sgd_op.cc
+4
-8
paddle/operators/sigmoid_op.cc
paddle/operators/sigmoid_op.cc
+1
-3
paddle/operators/sigmoid_op.h
paddle/operators/sigmoid_op.h
+2
-2
paddle/operators/softmax_op.cc
paddle/operators/softmax_op.cc
+3
-11
paddle/operators/uniform_random_op.cc
paddle/operators/uniform_random_op.cc
+2
-2
paddle/operators/uniform_random_op.cu
paddle/operators/uniform_random_op.cu
+1
-1
python/paddle/v2/framework/op.py
python/paddle/v2/framework/op.py
+43
-84
python/paddle/v2/framework/tests/gradient_checker.py
python/paddle/v2/framework/tests/gradient_checker.py
+21
-13
python/paddle/v2/framework/tests/test_add_two_op.py
python/paddle/v2/framework/tests/test_add_two_op.py
+0
-9
python/paddle/v2/framework/tests/test_net.py
python/paddle/v2/framework/tests/test_net.py
+6
-6
python/paddle/v2/framework/tests/test_operator.py
python/paddle/v2/framework/tests/test_operator.py
+70
-69
python/paddle/v2/framework/tests/test_protobuf.py
python/paddle/v2/framework/tests/test_protobuf.py
+3
-4
未找到文件。
paddle/framework/CMakeLists.txt
浏览文件 @
81f5f861
...
@@ -15,23 +15,19 @@ cc_test(variable_test SRCS variable_test.cc)
...
@@ -15,23 +15,19 @@ cc_test(variable_test SRCS variable_test.cc)
cc_library
(
scope SRCS scope.cc
)
cc_library
(
scope SRCS scope.cc
)
cc_test
(
scope_test SRCS scope_test.cc DEPS scope
)
cc_test
(
scope_test SRCS scope_test.cc DEPS scope
)
proto_library
(
attribute_proto SRCS attribute.proto
)
proto_library
(
framework_proto SRCS framework.proto
)
proto_library
(
op_proto SRCS op_proto.proto DEPS attribute_proto
)
proto_library
(
op_desc SRCS op_desc.proto DEPS attribute_proto
)
cc_test
(
op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf
)
cc_test
(
op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf
)
cc_library
(
attribute SRCS attribute.cc DEPS
op_desc op
_proto
)
cc_library
(
attribute SRCS attribute.cc DEPS
framework
_proto
)
cc_library
(
operator SRCS operator.cc DEPS
op_desc
device_context tensor scope attribute
)
cc_library
(
operator SRCS operator.cc DEPS
framework_proto
device_context tensor scope attribute
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry
)
cc_library
(
grad_op_builder SRCS grad_op_builder.cc DEPS op
_proto op
erator
)
cc_library
(
grad_op_builder SRCS grad_op_builder.cc DEPS operator
)
cc_library
(
op_registry SRCS op_registry.cc DEPS
op_desc
grad_op_builder
)
cc_library
(
op_registry SRCS op_registry.cc DEPS grad_op_builder
)
cc_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry
)
cc_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry
)
cc_test
(
grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op
)
cc_test
(
grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op
)
py_proto_compile
(
framework_py_proto SRCS
attribute.proto op_proto.proto op_desc
.proto
)
py_proto_compile
(
framework_py_proto SRCS
framework
.proto
)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target
(
framework_py_proto_init ALL COMMAND
${
CMAKE_COMMAND
}
-E touch __init__.py
)
add_custom_target
(
framework_py_proto_init ALL COMMAND
${
CMAKE_COMMAND
}
-E touch __init__.py
)
add_dependencies
(
framework_py_proto framework_py_proto_init
)
add_dependencies
(
framework_py_proto framework_py_proto_init
)
...
...
paddle/framework/attribute.cc
浏览文件 @
81f5f861
...
@@ -44,7 +44,7 @@ AttrType AttrTypeID<std::vector<std::string>>() {
...
@@ -44,7 +44,7 @@ AttrType AttrTypeID<std::vector<std::string>>() {
return
STRINGS
;
return
STRINGS
;
}
}
Attribute
GetAttrValue
(
const
AttrDesc
&
attr_desc
)
{
Attribute
GetAttrValue
(
const
OpDesc
::
Attr
&
attr_desc
)
{
switch
(
attr_desc
.
type
())
{
switch
(
attr_desc
.
type
())
{
case
paddle
::
framework
::
AttrType
::
INT
:
{
case
paddle
::
framework
::
AttrType
::
INT
:
{
return
attr_desc
.
i
();
return
attr_desc
.
i
();
...
...
paddle/framework/attribute.h
浏览文件 @
81f5f861
...
@@ -20,8 +20,7 @@ limitations under the License. */
...
@@ -20,8 +20,7 @@ limitations under the License. */
#include <unordered_set>
#include <unordered_set>
#include <vector>
#include <vector>
#include "paddle/framework/attribute.pb.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/variant.h"
#include "paddle/platform/variant.h"
...
@@ -37,7 +36,7 @@ typedef std::unordered_map<std::string, Attribute> AttributeMap;
...
@@ -37,7 +36,7 @@ typedef std::unordered_map<std::string, Attribute> AttributeMap;
template
<
typename
T
>
template
<
typename
T
>
AttrType
AttrTypeID
();
AttrType
AttrTypeID
();
Attribute
GetAttrValue
(
const
AttrDesc
&
attr_desc
);
Attribute
GetAttrValue
(
const
OpDesc
::
Attr
&
attr_desc
);
// check whether a value(attribute) fit a certain limit
// check whether a value(attribute) fit a certain limit
template
<
typename
T
>
template
<
typename
T
>
...
...
paddle/framework/backward.cc
浏览文件 @
81f5f861
...
@@ -21,15 +21,24 @@
...
@@ -21,15 +21,24 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
static
bool
AllInSet
(
const
std
::
vector
<
std
::
string
>&
names
,
template
<
typename
Map
,
typename
T
>
const
std
::
string
&
suffix
,
static
void
ForEachVarName
(
Map
&
names
,
T
callback
)
{
const
std
::
unordered_set
<
std
::
string
>&
set
)
{
for
(
auto
&
name
:
names
)
{
for
(
auto
&
name
:
names
)
{
if
(
set
.
find
(
name
+
suffix
)
==
set
.
end
()
)
{
for
(
auto
&
n
:
name
.
second
)
{
return
false
;
if
(
callback
(
n
))
return
;
}
}
}
}
return
true
;
}
static
bool
AllInSet
(
const
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>&
names
,
const
std
::
string
&
suffix
,
const
std
::
unordered_set
<
std
::
string
>&
set
)
{
bool
all_in_set
=
true
;
ForEachVarName
(
names
,
[
&
all_in_set
,
&
set
,
&
suffix
](
const
std
::
string
&
n
)
{
all_in_set
=
set
.
find
(
n
+
suffix
)
!=
set
.
end
();
return
!
all_in_set
;
});
return
all_in_set
;
}
}
static
std
::
shared_ptr
<
OperatorBase
>
NOP
()
{
static
std
::
shared_ptr
<
OperatorBase
>
NOP
()
{
...
@@ -68,10 +77,11 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
...
@@ -68,10 +77,11 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// Then all input gradients cannot be computed at all, and we put them into
// Then all input gradients cannot be computed at all, and we put them into
// `no_grad_names` set. Return an NOP.
// `no_grad_names` set. Return an NOP.
if
(
AllInSet
(
forwardOp
.
outputs_
,
kGradVarSuffix
,
no_grad_names
))
{
if
(
AllInSet
(
forwardOp
.
outputs_
,
kGradVarSuffix
,
no_grad_names
))
{
for
(
auto
&
name
:
forwardOp
.
inputs_
)
{
ForEachVarName
(
forwardOp
.
inputs_
,
// Mark all input is not need
[
&
no_grad_names
](
const
std
::
string
&
name
)
->
bool
{
no_grad_names
.
insert
(
name
+
kGradVarSuffix
);
no_grad_names
.
insert
(
GradVarName
(
name
));
}
return
false
;
});
return
NOP
();
return
NOP
();
}
}
...
@@ -93,9 +103,11 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
...
@@ -93,9 +103,11 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
auto
fwd
=
*
it
;
auto
fwd
=
*
it
;
auto
bwd
=
BackwardRecursive
(
*
fwd
,
no_grad_names
,
uniq_id
);
auto
bwd
=
BackwardRecursive
(
*
fwd
,
no_grad_names
,
uniq_id
);
net
->
AddOp
(
bwd
);
net
->
AddOp
(
bwd
);
for
(
auto
&
out
:
bwd
->
outputs_
)
{
ForEachVarName
(
bwd
->
outputs_
,
dup_output_ops
[
out
].
emplace_back
(
local_op_id
);
[
&
dup_output_ops
,
local_op_id
](
const
std
::
string
&
out
)
{
}
dup_output_ops
[
out
].
emplace_back
(
local_op_id
);
return
false
;
});
}
}
// Get unique ID for this method.
// Get unique ID for this method.
auto
uid
=
uniq_id
++
;
auto
uid
=
uniq_id
++
;
...
@@ -117,7 +129,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
...
@@ -117,7 +129,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
insert_position
.
push_back
(
insert_position
.
push_back
(
{
dup_op
.
back
(),
{
dup_op
.
back
(),
OpRegistry
::
CreateOp
(
OpRegistry
::
CreateOp
(
"add"
,
{
dup_outputs
},
{
name
},
"add"
,
{
{
"X"
,
{
dup_outputs
}}},
{{
"Out"
,
{
name
}}
},
{{
"input_format"
,
{{
"input_format"
,
std
::
vector
<
int
>
{
0
,
static_cast
<
int
>
(
dup_outputs
.
size
())}}})});
std
::
vector
<
int
>
{
0
,
static_cast
<
int
>
(
dup_outputs
.
size
())}}})});
}
}
...
@@ -131,7 +143,9 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
...
@@ -131,7 +143,9 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
}
else
{
}
else
{
std
::
shared_ptr
<
OperatorBase
>
grad_op
=
OpRegistry
::
CreateGradOp
(
forwardOp
);
std
::
shared_ptr
<
OperatorBase
>
grad_op
=
OpRegistry
::
CreateGradOp
(
forwardOp
);
for
(
std
::
string
&
grad_input
:
grad_op
->
inputs_
)
{
ForEachVarName
(
grad_op
->
inputs_
,
[
&
no_grad_names
,
&
net
](
std
::
string
&
grad_input
)
{
if
(
no_grad_names
.
count
(
grad_input
))
{
if
(
no_grad_names
.
count
(
grad_input
))
{
// +1 for \0
// +1 for \0
std
::
string
prefix
=
grad_input
.
substr
(
std
::
string
prefix
=
grad_input
.
substr
(
...
@@ -140,16 +154,19 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
...
@@ -140,16 +154,19 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// If part of input gradient of that operator is not calculated, fill
// If part of input gradient of that operator is not calculated, fill
// zero variables to that input gradient.
// zero variables to that input gradient.
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"fill_zeros_like"
,
{
prefix
},
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"fill_zeros_like"
,
{
{
"Src"
,
{
prefix
}}
},
{
grad_input
},
{}));
{
{
"Dst"
,
{
grad_input
}}
},
{}));
}
}
}
return
false
;
});
for
(
std
::
string
&
grad_output
:
grad_op
->
outputs_
)
{
if
(
no_grad_names
.
count
(
grad_output
))
{
ForEachVarName
(
grad_op
->
outputs_
,
grad_output
=
kEmptyVarName
;
[
&
no_grad_names
](
std
::
string
&
grad_output
)
{
}
if
(
no_grad_names
.
count
(
grad_output
))
{
}
grad_output
=
kEmptyVarName
;
}
return
false
;
});
if
(
net
->
ops_
.
empty
())
{
// Current no aux op is added to network
if
(
net
->
ops_
.
empty
())
{
// Current no aux op is added to network
return
grad_op
;
return
grad_op
;
...
...
paddle/framework/backward_test.cc
浏览文件 @
81f5f861
...
@@ -30,8 +30,7 @@ using DeviceContext = platform::DeviceContext;
...
@@ -30,8 +30,7 @@ using DeviceContext = platform::DeviceContext;
class
EmptyOp
:
public
OperatorBase
{
class
EmptyOp
:
public
OperatorBase
{
public:
public:
DEFINE_OPERATOR_CTOR
(
EmptyOp
,
OperatorBase
)
DEFINE_OPERATOR_CTOR
(
EmptyOp
,
OperatorBase
);
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
Run
(
const
Scope
&
scope
,
const
DeviceContext
&
dev_ctx
)
const
override
{}
void
Run
(
const
Scope
&
scope
,
const
DeviceContext
&
dev_ctx
)
const
override
{}
};
};
...
@@ -40,9 +39,9 @@ class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
...
@@ -40,9 +39,9 @@ class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
public:
public:
RowWiseAddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
RowWiseAddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input X of Add"
).
Ignore
Gradient
();
AddInput
(
"X"
,
"Input X of Add"
).
AsNo
Gradient
();
AddInput
(
"b"
,
"Bias of Add"
).
Ignore
Gradient
();
AddInput
(
"b"
,
"Bias of Add"
).
AsNo
Gradient
();
AddOutput
(
"Out"
,
"Out of Add"
).
Ignore
Gradient
();
AddOutput
(
"Out"
,
"Out of Add"
).
AsNo
Gradient
();
AddComment
(
"Add Op"
);
AddComment
(
"Add Op"
);
}
}
};
};
...
@@ -51,8 +50,8 @@ class MulOpMaker : public OpProtoAndCheckerMaker {
...
@@ -51,8 +50,8 @@ class MulOpMaker : public OpProtoAndCheckerMaker {
public:
public:
MulOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
MulOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"
A
"
,
"A"
);
AddInput
(
"
X
"
,
"A"
);
AddInput
(
"
B
"
,
"B"
);
AddInput
(
"
Y
"
,
"B"
);
AddOutput
(
"Out"
,
"Out"
);
AddOutput
(
"Out"
,
"Out"
);
AddComment
(
"Mul"
);
AddComment
(
"Mul"
);
}
}
...
@@ -63,7 +62,7 @@ class SigmoidOpMaker : public OpProtoAndCheckerMaker {
...
@@ -63,7 +62,7 @@ class SigmoidOpMaker : public OpProtoAndCheckerMaker {
SigmoidOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
SigmoidOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"X"
);
AddInput
(
"X"
,
"X"
);
AddOutput
(
"
Y
"
,
"Y"
);
AddOutput
(
"
Out
"
,
"Y"
);
AddComment
(
"Sigmoid"
);
AddComment
(
"Sigmoid"
);
}
}
};
};
...
@@ -73,21 +72,24 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
...
@@ -73,21 +72,24 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
NoGradOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
NoGradOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"X input"
);
AddInput
(
"X"
,
"X input"
);
AddOutput
(
"
Y
"
,
"Y output"
);
AddOutput
(
"
Out
"
,
"Y output"
);
AddComment
(
"NoGradOp, same input output. no Grad"
);
AddComment
(
"NoGradOp, same input output. no Grad"
);
}
}
};
};
class
FcOp
:
public
operators
::
NetOp
{
class
FcOp
:
public
operators
::
NetOp
{
public:
public:
DEFINE_OPERATOR_CTOR
(
FcOp
,
operators
::
NetOp
)
void
Init
()
override
{
void
Init
()
override
{
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
{
Input
(
"X"
),
Input
(
"W"
)},
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
{
Output
(
"mul_result"
)},
{}));
{{
"X"
,
{
Input
(
"X"
)}},
{
"Y"
,
{
Input
(
"W"
)}}},
auto
b_name
=
Input
(
"b"
);
{{
"Out"
,
{
Output
(
"mul_result"
)}}},
{}));
auto
input_b
=
Inputs
(
"b"
);
std
::
string
before_act
=
"mul_result"
;
std
::
string
before_act
=
"mul_result"
;
if
(
b_name
!=
kEmptyVarName
)
{
if
(
input_b
.
size
()
!=
0
)
{
AddOp
(
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{
Output
(
"mul_result"
),
b_name
},
AddOp
(
OpRegistry
::
CreateOp
(
{
Output
(
"add_result"
)},
{}));
"rowwise_add"
,
{{
"X"
,
{
Output
(
"mul_result"
)}},
{
"b"
,
{
input_b
[
0
]}}},
{{
"Out"
,
{
Output
(
"add_result"
)}}},
{}));
before_act
=
"add_result"
;
before_act
=
"add_result"
;
}
else
{
}
else
{
auto
out_varname
=
Output
(
"add_result"
);
auto
out_varname
=
Output
(
"add_result"
);
...
@@ -96,8 +98,8 @@ class FcOp : public operators::NetOp {
...
@@ -96,8 +98,8 @@ class FcOp : public operators::NetOp {
}
}
}
}
AddOp
(
OpRegistry
::
CreateOp
(
"sigmoid"
,
{
Output
(
before_act
)},
{
Output
(
"Out"
)
},
AddOp
(
OpRegistry
::
CreateOp
(
"sigmoid"
,
{
{
"X"
,
{
Output
(
before_act
)}}
},
{}));
{
{
"Out"
,
{
Output
(
"Out"
)}}},
{
}));
CompleteAddOp
(
false
);
CompleteAddOp
(
false
);
}
}
};
};
...
@@ -109,8 +111,8 @@ class FcOpMaker : public OpProtoAndCheckerMaker {
...
@@ -109,8 +111,8 @@ class FcOpMaker : public OpProtoAndCheckerMaker {
AddInput
(
"X"
,
"x"
);
AddInput
(
"X"
,
"x"
);
AddInput
(
"W"
,
"w"
);
AddInput
(
"W"
,
"w"
);
AddInput
(
"b"
,
"b"
);
AddInput
(
"b"
,
"b"
);
AddOutput
(
"mul_result"
,
""
).
SetTemporary
();
AddOutput
(
"mul_result"
,
""
).
AsIntermediate
();
AddOutput
(
"add_result"
,
""
).
SetTemporary
();
AddOutput
(
"add_result"
,
""
).
AsIntermediate
();
AddOutput
(
"Out"
,
""
);
AddOutput
(
"Out"
,
""
);
AddComment
(
""
);
AddComment
(
""
);
}
}
...
@@ -141,7 +143,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
...
@@ -141,7 +143,7 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
public:
public:
AddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
AddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"x"
).
SetMultip
le
();
AddInput
(
"X"
,
"x"
).
AsDuplicab
le
();
AddOutput
(
"Y"
,
"y"
);
AddOutput
(
"Y"
,
"y"
);
AddComment
(
""
);
AddComment
(
""
);
}
}
...
@@ -167,27 +169,24 @@ REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker);
...
@@ -167,27 +169,24 @@ REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker);
REGISTER_GRADIENT_OP
(
many_output_op
,
many_output_op_grad
,
f
::
EmptyOp
);
REGISTER_GRADIENT_OP
(
many_output_op
,
many_output_op_grad
,
f
::
EmptyOp
);
TEST
(
Backward
,
simple_op_grad
)
{
TEST
(
Backward
,
simple_op_grad
)
{
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{
"X"
,
"b"
},
{
"Out"
},
{});
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{{
"X"
,
{
"x"
}},
{
"b"
,
{
"b"
}}},
{{
"Out"
,
{
"out"
}}},
{});
ASSERT_NE
(
fwd
,
nullptr
);
ASSERT_NE
(
fwd
,
nullptr
);
auto
gop
=
f
::
OpRegistry
::
CreateGradOp
(
*
fwd
);
auto
gop
=
f
::
OpRegistry
::
CreateGradOp
(
*
fwd
);
ASSERT_EQ
(
4UL
,
gop
->
inputs_
.
size
());
ASSERT_EQ
(
1UL
,
gop
->
inputs_
.
size
());
ASSERT_EQ
(
f
::
kEmptyVarName
,
gop
->
inputs_
[
0
]);
ASSERT_EQ
(
"rowwise_add_grad"
,
gop
->
type_
);
ASSERT_EQ
(
"rowwise_add_grad"
,
gop
->
type_
);
ASSERT_EQ
(
f
::
GradVarName
(
"X"
),
gop
->
outputs_
[
0
]);
ASSERT_EQ
(
f
::
GradVarName
(
"x"
),
gop
->
Output
(
f
::
GradVarName
(
"X"
)));
ASSERT_EQ
(
f
::
GradVarName
(
"b"
),
gop
->
outputs_
[
1
]);
ASSERT_EQ
(
f
::
GradVarName
(
"b"
),
gop
->
Output
(
f
::
GradVarName
(
"b"
)));
ASSERT_EQ
(
f
::
GradVarName
(
"X"
),
gop
->
Output
(
f
::
GradVarName
(
"X"
)));
}
}
TEST
(
Backward
,
simple_op_not_need_grad
)
{
TEST
(
Backward
,
simple_op_not_need_grad
)
{
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{
"X"
,
"b"
},
{
"Out"
},
{});
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{{
"X"
,
{
"x"
}},
{
"b"
,
{
"b"
}}},
{{
"Out"
,
{
"out"
}}},
{});
ASSERT_NE
(
fwd
,
nullptr
);
ASSERT_NE
(
fwd
,
nullptr
);
auto
gop
=
f
::
Backward
(
*
fwd
,
{
"X"
});
auto
gop
=
f
::
Backward
(
*
fwd
,
{
"x"
});
ASSERT_EQ
(
std
::
find
(
gop
->
outputs_
.
begin
(),
gop
->
outputs_
.
end
(),
ASSERT_EQ
(
gop
->
Output
(
f
::
GradVarName
(
"X"
)),
f
::
kEmptyVarName
);
f
::
GradVarName
(
"X"
)),
gop
->
outputs_
.
end
());
auto
no_input_gop
=
f
::
Backward
(
*
fwd
,
{
"
X
"
,
"b"
});
auto
no_input_gop
=
f
::
Backward
(
*
fwd
,
{
"
x
"
,
"b"
});
ASSERT_NE
(
no_input_gop
,
nullptr
);
ASSERT_NE
(
no_input_gop
,
nullptr
);
ASSERT_TRUE
(
no_input_gop
->
IsNetOp
());
ASSERT_TRUE
(
no_input_gop
->
IsNetOp
());
ASSERT_EQ
(
0UL
,
ASSERT_EQ
(
0UL
,
...
@@ -195,8 +194,12 @@ TEST(Backward, simple_op_not_need_grad) {
...
@@ -195,8 +194,12 @@ TEST(Backward, simple_op_not_need_grad) {
}
}
TEST
(
Backward
,
net_fc_backward_normal
)
{
TEST
(
Backward
,
net_fc_backward_normal
)
{
std
::
shared_ptr
<
f
::
OperatorBase
>
fwd
=
f
::
OpRegistry
::
CreateOp
(
std
::
shared_ptr
<
f
::
OperatorBase
>
fwd
=
"fc"
,
{
"X"
,
"w"
,
"b"
},
{
"mul_result"
,
"add_result"
,
"out"
},
{});
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{{
"X"
,
{
"x"
}},
{
"W"
,
{
"w"
}},
{
"b"
,
{
"b"
}}},
{{
"mul_result"
,
{
"mul_res"
}},
{
"add_result"
,
{
"add_re"
}},
{
"Out"
,
{
"out"
}}},
{});
ASSERT_NE
(
fwd
,
nullptr
);
ASSERT_NE
(
fwd
,
nullptr
);
std
::
shared_ptr
<
f
::
OperatorBase
>
gop
=
f
::
Backward
(
*
fwd
,
{});
std
::
shared_ptr
<
f
::
OperatorBase
>
gop
=
f
::
Backward
(
*
fwd
,
{});
ASSERT_TRUE
(
gop
->
IsNetOp
());
ASSERT_TRUE
(
gop
->
IsNetOp
());
...
@@ -218,8 +221,11 @@ TEST(Backward, net_fc_backward_normal) {
...
@@ -218,8 +221,11 @@ TEST(Backward, net_fc_backward_normal) {
TEST
(
Backward
,
net_fc_backward_not_have_b
)
{
TEST
(
Backward
,
net_fc_backward_not_have_b
)
{
std
::
shared_ptr
<
f
::
OperatorBase
>
fwd
=
std
::
shared_ptr
<
f
::
OperatorBase
>
fwd
=
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{
"X"
,
"w"
,
f
::
kEmptyVarName
},
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{{
"X"
,
{
"x"
}},
{
"W"
,
{
"w"
}},
{
"b"
,
{}}},
{
"mul_result"
,
"add_result"
,
"tmp"
},
{});
{{
"mul_result"
,
{
"mul_res"
}},
{
"add_result"
,
{
"add_res"
}},
{
"Out"
,
{
"tmp"
}}},
{});
ASSERT_NE
(
fwd
,
nullptr
);
ASSERT_NE
(
fwd
,
nullptr
);
std
::
shared_ptr
<
f
::
OperatorBase
>
gop
=
f
::
Backward
(
*
fwd
,
{});
std
::
shared_ptr
<
f
::
OperatorBase
>
gop
=
f
::
Backward
(
*
fwd
,
{});
ASSERT_TRUE
(
gop
->
IsNetOp
());
ASSERT_TRUE
(
gop
->
IsNetOp
());
...
@@ -238,38 +244,49 @@ TEST(Backward, net_fc_backward_not_have_b) {
...
@@ -238,38 +244,49 @@ TEST(Backward, net_fc_backward_not_have_b) {
TEST
(
Backward
,
net_input_of_network_not_need_grad
)
{
TEST
(
Backward
,
net_input_of_network_not_need_grad
)
{
ops
::
NetOp
net
;
ops
::
NetOp
net
;
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{
"X"
,
"W1"
,
"b1"
},
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
{
"mul_tmp_0"
,
"add_tmp_0"
,
"hidden0"
},
{}));
"fc"
,
{{
"X"
,
{
"x"
}},
{
"W"
,
{
"W1"
}},
{
"b"
,
{
"b1"
}}},
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{
"hidden0"
,
"W2"
,
"b2"
},
{{
"mul_result"
,
{
"mul_tmp_0"
}},
{
"mul_tmp_1"
,
"add_tmp_1"
,
"hidden1"
},
{}));
{
"add_result"
,
{
"add_tmp_0"
}},
{
"Out"
,
{
"hidden0"
}}},
{}));
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{{
"X"
,
{
"hidden0"
}},
{
"W"
,
{
"W2"
}},
{
"b"
,
{
"b2"
}}},
{{
"mul_result"
,
{
"mul_tmp_1"
}},
{
"add_result"
,
{
"add_tmp_1"
}},
{
"Out"
,
{
"hidden1"
}}},
{}));
net
.
CompleteAddOp
();
net
.
CompleteAddOp
();
auto
bwd
=
Backward
(
net
,
{
"
X"
});
// X
@GRAD is not need.
auto
bwd
=
Backward
(
net
,
{
"
x"
});
// x
@GRAD is not need.
ASSERT_TRUE
(
bwd
->
IsNetOp
());
ASSERT_TRUE
(
bwd
->
IsNetOp
());
auto
bwd_net
=
static_cast
<
ops
::
NetOp
*>
(
bwd
.
get
());
auto
bwd_net
=
static_cast
<
ops
::
NetOp
*>
(
bwd
.
get
());
std
::
unordered_set
<
std
::
string
>
all_output
=
std
::
unordered_set
<
std
::
string
>
(
auto
output_vars
=
bwd_net
->
OutputVars
(
true
);
bwd_net
->
outputs_
.
begin
(),
bwd_net
->
outputs_
.
end
());
std
::
unordered_set
<
std
::
string
>
all_outputs
=
all_output
.
erase
(
f
::
kEmptyVarName
);
std
::
unordered_set
<
std
::
string
>
(
output_vars
.
begin
(),
output_vars
.
end
());
all_outputs
.
erase
(
f
::
kEmptyVarName
);
for
(
auto
&
out
:
{
"W1"
,
"b1"
,
"hidden0"
,
"W2"
,
"b2"
})
{
for
(
auto
&
out
:
{
"W1"
,
"b1"
,
"hidden0"
,
"W2"
,
"b2"
})
{
ASSERT_NE
(
all_output
.
find
(
f
::
GradVarName
(
out
)),
all_output
.
end
());
ASSERT_NE
(
all_output
s
.
find
(
f
::
GradVarName
(
out
)),
all_outputs
.
end
());
}
}
// Not Generated X
// Not Generated X
ASSERT_EQ
(
all_output
.
find
(
f
::
GradVarName
(
"X"
)),
all_output
.
end
());
ASSERT_EQ
(
all_output
s
.
find
(
f
::
GradVarName
(
"X"
)),
all_outputs
.
end
());
ASSERT_EQ
(
2UL
,
bwd_net
->
ops_
.
size
());
ASSERT_EQ
(
2UL
,
bwd_net
->
ops_
.
size
());
ASSERT_TRUE
(
bwd_net
->
ops_
[
1
]
->
IsNetOp
());
ASSERT_TRUE
(
bwd_net
->
ops_
[
1
]
->
IsNetOp
());
auto
first_fc_grad
=
static_cast
<
ops
::
NetOp
*>
(
bwd_net
->
ops_
[
1
].
get
());
auto
first_fc_grad
=
static_cast
<
ops
::
NetOp
*>
(
bwd_net
->
ops_
[
1
].
get
());
ASSERT_EQ
(
3UL
,
first_fc_grad
->
ops_
.
size
());
ASSERT_EQ
(
3UL
,
first_fc_grad
->
ops_
.
size
());
ASSERT_EQ
(
f
::
kEmptyVarName
,
ASSERT_EQ
(
f
::
kEmptyVarName
,
first_fc_grad
->
ops_
[
2
]
->
Output
(
f
::
GradVarName
(
"
A
"
)));
first_fc_grad
->
ops_
[
2
]
->
Output
(
f
::
GradVarName
(
"
X
"
)));
}
}
TEST
(
Backward
,
net_shared_weight
)
{
TEST
(
Backward
,
net_shared_weight
)
{
ops
::
NetOp
net
;
ops
::
NetOp
net
;
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"mul"
,
{
"X"
,
"W"
},
{
"Out"
},
{}));
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"mul"
,
{{
"X"
,
{
"x"
}},
{
"Y"
,
{
"w"
}}},
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"mul"
,
{
"Out"
,
"W"
},
{
"FinalOut"
},
{}));
{{
"Out"
,
{
"out"
}}},
{}));
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"mul"
,
{{
"X"
,
{
"out"
}},
{
"Y"
,
{
"w"
}}},
{{
"Out"
,
{
"FinalOut"
}}},
{}));
net
.
CompleteAddOp
();
net
.
CompleteAddOp
();
auto
bwd
=
f
::
Backward
(
net
,
{});
auto
bwd
=
f
::
Backward
(
net
,
{});
...
@@ -280,31 +297,37 @@ TEST(Backward, net_shared_weight) {
...
@@ -280,31 +297,37 @@ TEST(Backward, net_shared_weight) {
}
}
TEST
(
Backward
,
op_register_grad_not_for_network
)
{
TEST
(
Backward
,
op_register_grad_not_for_network
)
{
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
auto
fwd
=
"fc"
,
{
"X"
,
"W"
,
"b"
},
{
"mul_out"
,
"add_out"
,
"out1"
},
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{{
"X"
,
{
"x"
}},
{
"W"
,
{
"w"
}},
{
"b"
,
{
"b"
}}},
{{
"temporary_index"
,
std
::
vector
<
int
>
{
0
,
1
}}});
{{
"mul_result"
,
{
"mul_out"
}},
{
"add_result"
,
{
"add_out"
}},
{
"Out"
,
{
"out1"
}}},
{{
"temporary_index"
,
std
::
vector
<
int
>
{
0
,
1
}}});
ASSERT_THROW
(
f
::
OpRegistry
::
CreateGradOp
(
*
fwd
),
EnforceNotMet
);
ASSERT_THROW
(
f
::
OpRegistry
::
CreateGradOp
(
*
fwd
),
EnforceNotMet
);
}
}
TEST
(
Backward
,
op_all_input_are_not_need
)
{
TEST
(
Backward
,
op_all_input_are_not_need
)
{
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{
"X"
,
"b"
},
{
"Out"
},
{});
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
auto
backward
=
f
::
Backward
(
*
fwd
,
{
"X"
,
"b"
});
"rowwise_add"
,
{{
"X"
,
{
"x"
}},
{
"b"
,
{
"b"
}}},
{{
"Out"
,
{
"out"
}}},
{});
auto
backward
=
f
::
Backward
(
*
fwd
,
{
"x"
,
"b"
});
ASSERT_TRUE
(
backward
->
IsNetOp
());
ASSERT_TRUE
(
backward
->
IsNetOp
());
auto
net
=
static_cast
<
ops
::
NetOp
*>
(
backward
.
get
());
auto
net
=
static_cast
<
ops
::
NetOp
*>
(
backward
.
get
());
ASSERT_TRUE
(
net
->
ops_
.
empty
());
ASSERT_TRUE
(
net
->
ops_
.
empty
());
}
}
TEST
(
Backward
,
op_all_output_are_not_need
)
{
TEST
(
Backward
,
op_all_output_are_not_need
)
{
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{
"X"
,
"b"
},
{
"Out"
},
{});
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
auto
backward
=
f
::
Backward
(
*
fwd
,
{
"Out"
});
"rowwise_add"
,
{{
"X"
,
{
"x"
}},
{
"b"
,
{
"b"
}}},
{{
"Out"
,
{
"out"
}}},
{});
auto
backward
=
f
::
Backward
(
*
fwd
,
{
"out"
});
ASSERT_TRUE
(
backward
->
IsNetOp
());
ASSERT_TRUE
(
backward
->
IsNetOp
());
auto
net
=
static_cast
<
ops
::
NetOp
*>
(
backward
.
get
());
auto
net
=
static_cast
<
ops
::
NetOp
*>
(
backward
.
get
());
ASSERT_TRUE
(
net
->
ops_
.
empty
());
ASSERT_TRUE
(
net
->
ops_
.
empty
());
}
}
TEST
(
Backward
,
op_part_of_output_are_not_need
)
{
TEST
(
Backward
,
op_part_of_output_are_not_need
)
{
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"many_output_op"
,
{
"X"
},
{
"Y"
,
"Z"
},
{});
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"many_output_op"
,
{{
"x"
,
{
"X"
}}},
{{
"y"
,
{
"Y"
}},
{
"z"
,
{
"Z"
}}},
{});
auto
backward
=
f
::
Backward
(
*
fwd
,
{
"Z"
});
auto
backward
=
f
::
Backward
(
*
fwd
,
{
"Z"
});
ASSERT_TRUE
(
backward
->
IsNetOp
());
ASSERT_TRUE
(
backward
->
IsNetOp
());
auto
net
=
static_cast
<
ops
::
NetOp
*>
(
backward
.
get
());
auto
net
=
static_cast
<
ops
::
NetOp
*>
(
backward
.
get
());
...
@@ -312,10 +335,10 @@ TEST(Backward, op_part_of_output_are_not_need) {
...
@@ -312,10 +335,10 @@ TEST(Backward, op_part_of_output_are_not_need) {
auto
&
fill_zero
=
*
net
->
ops_
[
0
];
auto
&
fill_zero
=
*
net
->
ops_
[
0
];
ASSERT_EQ
(
"fill_zeros_like"
,
fill_zero
.
type_
);
ASSERT_EQ
(
"fill_zeros_like"
,
fill_zero
.
type_
);
ASSERT_EQ
(
1UL
,
fill_zero
.
inputs_
.
size
());
ASSERT_EQ
(
1UL
,
fill_zero
.
Inputs
(
"Src"
)
.
size
());
ASSERT_EQ
(
"Z"
,
fill_zero
.
inputs_
[
0
]
);
ASSERT_EQ
(
"Z"
,
fill_zero
.
Input
(
"Src"
)
);
ASSERT_EQ
(
1UL
,
fill_zero
.
outputs_
.
size
());
ASSERT_EQ
(
1UL
,
fill_zero
.
Outputs
(
"Dst"
)
.
size
());
ASSERT_EQ
(
std
::
string
(
"Z"
)
+
f
::
kZeroVarSuffix
,
fill_zero
.
outputs_
[
0
]
);
ASSERT_EQ
(
std
::
string
(
"Z"
)
+
f
::
kZeroVarSuffix
,
fill_zero
.
Output
(
"Dst"
)
);
auto
&
d_many_out
=
*
net
->
ops_
[
1
];
auto
&
d_many_out
=
*
net
->
ops_
[
1
];
ASSERT_EQ
(
"many_output_op_grad"
,
d_many_out
.
type_
);
ASSERT_EQ
(
"many_output_op_grad"
,
d_many_out
.
type_
);
...
@@ -327,44 +350,62 @@ TEST(Backward, op_part_of_output_are_not_need) {
...
@@ -327,44 +350,62 @@ TEST(Backward, op_part_of_output_are_not_need) {
}
}
TEST
(
Backward
,
op_part_of_input_are_not_need
)
{
TEST
(
Backward
,
op_part_of_input_are_not_need
)
{
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"mul"
,
{
"a"
,
"b"
},
{
"out"
},
{});
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"mul"
,
{{
"X"
,
{
"a"
}},
{
"Y"
,
{
"b"
}}},
{{
"Out"
,
{
"out"
}}},
{});
auto
backward
=
f
::
Backward
(
*
fwd
,
{
"a"
});
auto
backward
=
f
::
Backward
(
*
fwd
,
{
"a"
});
auto
&
grad_mul
=
*
backward
;
auto
&
grad_mul
=
*
backward
;
ASSERT_EQ
(
grad_mul
.
type_
,
"mul_grad"
);
ASSERT_EQ
(
grad_mul
.
type_
,
"mul_grad"
);
ASSERT_EQ
(
grad_mul
.
inputs_
.
size
(),
2UL
+
1UL
+
1UL
);
ASSERT_EQ
(
grad_mul
.
inputs_
.
size
(),
2UL
+
1UL
+
1UL
);
ASSERT_EQ
(
grad_mul
.
outputs_
.
size
(),
2UL
);
ASSERT_EQ
(
grad_mul
.
outputs_
.
size
(),
2UL
);
ASSERT_EQ
(
grad_mul
.
Output
(
f
::
GradVarName
(
"
A
"
)),
f
::
kEmptyVarName
);
ASSERT_EQ
(
grad_mul
.
Output
(
f
::
GradVarName
(
"
X
"
)),
f
::
kEmptyVarName
);
ASSERT_EQ
(
grad_mul
.
Output
(
f
::
GradVarName
(
"
B
"
)),
f
::
GradVarName
(
"b"
));
ASSERT_EQ
(
grad_mul
.
Output
(
f
::
GradVarName
(
"
Y
"
)),
f
::
GradVarName
(
"b"
));
ASSERT_EQ
(
grad_mul
.
Input
(
f
::
GradVarName
(
"Out"
)),
f
::
GradVarName
(
"out"
));
ASSERT_EQ
(
grad_mul
.
Input
(
f
::
GradVarName
(
"Out"
)),
f
::
GradVarName
(
"out"
));
ASSERT_EQ
(
grad_mul
.
Input
(
"
A
"
),
"a"
);
ASSERT_EQ
(
grad_mul
.
Input
(
"
X
"
),
"a"
);
ASSERT_EQ
(
grad_mul
.
Input
(
"
B
"
),
"b"
);
ASSERT_EQ
(
grad_mul
.
Input
(
"
Y
"
),
"b"
);
ASSERT_EQ
(
grad_mul
.
Input
(
"Out"
),
"out"
);
ASSERT_EQ
(
grad_mul
.
Input
(
"Out"
),
"out"
);
}
}
TEST
(
Backward
,
linear_net_intermediate_variable_has_no_grad
)
{
TEST
(
Backward
,
linear_net_intermediate_variable_has_no_grad
)
{
ops
::
NetOp
net
;
ops
::
NetOp
net
;
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{
"x1"
,
"w1"
,
"b1"
},
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
{
"mul_out1"
,
"add_out1"
,
"out1"
},
{}));
"fc"
,
{{
"X"
,
{
"x1"
}},
{
"W"
,
{
"w1"
}},
{
"b"
,
{
"b1"
}}},
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{
"out1"
,
"w2"
,
"b2"
},
{{
"mul_result"
,
{
"mul_out1"
}},
{
"mul_out2"
,
"tmp_out2"
,
"out2"
},
{}));
{
"add_result"
,
{
"add_out1"
}},
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{
"out2"
,
"w3"
,
"b3"
},
{
"Out"
,
{
"out1"
}}},
{
"mul_out3"
,
"tmp_out3"
,
"out3"
},
{}));
{}));
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{{
"X"
,
{
"out1"
}},
{
"W"
,
{
"w2"
}},
{
"b"
,
{
"b2"
}}},
{{
"mul_result"
,
{
"mul_out2"
}},
{
"add_result"
,
{
"tmp_out2"
}},
{
"Out"
,
{
"out2"
}}},
{}));
net
.
AddOp
(
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{{
"X"
,
{
"out2"
}},
{
"W"
,
{
"w3"
}},
{
"b"
,
{
"b3"
}}},
{{
"mul_result"
,
{
"mul_out3"
}},
{
"add_result"
,
{
"tmp_out3"
}},
{
"Out"
,
{
"out3"
}}},
{}));
net
.
CompleteAddOp
();
net
.
CompleteAddOp
();
auto
backward
=
f
::
Backward
(
net
,
{
"mul_out2"
,
"tmp_out2"
,
"out2"
});
auto
backward
=
f
::
Backward
(
net
,
{
"mul_out2"
,
"tmp_out2"
,
"out2"
});
ASSERT_TRUE
(
backward
->
IsNetOp
());
ASSERT_TRUE
(
backward
->
IsNetOp
());
auto
bwd_net
=
static_cast
<
ops
::
NetOp
*>
(
backward
.
get
());
auto
bwd_net
=
static_cast
<
ops
::
NetOp
*>
(
backward
.
get
());
ASSERT_EQ
(
bwd_net
->
ops_
.
size
(),
3UL
);
ASSERT_EQ
(
bwd_net
->
ops_
.
size
(),
3UL
);
auto
&
grad_fc
=
*
bwd_net
->
ops_
[
0
];
auto
&
grad_fc
=
*
bwd_net
->
ops_
[
0
];
EXPECT_EQ
(
grad_fc
.
inputs_
.
size
(),
3UL
/* external input number */
const
char
*
all
=
paddle
::
operators
::
NetOp
::
kAll
;
EXPECT_EQ
(
grad_fc
.
inputs_
[
all
].
size
(),
2UL
/* external input number */
+
1UL
/* external output number*/
+
1UL
/* external output number*/
+
1UL
/* number of gradient of external output*/
+
1UL
/* number of gradient of external output*/
+
2U
/* internal variable number*/
);
+
2U
/* internal variable number*/
);
EXPECT_EQ
(
grad_fc
.
outputs_
.
size
(),
2UL
/* input number of mul*/
EXPECT_EQ
(
grad_fc
.
outputs_
[
all
].
size
(),
+
2UL
/* input number of rowwise_add */
2UL
/* input number of mul*/
+
1UL
/* input number of sigmod */
);
+
2UL
/* input number of rowwise_add
EXPECT_EQ
(
bwd_net
->
ops_
[
1
]
->
inputs_
.
size
(),
0UL
);
*/
EXPECT_EQ
(
bwd_net
->
ops_
[
1
]
->
outputs_
.
size
(),
0UL
);
+
1UL
/* input number of sigmod */
);
EXPECT_EQ
(
bwd_net
->
ops_
[
2
]
->
inputs_
.
size
(),
0UL
);
EXPECT_EQ
(
bwd_net
->
ops_
[
1
]
->
inputs_
[
all
].
size
(),
0UL
);
EXPECT_EQ
(
bwd_net
->
ops_
[
2
]
->
outputs_
.
size
(),
0UL
);
EXPECT_EQ
(
bwd_net
->
ops_
[
1
]
->
outputs_
[
all
].
size
(),
0UL
);
EXPECT_EQ
(
bwd_net
->
ops_
[
2
]
->
inputs_
[
all
].
size
(),
0UL
);
EXPECT_EQ
(
bwd_net
->
ops_
[
2
]
->
outputs_
[
all
].
size
(),
0UL
);
}
}
paddle/framework/ddim.cc
浏览文件 @
81f5f861
...
@@ -283,6 +283,5 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
...
@@ -283,6 +283,5 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
DDim
::
DDim
(
std
::
initializer_list
<
int
>
init_list
)
{
DDim
::
DDim
(
std
::
initializer_list
<
int
>
init_list
)
{
*
this
=
make_ddim
(
init_list
);
*
this
=
make_ddim
(
init_list
);
}
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/
attribute
.proto
→
paddle/framework/
framework
.proto
浏览文件 @
81f5f861
...
@@ -15,9 +15,6 @@ limitations under the License. */
...
@@ -15,9 +15,6 @@ limitations under the License. */
syntax
=
"proto2"
;
syntax
=
"proto2"
;
package
paddle
.
framework
;
package
paddle
.
framework
;
// Attribute Type for paddle's Op.
// Op contains many attributes. Each type of attributes could be different.
// The AttrType will be shared between AttrDesc and AttrProto.
enum
AttrType
{
enum
AttrType
{
INT
=
0
;
INT
=
0
;
FLOAT
=
1
;
FLOAT
=
1
;
...
@@ -25,4 +22,61 @@ enum AttrType {
...
@@ -25,4 +22,61 @@ enum AttrType {
INTS
=
3
;
INTS
=
3
;
FLOATS
=
4
;
FLOATS
=
4
;
STRINGS
=
5
;
STRINGS
=
5
;
}
}
\ No newline at end of file
// OpDesc describes an instance of a C++ framework::OperatorBase
// derived class type.
message
OpDesc
{
message
Attr
{
required
string
name
=
1
;
required
AttrType
type
=
2
;
optional
int32
i
=
3
;
optional
float
f
=
4
;
optional
string
s
=
5
;
repeated
int32
ints
=
6
;
repeated
float
floats
=
7
;
repeated
string
strings
=
8
;
};
message
Var
{
required
string
parameter
=
1
;
repeated
string
arguments
=
2
;
};
required
string
type
=
3
;
repeated
Var
inputs
=
1
;
repeated
Var
outputs
=
2
;
repeated
Attr
attrs
=
4
;
};
// OpProto describes a C++ framework::OperatorBase derived class.
message
OpProto
{
// VarProto describes the C++ type framework::Variable.
message
Var
{
required
string
name
=
1
;
required
string
comment
=
2
;
optional
bool
duplicable
=
3
[
default
=
false
];
optional
bool
intermediate
=
4
[
default
=
false
];
optional
bool
no_gradient
=
5
[
default
=
false
];
}
// AttrProto describes the C++ type Attribute.
message
Attr
{
required
string
name
=
1
;
required
AttrType
type
=
2
;
required
string
comment
=
3
;
// If that attribute is generated, it means the Paddle third
// language binding has responsibility to fill that
// attribute. End-User should not set that attribute.
optional
bool
generated
=
4
[
default
=
false
];
}
required
string
type
=
1
;
repeated
Var
inputs
=
2
;
repeated
Var
outputs
=
3
;
repeated
Attr
attrs
=
4
;
required
string
comment
=
5
;
}
paddle/framework/grad_op_builder.cc
浏览文件 @
81f5f861
...
@@ -13,104 +13,45 @@ express or implied. See the License for the specific language governing
...
@@ -13,104 +13,45 @@ express or implied. See the License for the specific language governing
permissions and limitations under the License. */
permissions and limitations under the License. */
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/
op_proto
.pb.h"
#include "paddle/framework/
framework
.pb.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
typedef
std
::
vector
<
int
>
Ints
;
enum
class
OpArgType
{
IN
,
OUT
};
enum
class
OpArgType
{
IN
,
OUT
};
const
Ints
*
AttrFormat
(
const
AttributeMap
&
attrs
,
const
std
::
string
&
key
)
{
static
void
TransOpArg
(
const
OperatorBase
*
src_op
,
OperatorBase
*
dst_op
,
return
(
attrs
.
count
(
key
)
>
0
)
?
&
boost
::
get
<
Ints
>
(
attrs
.
at
(
key
))
:
nullptr
;
const
OpArgType
&
src_type
,
const
OpArgType
&
dst_type
,
}
bool
is_grad
)
{
const
auto
&
src_inout
=
Ints
*
AttrFormat
(
AttributeMap
&
attrs
,
const
std
::
string
&
key
)
{
src_type
==
OpArgType
::
IN
?
src_op
->
inputs_
:
src_op
->
outputs_
;
return
(
attrs
.
count
(
key
)
>
0
)
?
&
boost
::
get
<
Ints
>
(
attrs
.
at
(
key
))
:
nullptr
;
auto
&
dst_inout
=
}
dst_type
==
OpArgType
::
IN
?
dst_op
->
inputs_
:
dst_op
->
outputs_
;
static
void
TransOpArg
(
const
OperatorBase
*
src_op
,
std
::
vector
<
std
::
string
>&
grad_inputs
,
std
::
vector
<
std
::
string
>&
grad_outputs
,
AttributeMap
&
grad_attrs
,
std
::
unordered_map
<
std
::
string
,
int
>&
grad_idxs
,
const
std
::
string
&
src_type
,
const
std
::
string
&
dst_type
,
int
&
idx
,
bool
is_grad
)
{
const
std
::
vector
<
std
::
string
>&
src_inout
=
(
src_type
==
"input_format"
)
?
src_op
->
inputs_
:
src_op
->
outputs_
;
const
std
::
vector
<
int
>*
src_format
=
AttrFormat
(
src_op
->
Attrs
(),
src_type
);
std
::
vector
<
std
::
string
>&
dst_inout
=
(
dst_type
==
"input_format"
)
?
grad_inputs
:
grad_outputs
;
std
::
vector
<
int
>*
dst_format
=
AttrFormat
(
grad_attrs
,
dst_type
);
const
OpProto
&
proto
=
OpRegistry
::
protos
().
at
(
src_op
->
type_
);
const
OpProto
&
proto
=
OpProtos
().
at
(
src_op
->
type_
);
const
auto
&
src_arg_list
=
const
auto
&
src_arg_list
=
(
src_type
==
"input_format"
)
?
proto
.
inputs
()
:
proto
.
outputs
();
src_type
==
OpArgType
::
IN
?
proto
.
inputs
()
:
proto
.
outputs
();
for
(
const
auto
&
arg
:
src_arg_list
)
{
for
(
const
auto
&
arg
:
src_arg_list
)
{
std
::
string
src_name
=
arg
.
name
();
if
(
arg
.
no_gradient
()
&&
!
is_grad
)
continue
;
std
::
string
dst_name
=
is_grad
?
src_name
+
kGradVarSuffix
:
src_name
;
const
std
::
string
src_name
=
arg
.
name
();
grad_idxs
[
dst_name
]
=
idx
++
;
std
::
string
dst_name
=
is_grad
?
GradVarName
(
src_name
)
:
src_name
;
int
src_arg_idx
=
src_op
->
in_out_idxs_
->
at
(
src_name
);
dst_inout
[
dst_name
].
reserve
(
src_inout
.
at
(
src_name
).
size
());
int
src_begin
=
for
(
auto
&
var_name
:
src_inout
.
at
(
src_name
))
{
src_format
==
nullptr
?
src_arg_idx
:
src_format
->
at
(
src_arg_idx
);
std
::
string
s
=
is_grad
?
GradVarName
(
var_name
)
:
var_name
;
int
src_end
=
src_format
==
nullptr
?
src_arg_idx
+
1
dst_inout
[
dst_name
].
emplace_back
(
s
);
:
src_format
->
at
(
src_arg_idx
+
1
);
for
(
int
i
=
src_begin
;
i
<
src_end
;
++
i
)
{
std
::
string
s
=
is_grad
?
src_inout
[
i
]
+
kGradVarSuffix
:
(
arg
.
ignore_gradient
()
?
kEmptyVarName
:
src_inout
[
i
]);
dst_inout
.
emplace_back
(
s
);
}
if
(
dst_format
!=
nullptr
)
{
dst_format
->
push_back
(
dst_inout
.
size
());
}
}
}
}
}
}
OperatorBase
*
BuildGradOp
(
const
OperatorBase
*
op
)
{
OperatorBase
*
BuildGradOp
(
const
OperatorBase
*
op
)
{
const
std
::
string
&
grad_op_type
=
OpRegistry
::
grad_ops
().
at
(
op
->
Type
());
std
::
string
grad_op_type
=
OpRegistry
::
grad_ops
().
at
(
op
->
type_
);
AttributeMap
grad_attrs
(
op
->
Attrs
());
grad_attrs
.
erase
(
"input_format"
);
grad_attrs
.
erase
(
"output_format"
);
if
(
op
->
Attrs
().
count
(
"input_format"
)
>
0
)
{
grad_attrs
[
"output_format"
]
=
std
::
vector
<
int
>
({
0
});
}
if
(
op
->
Attrs
().
count
(
"input_format"
)
>
0
||
op
->
Attrs
().
count
(
"output_format"
)
>
0
)
{
grad_attrs
[
"input_format"
]
=
std
::
vector
<
int
>
({
0
});
}
std
::
vector
<
std
::
string
>
grad_inputs
,
grad_outputs
;
using
VarIndexMap
=
std
::
unordered_map
<
std
::
string
,
int
>
;
VarIndexMap
*
grad_idxs
=
new
VarIndexMap
;
int
in_idx
=
0
;
int
out_idx
=
0
;
TransOpArg
(
op
,
grad_inputs
,
grad_outputs
,
grad_attrs
,
*
grad_idxs
,
"input_format"
,
"input_format"
,
in_idx
,
false
);
// I
TransOpArg
(
op
,
grad_inputs
,
grad_outputs
,
grad_attrs
,
*
grad_idxs
,
"output_format"
,
"input_format"
,
in_idx
,
false
);
// G
TransOpArg
(
op
,
grad_inputs
,
grad_outputs
,
grad_attrs
,
*
grad_idxs
,
"output_format"
,
"input_format"
,
in_idx
,
true
);
// OG
TransOpArg
(
op
,
grad_inputs
,
grad_outputs
,
grad_attrs
,
*
grad_idxs
,
"input_format"
,
"output_format"
,
out_idx
,
true
);
// IG
OperatorBase
*
grad_op
=
OpRegistry
::
op_creators
().
at
(
grad_op_type
)();
OperatorBase
*
grad_op
=
OpRegistry
::
op_creators
().
at
(
grad_op_type
)();
grad_op
->
type_
=
grad_op_type
;
grad_op
->
type_
=
grad_op_type
;
grad_op
->
inputs_
=
grad_inputs
;
grad_op
->
attrs_
=
op
->
attrs_
;
grad_op
->
outputs_
=
grad_outputs
;
TransOpArg
(
op
,
grad_op
,
OpArgType
::
IN
,
OpArgType
::
IN
,
false
);
// I
grad_op
->
attrs_
=
grad_attrs
;
TransOpArg
(
op
,
grad_op
,
OpArgType
::
OUT
,
OpArgType
::
IN
,
false
);
// O
grad_op
->
in_out_idxs_
.
reset
(
grad_idxs
);
TransOpArg
(
op
,
grad_op
,
OpArgType
::
OUT
,
OpArgType
::
IN
,
true
);
// OG
TransOpArg
(
op
,
grad_op
,
OpArgType
::
IN
,
OpArgType
::
OUT
,
true
);
// IG
return
grad_op
;
return
grad_op
;
}
}
...
...
paddle/framework/grad_op_builder_test.cc
浏览文件 @
81f5f861
...
@@ -10,8 +10,7 @@ namespace framework {
...
@@ -10,8 +10,7 @@ namespace framework {
class
NOP
:
public
OperatorBase
{
class
NOP
:
public
OperatorBase
{
public:
public:
DEFINE_OPERATOR_CTOR
(
NOP
,
OperatorBase
)
DEFINE_OPERATOR_CTOR
(
NOP
,
OperatorBase
);
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
Run
(
const
Scope
&
scope
,
void
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
...
@@ -22,10 +21,10 @@ class MutiInOutOpMaker : public OpProtoAndCheckerMaker {
...
@@ -22,10 +21,10 @@ class MutiInOutOpMaker : public OpProtoAndCheckerMaker {
MutiInOutOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
MutiInOutOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"In1"
,
"a single input"
);
AddInput
(
"In1"
,
"a single input"
);
AddInput
(
"In2_mult"
,
"a multiple input"
).
SetMultip
le
();
AddInput
(
"In2_mult"
,
"a multiple input"
).
AsDuplicab
le
();
AddInput
(
"In3"
,
"another single input"
);
AddInput
(
"In3"
,
"another single input"
);
AddOutput
(
"Out1"
,
"a single output"
);
AddOutput
(
"Out1"
,
"a single output"
);
AddOutput
(
"Out2_mult"
,
"a multiple output"
).
SetMultip
le
();
AddOutput
(
"Out2_mult"
,
"a multiple output"
).
AsDuplicab
le
();
AddComment
(
"test op with multiple inputs and outputs"
);
AddComment
(
"test op with multiple inputs and outputs"
);
}
}
};
};
...
@@ -35,10 +34,10 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
...
@@ -35,10 +34,10 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
IOIgnoredOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
IOIgnoredOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"In1"
,
"a single input"
);
AddInput
(
"In1"
,
"a single input"
);
AddInput
(
"In2_mult"
,
"a multiple input"
).
SetMultiple
().
Ignore
Gradient
();
AddInput
(
"In2_mult"
,
"a multiple input"
).
AsDuplicable
().
AsNo
Gradient
();
AddInput
(
"In3_mult"
,
"another multiple input"
).
SetMultip
le
();
AddInput
(
"In3_mult"
,
"another multiple input"
).
AsDuplicab
le
();
AddOutput
(
"Out1_mult"
,
"a multiple output"
).
SetMultip
le
();
AddOutput
(
"Out1_mult"
,
"a multiple output"
).
AsDuplicab
le
();
AddOutput
(
"Out2"
,
"a single output"
).
Ignore
Gradient
();
AddOutput
(
"Out2"
,
"a single output"
).
AsNo
Gradient
();
AddComment
(
"op with inputs and outputs ignored in gradient calculating"
);
AddComment
(
"op with inputs and outputs ignored in gradient calculating"
);
}
}
};
};
...
@@ -49,18 +48,18 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
...
@@ -49,18 +48,18 @@ class IOIgnoredOpMaker : public OpProtoAndCheckerMaker {
namespace
f
=
paddle
::
framework
;
namespace
f
=
paddle
::
framework
;
TEST
(
GradOpBuilder
,
AddTwo
)
{
TEST
(
GradOpBuilder
,
AddTwo
)
{
std
::
shared_ptr
<
f
::
OperatorBase
>
add_op
(
std
::
shared_ptr
<
f
::
OperatorBase
>
add_op
(
f
::
OpRegistry
::
CreateOp
(
f
::
OpRegistry
::
CreateOp
(
"add_two"
,
{
"x"
,
"y"
},
{
"out"
},
{}));
"add_two"
,
{{
"X"
,
{
"x"
}},
{
"Y"
,
{
"y"
}}},
{{
"Out"
,
{
"out"
}}
},
{}));
std
::
shared_ptr
<
f
::
OperatorBase
>
grad_add_op
=
std
::
shared_ptr
<
f
::
OperatorBase
>
grad_add_op
=
f
::
OpRegistry
::
CreateGradOp
(
*
add_op
);
f
::
OpRegistry
::
CreateGradOp
(
*
add_op
);
EXPECT_EQ
(
static_cast
<
int
>
(
grad_add_op
->
inputs_
.
size
()),
4
);
EXPECT_EQ
(
grad_add_op
->
inputs_
.
size
(),
4UL
);
EXPECT_EQ
(
static_cast
<
int
>
(
grad_add_op
->
outputs_
.
size
()),
2
);
EXPECT_EQ
(
grad_add_op
->
outputs_
.
size
(),
2UL
);
EXPECT_EQ
(
grad_add_op
->
Input
(
"X"
),
"x"
);
EXPECT_EQ
(
grad_add_op
->
Input
(
"X"
),
"x"
);
EXPECT_EQ
(
grad_add_op
->
Input
(
"Y"
),
"y"
);
EXPECT_EQ
(
grad_add_op
->
Input
(
"Y"
),
"y"
);
EXPECT_EQ
(
grad_add_op
->
Input
(
"Out"
),
"out"
);
EXPECT_EQ
(
grad_add_op
->
Input
(
"Out"
),
"out"
);
EXPECT_EQ
(
grad_add_op
->
Input
(
"Out@GRAD"
),
"out@GRAD"
);
EXPECT_EQ
(
grad_add_op
->
Input
(
f
::
GradVarName
(
"Out"
)),
f
::
GradVarName
(
"out"
)
);
EXPECT_EQ
(
grad_add_op
->
Output
(
"X@GRAD"
),
"x@GRAD"
);
EXPECT_EQ
(
grad_add_op
->
Output
(
f
::
GradVarName
(
"X"
)),
f
::
GradVarName
(
"x"
)
);
EXPECT_EQ
(
grad_add_op
->
Output
(
"Y@GRAD"
),
"y@GRAD"
);
EXPECT_EQ
(
grad_add_op
->
Output
(
f
::
GradVarName
(
"Y"
)),
f
::
GradVarName
(
"y"
)
);
}
}
REGISTER_OP
(
mult_io
,
f
::
NOP
,
f
::
MutiInOutOpMaker
);
REGISTER_OP
(
mult_io
,
f
::
NOP
,
f
::
MutiInOutOpMaker
);
...
@@ -69,15 +68,15 @@ REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker);
...
@@ -69,15 +68,15 @@ REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker);
REGISTER_GRADIENT_OP
(
io_ignored
,
io_ignored_grad
,
f
::
NOP
);
REGISTER_GRADIENT_OP
(
io_ignored
,
io_ignored_grad
,
f
::
NOP
);
TEST
(
GradOpBuilder
,
MutiInOut
)
{
TEST
(
GradOpBuilder
,
MutiInOut
)
{
f
::
AttributeMap
attrs
{{
"input_format"
,
std
::
vector
<
int
>
{
0
,
1
,
4
,
5
}},
{
"output_format"
,
std
::
vector
<
int
>
{
0
,
1
,
3
}}};
std
::
shared_ptr
<
f
::
OperatorBase
>
test_op
(
f
::
OpRegistry
::
CreateOp
(
std
::
shared_ptr
<
f
::
OperatorBase
>
test_op
(
f
::
OpRegistry
::
CreateOp
(
"mult_io"
,
{
"in1"
,
"in2_1"
,
"in2_2"
,
"in2_3"
,
"in3"
},
"mult_io"
,
{{
"In1"
,
{
"in1"
}},
{
"out1"
,
"out2_1"
,
"out2_2"
},
attrs
));
{
"In2_mult"
,
{
"in2_1"
,
"in2_2"
,
"in2_3"
}},
{
"In3"
,
{
"in3"
}}},
{{
"Out1"
,
{
"out1"
}},
{
"Out2_mult"
,
{
"out2_1"
,
"out2_2"
}}},
{}));
std
::
shared_ptr
<
f
::
OperatorBase
>
grad_test_op
=
std
::
shared_ptr
<
f
::
OperatorBase
>
grad_test_op
=
f
::
OpRegistry
::
CreateGradOp
(
*
test_op
);
f
::
OpRegistry
::
CreateGradOp
(
*
test_op
);
ASSERT_EQ
(
grad_test_op
->
inputs_
.
size
(),
5UL
+
3UL
+
3
UL
);
ASSERT_EQ
(
grad_test_op
->
inputs_
.
size
(),
3UL
+
2UL
+
2
UL
);
EXPECT_EQ
(
grad_test_op
->
Input
(
"In1"
),
"in1"
);
EXPECT_EQ
(
grad_test_op
->
Input
(
"In1"
),
"in1"
);
EXPECT_EQ
(
grad_test_op
->
Inputs
(
"In2_mult"
),
EXPECT_EQ
(
grad_test_op
->
Inputs
(
"In2_mult"
),
std
::
vector
<
std
::
string
>
({
"in2_1"
,
"in2_2"
,
"in2_3"
}));
std
::
vector
<
std
::
string
>
({
"in2_1"
,
"in2_2"
,
"in2_3"
}));
...
@@ -91,7 +90,7 @@ TEST(GradOpBuilder, MutiInOut) {
...
@@ -91,7 +90,7 @@ TEST(GradOpBuilder, MutiInOut) {
std
::
vector
<
std
::
string
>
(
std
::
vector
<
std
::
string
>
(
{
f
::
GradVarName
(
"out2_1"
),
f
::
GradVarName
(
"out2_2"
)}));
{
f
::
GradVarName
(
"out2_1"
),
f
::
GradVarName
(
"out2_2"
)}));
ASSERT_EQ
(
grad_test_op
->
outputs_
.
size
(),
5
UL
);
ASSERT_EQ
(
grad_test_op
->
outputs_
.
size
(),
3
UL
);
EXPECT_EQ
(
grad_test_op
->
Output
(
f
::
GradVarName
(
"In1"
)),
f
::
GradVarName
(
"in1"
));
EXPECT_EQ
(
grad_test_op
->
Output
(
f
::
GradVarName
(
"In1"
)),
f
::
GradVarName
(
"in1"
));
EXPECT_EQ
(
grad_test_op
->
Outputs
(
f
::
GradVarName
(
"In2_mult"
)),
EXPECT_EQ
(
grad_test_op
->
Outputs
(
f
::
GradVarName
(
"In2_mult"
)),
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"in2_1"
),
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"in2_1"
),
...
@@ -101,31 +100,28 @@ TEST(GradOpBuilder, MutiInOut) {
...
@@ -101,31 +100,28 @@ TEST(GradOpBuilder, MutiInOut) {
}
}
TEST
(
GradOpBuilder
,
IOIgnoredInGradient
)
{
TEST
(
GradOpBuilder
,
IOIgnoredInGradient
)
{
f
::
AttributeMap
attrs
{{
"input_format"
,
std
::
vector
<
int
>
{
0
,
1
,
3
,
5
}},
{
"output_format"
,
std
::
vector
<
int
>
{
0
,
2
,
3
}}};
std
::
shared_ptr
<
f
::
OperatorBase
>
test_op
(
f
::
OpRegistry
::
CreateOp
(
std
::
shared_ptr
<
f
::
OperatorBase
>
test_op
(
f
::
OpRegistry
::
CreateOp
(
"io_ignored"
,
{
"in1"
,
"in2_1"
,
"in2_2"
,
"in3_1"
,
"in3_2"
},
"io_ignored"
,
{{
"In1"
,
{
"in1"
}},
{
"out1_1"
,
"out1_2"
,
"out2"
},
attrs
));
{
"In2_mult"
,
{
"in2_1"
,
"in2_2"
}},
{
"In3_mult"
,
{
"in3_1"
,
"in3_2"
}}},
{{
"Out1_mult"
,
{
"out1_1"
,
"out1_2"
}},
{
"Out2"
,
{
"out2"
}}},
{}));
std
::
shared_ptr
<
f
::
OperatorBase
>
grad_test_op
=
std
::
shared_ptr
<
f
::
OperatorBase
>
grad_test_op
=
f
::
OpRegistry
::
CreateGradOp
(
*
test_op
);
f
::
OpRegistry
::
CreateGradOp
(
*
test_op
);
// 'In2' and 'Out2' are ignored in gradient calculating
// 'In2' and 'Out2' are ignored in gradient calculating
ASSERT_EQ
(
grad_test_op
->
inputs_
.
size
(),
5UL
+
3UL
+
3
UL
);
ASSERT_EQ
(
grad_test_op
->
inputs_
.
size
(),
2UL
+
1UL
+
2
UL
);
EXPECT_EQ
(
grad_test_op
->
Input
(
"In1"
),
"in1"
);
EXPECT_EQ
(
grad_test_op
->
Input
(
"In1"
),
"in1"
);
EXPECT_EQ
(
grad_test_op
->
Inputs
(
"In2_mult"
),
std
::
vector
<
std
::
string
>
({
f
::
kEmptyVarName
,
f
::
kEmptyVarName
}));
EXPECT_EQ
(
grad_test_op
->
Inputs
(
"In3_mult"
),
EXPECT_EQ
(
grad_test_op
->
Inputs
(
"In3_mult"
),
std
::
vector
<
std
::
string
>
({
"in3_1"
,
"in3_2"
}));
std
::
vector
<
std
::
string
>
({
"in3_1"
,
"in3_2"
}));
EXPECT_EQ
(
grad_test_op
->
Inputs
(
"Out1_mult"
),
EXPECT_EQ
(
grad_test_op
->
Inputs
(
"Out1_mult"
),
std
::
vector
<
std
::
string
>
({
"out1_1"
,
"out1_2"
}));
std
::
vector
<
std
::
string
>
({
"out1_1"
,
"out1_2"
}));
EXPECT_EQ
(
grad_test_op
->
Input
(
"Out2"
),
f
::
kEmptyVarName
);
EXPECT_EQ
(
grad_test_op
->
Inputs
(
f
::
GradVarName
(
"Out1_mult"
)),
EXPECT_EQ
(
grad_test_op
->
Inputs
(
f
::
GradVarName
(
"Out1_mult"
)),
std
::
vector
<
std
::
string
>
(
std
::
vector
<
std
::
string
>
(
{
f
::
GradVarName
(
"out1_1"
),
f
::
GradVarName
(
"out1_2"
)}));
{
f
::
GradVarName
(
"out1_1"
),
f
::
GradVarName
(
"out1_2"
)}));
EXPECT_EQ
(
grad_test_op
->
Input
(
f
::
GradVarName
(
"Out2"
)),
EXPECT_EQ
(
grad_test_op
->
Input
(
f
::
GradVarName
(
"Out2"
)),
f
::
GradVarName
(
"out2"
));
f
::
GradVarName
(
"out2"
));
ASSERT_EQ
(
grad_test_op
->
outputs_
.
size
(),
5
UL
);
ASSERT_EQ
(
grad_test_op
->
outputs_
.
size
(),
3
UL
);
EXPECT_EQ
(
grad_test_op
->
Output
(
f
::
GradVarName
(
"In1"
)),
f
::
GradVarName
(
"in1"
));
EXPECT_EQ
(
grad_test_op
->
Output
(
f
::
GradVarName
(
"In1"
)),
f
::
GradVarName
(
"in1"
));
EXPECT_EQ
(
grad_test_op
->
Outputs
(
f
::
GradVarName
(
"In2_mult"
)),
EXPECT_EQ
(
grad_test_op
->
Outputs
(
f
::
GradVarName
(
"In2_mult"
)),
std
::
vector
<
std
::
string
>
(
std
::
vector
<
std
::
string
>
(
...
...
paddle/framework/op_desc.proto
已删除
100644 → 0
浏览文件 @
8747d60d
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax
=
"proto2"
;
package
paddle
.
framework
;
import
"attribute.proto"
;
// AttrDesc is used to describe Attributes of an Operator. It contain's
// name, type, and value of Attribute.
//
// e.g, for scale=3.0: name=scala, type=AttrType.FLOAT, value=3.0
message
AttrDesc
{
required
string
name
=
1
;
required
AttrType
type
=
2
;
optional
int32
i
=
3
;
optional
float
f
=
4
;
optional
string
s
=
5
;
repeated
int32
ints
=
6
;
repeated
float
floats
=
7
;
repeated
string
strings
=
8
;
};
// Protocol Message to describe an Operator.
//
// In PaddlePaddle, Operator is used to do a certain computation such
// as "add", "sub", "cosine", etc.
// (1) Operator needs to know the input and output variable names.
// (2) Some ops may have special attributes such as "scale" in "CosineOp".
//
// 3rd-party language can build this proto message and call
// AddOp(const OpDesc& op_desc) of Paddle core to create an Operator.
message
OpDesc
{
// input names of this Operator.
repeated
string
inputs
=
1
;
// output names of this Operator.
repeated
string
outputs
=
2
;
// type of this Operator, such as "add", "sub", "fc".
required
string
type
=
3
;
// Attributes of this Operator. e.g., scale=3.0 in cosine op.
repeated
AttrDesc
attrs
=
4
;
};
\ No newline at end of file
paddle/framework/op_desc_test.cc
已删除
100644 → 0
浏览文件 @
8747d60d
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <paddle/framework/op_desc.pb.h>
TEST
(
OpDesc
,
Create
)
{
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"add"
);
op_desc
.
add_inputs
(
"X"
);
op_desc
.
add_inputs
(
"Y"
);
op_desc
.
add_outputs
(
"Z"
);
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
attr
->
set_f
(
3.14
);
// required field name is not set, so IsInitialized should be false.
ASSERT_FALSE
(
op_desc
.
IsInitialized
());
attr
->
set_name
(
"add"
);
// after all required fields are set, IsInitialized should be true now.
ASSERT_TRUE
(
op_desc
.
IsInitialized
());
}
\ No newline at end of file
paddle/framework/op_proto.proto
已删除
100644 → 0
浏览文件 @
8747d60d
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Protocol Message for 3rd-party language binding.
//
// Paddle Python package will use `OpProto` to generate op creation methods.
// The op creation methods take user's input and generate `OpDesc` proto
// message,
// then pass `OpDesc` to C++ side and create Op pointer.
//
syntax
=
"proto2"
;
package
paddle
.
framework
;
import
"attribute.proto"
;
// Attribute protocol message for 3rd-party language binding.
// It will store the Op support what attribute and what type.
message
AttrProto
{
// Supported attribute name. e.g. `scale` for cosine op.
required
string
name
=
1
;
// Supported attribute type.
required
AttrType
type
=
2
;
// Supported attribute comments. It helps 3rd-party language generate
// doc-string.
required
string
comment
=
3
;
// If that attribute is generated, it means the Paddle third language
// binding has responsibility to fill that attribute. End-User should
// not set that attribute.
optional
bool
generated
=
4
[
default
=
false
];
}
// Input or output message for 3rd-party language binding.
// It contains parameter name and its comments.
message
VarProto
{
// Input or output name in that op creation function.
// e.g. `cos(a, b, output, ...)`, "a", "b", "output" are names.
required
string
name
=
1
;
// The comment for that input. It helps 3rd-party language generate
// doc-string.
required
string
comment
=
2
;
// Is that input/output could be a list or not.
// If so, that Op should write a attributed named `input_format` or
// `output_format`.
//
// e.g.
// If the op is a fc op, the inputs are `X`, `W`, `b`. The `X` and `W`
// could be multiple, so the multiple of `X` and `W` is True, and OpDesc
// will hold a attribute of them.
//
// The Op desc of same fc could be
// {
// "type": "fc",
// "input": ["X1", "X2", "W1", "W2", "b"],
// "output": "fc.out",
// "attrs" : {
// "input_format": [0, 2, 4, 5]
// }
// }
//
optional
bool
multiple
=
3
[
default
=
false
];
// It marks that output is a temporary output. That output is not used by
// user, but used by other op internally as input. If other op is not use
// that output, it could be optimized early.
//
// Attribute temporary_index will be set in OpDesc if there is some
// outputs are temporary.
//
// output = [ "xxx.out1", "xxx.tmp", "xxx.out2"],
// attrs = {
// "temporary_index": [1]
// }
optional
bool
temporary
=
4
[
default
=
false
];
// The gradient of operator can be ignored immediately
// e.g. operator AddOp, y = x1 + x2, the gradient of dy/dx1, dy/dx2
// can be ignored for the future optimized on graph.
optional
bool
ignore_gradient
=
6
;
}
// Op protocol message for 3rd-party language binding.
// It contains all information for generating op creation method.
message
OpProto
{
// The input information to generate op creation method.
repeated
VarProto
inputs
=
1
;
// The output information to generate op creation method.
repeated
VarProto
outputs
=
2
;
// The attribute information to generate op creation method.
repeated
AttrProto
attrs
=
3
;
// The comments for that Op. It helps 3rd-party language generate
// doc-string. The whole documentation of that Op is generated by comment,
// inputs, outputs, attrs together.
required
string
comment
=
4
;
// The type of that Op.
required
string
type
=
5
;
}
paddle/framework/op_proto_test.cc
已删除
100644 → 0
浏览文件 @
8747d60d
#include <gtest/gtest.h>
#include <paddle/framework/op_proto.pb.h>
TEST
(
TestOpProto
,
ALL
)
{
paddle
::
framework
::
OpProto
proto
;
{
auto
ipt
=
proto
.
mutable_inputs
()
->
Add
();
*
ipt
->
mutable_name
()
=
"a"
;
*
ipt
->
mutable_comment
()
=
"the one input of cosine op"
;
}
{
auto
ipt
=
proto
.
mutable_inputs
()
->
Add
();
*
ipt
->
mutable_name
()
=
"b"
;
*
ipt
->
mutable_comment
()
=
"the other input of cosine op"
;
}
{
auto
opt
=
proto
.
mutable_outputs
()
->
Add
();
*
opt
->
mutable_name
()
=
"output"
;
*
opt
->
mutable_comment
()
=
"the output of cosine op"
;
}
{
auto
attr
=
proto
.
mutable_attrs
()
->
Add
();
*
attr
->
mutable_name
()
=
"scale"
;
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
*
attr
->
mutable_comment
()
=
"the scale attribute of cosine op"
;
}
proto
.
set_type
(
"cos"
);
*
proto
.
mutable_comment
()
=
"cosine op, output = scale * cos(a, b)"
;
ASSERT_TRUE
(
proto
.
IsInitialized
());
}
\ No newline at end of file
paddle/framework/op_registry.h
浏览文件 @
81f5f861
...
@@ -20,8 +20,9 @@ limitations under the License. */
...
@@ -20,8 +20,9 @@ limitations under the License. */
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include "paddle/framework/attribute.h"
#include "paddle/framework/attribute.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op
_desc.pb
.h"
#include "paddle/framework/op
erator
.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/scope.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -44,52 +45,48 @@ class OpProtoAndCheckerMaker {
...
@@ -44,52 +45,48 @@ class OpProtoAndCheckerMaker {
protected:
protected:
struct
VariableBuilder
{
struct
VariableBuilder
{
VarProto
*
var_
;
OpProto
::
Var
*
var_
;
std
::
function
<
void
()
>
on_multiple_
;
std
::
function
<
void
()
>
on_temporary_
;
VariableBuilder
&
SetMultiple
()
{
VariableBuilder
&
AsDuplicable
()
{
var_
->
set_multiple
(
true
);
var_
->
set_duplicable
(
true
);
on_multiple_
();
return
*
this
;
return
*
this
;
}
}
VariableBuilder
&
SetTemporary
()
{
VariableBuilder
&
AsIntermediate
()
{
PADDLE_ENFORCE
(
bool
(
on_temporary_
),
"Cannot set temporary"
);
var_
->
set_intermediate
(
true
);
var_
->
set_temporary
(
true
);
on_temporary_
();
return
*
this
;
return
*
this
;
}
}
VariableBuilder
&
IgnoreGradient
()
{
// TODO(FengJiayi, yuyang18): `AsNoGradient` is a very bad name, because it
var_
->
set_ignore_gradient
(
true
);
// means that input/output is not needed when calculate gradient. It does
// not mean no gradient when backward. It should be changed soon.
VariableBuilder
&
AsNoGradient
()
{
var_
->
set_no_gradient
(
true
);
return
*
this
;
return
*
this
;
}
}
};
};
VariableBuilder
AddInput
(
const
std
::
string
&
name
,
VariableBuilder
AddInput
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
)
{
const
std
::
string
&
comment
)
{
VarPro
to
*
input
=
proto_
->
add_inputs
();
au
to
*
input
=
proto_
->
add_inputs
();
input
->
set_name
(
name
);
input
->
set_name
(
name
);
input
->
set_comment
(
comment
);
input
->
set_comment
(
comment
);
return
VariableBuilder
{
input
,
[
=
]
{
this
->
SetHasMultipleInput
();
},
return
VariableBuilder
{
input
};
nullptr
};
}
}
VariableBuilder
AddOutput
(
const
std
::
string
&
name
,
VariableBuilder
AddOutput
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
)
{
const
std
::
string
&
comment
)
{
VarPro
to
*
output
=
proto_
->
add_outputs
();
au
to
*
output
=
proto_
->
add_outputs
();
output
->
set_name
(
name
);
output
->
set_name
(
name
);
output
->
set_comment
(
comment
);
output
->
set_comment
(
comment
);
return
VariableBuilder
{
output
,
[
=
]
{
this
->
SetHasMultipleOutput
();
},
return
VariableBuilder
{
output
};
[
=
]
{
this
->
SetHasTemporaryOutput
();
}};
}
}
template
<
typename
T
>
template
<
typename
T
>
TypedAttrChecker
<
T
>&
AddAttr
(
const
std
::
string
&
name
,
TypedAttrChecker
<
T
>&
AddAttr
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
,
const
std
::
string
&
comment
,
bool
generated
=
false
)
{
bool
generated
=
false
)
{
AttrPro
to
*
attr
=
proto_
->
add_attrs
();
au
to
*
attr
=
proto_
->
add_attrs
();
attr
->
set_name
(
name
);
attr
->
set_name
(
name
);
attr
->
set_comment
(
comment
);
attr
->
set_comment
(
comment
);
attr
->
set_generated
(
generated
);
attr
->
set_generated
(
generated
);
...
@@ -100,53 +97,6 @@ class OpProtoAndCheckerMaker {
...
@@ -100,53 +97,6 @@ class OpProtoAndCheckerMaker {
void
AddComment
(
const
std
::
string
&
comment
)
{
proto_
->
set_comment
(
comment
);
}
void
AddComment
(
const
std
::
string
&
comment
)
{
proto_
->
set_comment
(
comment
);
}
private:
private:
void
SetHasMultiple
(
const
std
::
string
&
in_out
,
bool
*
flag
)
{
if
(
!*
flag
)
{
AddAttr
<
std
::
vector
<
int
>>
(
in_out
+
"_format"
,
"The multiple index of "
+
in_out
+
"
\n
"
R
"DOC(
This attribute is used by Paddle core framework. Paddle's Op support each input
or output could be a list of variable. This attribute is used to show how that
list organized.
e.g.
input = ["
a
", "
b
", "
c
", "
d
", "
e
", "
f
"]
input_format = [0, 4, 5, 6]
means
The number of all input variables this op is six, and they are segmented into
three inputs.
The first input is input[0:4], second is input[4:5], third is input[5:6].
)DOC"
,
/*generated*/
true
);
*
flag
=
true
;
}
}
void
SetHasMultipleInput
()
{
SetHasMultiple
(
"input"
,
&
has_multiple_input_
);
}
void
SetHasMultipleOutput
()
{
SetHasMultiple
(
"output"
,
&
has_multiple_output_
);
}
void
SetHasTemporaryOutput
()
{
if
(
!
has_temporary_output_
)
{
AddAttr
<
std
::
vector
<
int
>>
(
"temporary_index"
,
R
"DOC(The temporary index of output.
Not all output of Paddle Op is used by user. For faster computation, each op
could output some its internal state to other op, other op could take that
output to make compute faster.
Add a mark to which output is temporary is helpful for future optimization.
)DOC"
,
/*generated*/
true
)
.
SetDefault
(
std
::
vector
<
int
>
());
has_temporary_output_
=
true
;
}
}
void
CheckNoDuplicatedInOutAttrs
()
{
void
CheckNoDuplicatedInOutAttrs
()
{
std
::
unordered_set
<
std
::
string
>
names
;
std
::
unordered_set
<
std
::
string
>
names
;
auto
checker
=
[
&
](
const
std
::
string
&
name
)
{
auto
checker
=
[
&
](
const
std
::
string
&
name
)
{
...
@@ -167,22 +117,18 @@ Add a mark to which output is temporary is helpful for future optimization.
...
@@ -167,22 +117,18 @@ Add a mark to which output is temporary is helpful for future optimization.
OpProto
*
proto_
;
OpProto
*
proto_
;
OpAttrChecker
*
op_checker_
;
OpAttrChecker
*
op_checker_
;
bool
validated_
{
false
};
bool
validated_
{
false
};
bool
has_multiple_input_
{
false
};
bool
has_multiple_output_
{
false
};
bool
has_temporary_output_
{
false
};
};
};
class
OpRegistry
{
class
OpRegistry
{
using
OpCreator
=
std
::
function
<
OperatorBase
*
()
>
;
using
OpCreator
=
std
::
function
<
OperatorBase
*
()
>
;
using
VarIndexMap
=
std
::
unordered_map
<
std
::
string
,
int
>
;
using
VarNameMap
=
OperatorBase
::
VarNameMap
;
using
VarNameList
=
std
::
vector
<
std
::
string
>
;
public:
public:
template
<
typename
OpType
,
typename
ProtoMakerType
>
template
<
typename
OpType
,
typename
ProtoMakerType
>
static
void
RegisterOp
(
const
std
::
string
&
op_type
)
{
static
void
RegisterOp
(
const
std
::
string
&
op_type
)
{
op_creators
()[
op_type
]
=
[]
{
return
new
OpType
;
};
op_creators
()[
op_type
]
=
[]
{
return
new
OpType
;
};
OpAttrChecker
&
op_checker
=
op_checkers
()[
op_type
];
OpAttrChecker
&
op_checker
=
op_checkers
()[
op_type
];
OpProto
&
op_proto
=
p
rotos
()[
op_type
];
OpProto
&
op_proto
=
OpP
rotos
()[
op_type
];
auto
maker
=
ProtoMakerType
(
&
op_proto
,
&
op_checker
);
auto
maker
=
ProtoMakerType
(
&
op_proto
,
&
op_checker
);
maker
.
Validate
();
maker
.
Validate
();
op_proto
.
set_type
(
op_type
);
op_proto
.
set_type
(
op_type
);
...
@@ -190,17 +136,6 @@ class OpRegistry {
...
@@ -190,17 +136,6 @@ class OpRegistry {
op_proto
.
IsInitialized
(),
op_proto
.
IsInitialized
(),
"Fail to initialize %s's OpProto, because %s is not initialized"
,
"Fail to initialize %s's OpProto, because %s is not initialized"
,
op_type
,
op_proto
.
InitializationErrorString
());
op_type
,
op_proto
.
InitializationErrorString
());
VarIndexMaps
()[
op_type
].
reset
(
new
VarIndexMap
());
auto
&
varmap
=
*
VarIndexMaps
()[
op_type
];
int
idx
=
0
;
for
(
auto
&
var
:
op_proto
.
inputs
())
{
varmap
[
var
.
name
()]
=
idx
++
;
}
idx
=
0
;
for
(
auto
&
var
:
op_proto
.
outputs
())
{
varmap
[
var
.
name
()]
=
idx
++
;
}
}
}
template
<
typename
GradOpType
>
template
<
typename
GradOpType
>
...
@@ -211,8 +146,8 @@ class OpRegistry {
...
@@ -211,8 +146,8 @@ class OpRegistry {
}
}
static
std
::
shared_ptr
<
OperatorBase
>
CreateOp
(
const
std
::
string
&
type
,
static
std
::
shared_ptr
<
OperatorBase
>
CreateOp
(
const
std
::
string
&
type
,
const
VarName
List
&
inputs
,
const
VarName
Map
&
inputs
,
const
VarName
List
&
outputs
,
const
VarName
Map
&
outputs
,
const
AttributeMap
&
attrs
)
{
const
AttributeMap
&
attrs
)
{
auto
op_create_it
=
op_creators
().
find
(
type
);
auto
op_create_it
=
op_creators
().
find
(
type
);
PADDLE_ENFORCE
(
op_create_it
!=
op_creators
().
end
(),
PADDLE_ENFORCE
(
op_create_it
!=
op_creators
().
end
(),
...
@@ -228,28 +163,26 @@ class OpRegistry {
...
@@ -228,28 +163,26 @@ class OpRegistry {
GenerateTempVariableName
(
op
);
GenerateTempVariableName
(
op
);
{
auto
var_index_it
=
VarIndexMaps
().
find
(
type
);
if
(
var_index_it
!=
VarIndexMaps
().
end
())
{
op
->
in_out_idxs_
=
var_index_it
->
second
;
}
}
op
->
Init
();
op
->
Init
();
return
std
::
shared_ptr
<
OperatorBase
>
(
op
);
return
std
::
shared_ptr
<
OperatorBase
>
(
op
);
}
}
static
std
::
shared_ptr
<
OperatorBase
>
CreateOp
(
const
OpDesc
&
op_desc
)
{
static
VarNameMap
ConvertOpDescVarsToVarNameMap
(
std
::
vector
<
std
::
string
>
inputs
;
const
google
::
protobuf
::
RepeatedPtrField
<
OpDesc
::
Var
>&
op_desc_vars
)
{
inputs
.
reserve
((
size_t
)
op_desc
.
inputs_size
());
VarNameMap
ret_val
;
std
::
copy
(
op_desc
.
inputs
().
begin
(),
op_desc
.
inputs
().
end
(),
for
(
auto
&
var
:
op_desc_vars
)
{
std
::
back_inserter
(
inputs
));
auto
&
var_names
=
ret_val
[
var
.
parameter
()];
auto
&
var_names_in_proto
=
var
.
arguments
();
std
::
vector
<
std
::
string
>
outputs
;
var_names
.
reserve
(
static_cast
<
size_t
>
(
var_names_in_proto
.
size
()));
outputs
.
reserve
((
size_t
)
op_desc
.
outputs_size
());
std
::
copy
(
var_names_in_proto
.
begin
(),
var_names_in_proto
.
end
(),
std
::
copy
(
op_desc
.
outputs
().
begin
(),
op_desc
.
outputs
().
end
(),
std
::
back_inserter
(
var_names
));
std
::
back_inserter
(
outputs
));
}
return
ret_val
;
}
static
std
::
shared_ptr
<
OperatorBase
>
CreateOp
(
const
OpDesc
&
op_desc
)
{
VarNameMap
inputs
=
ConvertOpDescVarsToVarNameMap
(
op_desc
.
inputs
());
VarNameMap
outputs
=
ConvertOpDescVarsToVarNameMap
(
op_desc
.
outputs
());
AttributeMap
attrs
;
AttributeMap
attrs
;
for
(
auto
&
attr
:
op_desc
.
attrs
())
{
for
(
auto
&
attr
:
op_desc
.
attrs
())
{
attrs
[
attr
.
name
()]
=
GetAttrValue
(
attr
);
attrs
[
attr
.
name
()]
=
GetAttrValue
(
attr
);
...
@@ -266,22 +199,11 @@ class OpRegistry {
...
@@ -266,22 +199,11 @@ class OpRegistry {
return
grad_op
;
return
grad_op
;
}
}
static
std
::
unordered_map
<
std
::
string
,
OpProto
>&
protos
()
{
static
std
::
unordered_map
<
std
::
string
,
OpProto
>
protos_
;
return
protos_
;
}
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
grad_ops
()
{
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
grad_ops
()
{
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
grad_ops_
;
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
grad_ops_
;
return
grad_ops_
;
return
grad_ops_
;
}
}
static
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
VarIndexMap
>>&
VarIndexMaps
()
{
static
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
VarIndexMap
>>
maps_
;
return
maps_
;
}
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>&
op_creators
()
{
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>&
op_creators
()
{
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>
op_creators_
;
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>
op_creators_
;
return
op_creators_
;
return
op_creators_
;
...
@@ -295,11 +217,13 @@ class OpRegistry {
...
@@ -295,11 +217,13 @@ class OpRegistry {
static
void
GenerateTempVariableName
(
OperatorBase
*
op
)
{
static
void
GenerateTempVariableName
(
OperatorBase
*
op
)
{
static
std
::
atomic
<
size_t
>
gUniqId
(
0UL
);
static
std
::
atomic
<
size_t
>
gUniqId
(
0UL
);
for
(
auto
&
outname
:
op
->
outputs_
)
{
for
(
auto
&
output
:
op
->
outputs_
)
{
if
(
outname
==
kTempVarName
)
{
for
(
auto
&
output_name
:
output
.
second
)
{
outname
+=
op
->
type_
;
if
(
output_name
==
kTempVarName
)
{
outname
+=
"@"
;
output_name
+=
op
->
type_
;
outname
+=
std
::
to_string
(
gUniqId
.
fetch_add
(
1
));
output_name
+=
"@"
;
output_name
+=
std
::
to_string
(
gUniqId
.
fetch_add
(
1
));
}
}
}
}
}
}
}
...
...
paddle/framework/op_registry_test.cc
浏览文件 @
81f5f861
...
@@ -7,8 +7,7 @@ namespace paddle {
...
@@ -7,8 +7,7 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
class
CosineOp
:
public
OperatorBase
{
class
CosineOp
:
public
OperatorBase
{
public:
public:
DEFINE_OPERATOR_CTOR
(
CosineOp
,
OperatorBase
)
DEFINE_OPERATOR_CTOR
(
CosineOp
,
OperatorBase
);
void
Run
(
const
Scope
&
scope
,
void
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
...
@@ -29,8 +28,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -29,8 +28,7 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class
MyTestOp
:
public
OperatorBase
{
class
MyTestOp
:
public
OperatorBase
{
public:
public:
DEFINE_OPERATOR_CTOR
(
MyTestOp
,
OperatorBase
)
DEFINE_OPERATOR_CTOR
(
MyTestOp
,
OperatorBase
);
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
Run
(
const
Scope
&
scope
,
void
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
...
@@ -40,8 +38,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -40,8 +38,8 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
public:
MyTestOpProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
MyTestOpProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"input"
,
"input of cosine op"
).
SetMultip
le
();
AddInput
(
"input"
,
"input of cosine op"
).
AsDuplicab
le
();
AddOutput
(
"output"
,
"output of cosine op"
).
SetTemporary
();
AddOutput
(
"output"
,
"output of cosine op"
).
AsIntermediate
();
auto
my_checker
=
[](
int
i
)
{
auto
my_checker
=
[](
int
i
)
{
PADDLE_ENFORCE
(
i
%
2
==
0
,
"'test_attr' must be even!"
);
PADDLE_ENFORCE
(
i
%
2
==
0
,
"'test_attr' must be even!"
);
};
};
...
@@ -53,6 +51,15 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -53,6 +51,15 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
static
void
BuildVar
(
const
std
::
string
&
param_name
,
std
::
initializer_list
<
const
char
*>
arguments
,
paddle
::
framework
::
OpDesc
::
Var
*
var
)
{
var
->
set_parameter
(
param_name
);
for
(
auto
&
arg_name
:
arguments
)
{
var
->
add_arguments
(
arg_name
);
}
}
REGISTER_OP
(
cos_sim
,
paddle
::
framework
::
CosineOp
,
REGISTER_OP
(
cos_sim
,
paddle
::
framework
::
CosineOp
,
paddle
::
framework
::
CosineOpProtoAndCheckerMaker
);
paddle
::
framework
::
CosineOpProtoAndCheckerMaker
);
REGISTER_OP
(
my_test_op
,
paddle
::
framework
::
MyTestOp
,
REGISTER_OP
(
my_test_op
,
paddle
::
framework
::
MyTestOp
,
...
@@ -61,8 +68,8 @@ REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
...
@@ -61,8 +68,8 @@ REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
TEST
(
OpRegistry
,
CreateOp
)
{
TEST
(
OpRegistry
,
CreateOp
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"cos_sim"
);
op_desc
.
set_type
(
"cos_sim"
);
op_desc
.
add_inputs
(
"aa"
);
BuildVar
(
"input"
,
{
"aa"
},
op_desc
.
add_inputs
()
);
op_desc
.
add_outputs
(
"bb"
);
BuildVar
(
"output"
,
{
"bb"
},
op_desc
.
add_outputs
()
);
float
scale
=
3.3
;
float
scale
=
3.3
;
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
...
@@ -82,8 +89,8 @@ TEST(OpRegistry, CreateOp) {
...
@@ -82,8 +89,8 @@ TEST(OpRegistry, CreateOp) {
TEST
(
OpRegistry
,
IllegalAttr
)
{
TEST
(
OpRegistry
,
IllegalAttr
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"cos_sim"
);
op_desc
.
set_type
(
"cos_sim"
);
op_desc
.
add_inputs
(
"aa"
);
BuildVar
(
"input"
,
{
"aa"
},
op_desc
.
add_inputs
()
);
op_desc
.
add_outputs
(
"bb"
);
BuildVar
(
"output"
,
{
"bb"
},
op_desc
.
add_outputs
()
);
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_name
(
"scale"
);
...
@@ -107,8 +114,8 @@ TEST(OpRegistry, IllegalAttr) {
...
@@ -107,8 +114,8 @@ TEST(OpRegistry, IllegalAttr) {
TEST
(
OpRegistry
,
DefaultValue
)
{
TEST
(
OpRegistry
,
DefaultValue
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"cos_sim"
);
op_desc
.
set_type
(
"cos_sim"
);
op_desc
.
add_inputs
(
"aa"
);
BuildVar
(
"input"
,
{
"aa"
},
op_desc
.
add_inputs
()
);
op_desc
.
add_outputs
(
"bb"
);
BuildVar
(
"output"
,
{
"bb"
},
op_desc
.
add_outputs
()
);
ASSERT_TRUE
(
op_desc
.
IsInitialized
());
ASSERT_TRUE
(
op_desc
.
IsInitialized
());
...
@@ -120,20 +127,11 @@ TEST(OpRegistry, DefaultValue) {
...
@@ -120,20 +127,11 @@ TEST(OpRegistry, DefaultValue) {
ASSERT_EQ
(
op
->
GetAttr
<
float
>
(
"scale"
),
1.0
);
ASSERT_EQ
(
op
->
GetAttr
<
float
>
(
"scale"
),
1.0
);
}
}
static
void
SetInputFormat
(
paddle
::
framework
::
OpDesc
*
desc
)
{
auto
attr
=
desc
->
add_attrs
();
attr
->
set_name
(
"input_format"
);
attr
->
set_type
(
paddle
::
framework
::
INTS
);
attr
->
mutable_ints
()
->
Add
(
0
);
attr
->
mutable_ints
()
->
Add
(
1
);
}
TEST
(
OpRegistry
,
CustomChecker
)
{
TEST
(
OpRegistry
,
CustomChecker
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"my_test_op"
);
op_desc
.
set_type
(
"my_test_op"
);
op_desc
.
add_inputs
(
"ii"
);
BuildVar
(
"input"
,
{
"ii"
},
op_desc
.
add_inputs
());
op_desc
.
add_outputs
(
"oo"
);
BuildVar
(
"output"
,
{
"oo"
},
op_desc
.
add_outputs
());
SetInputFormat
(
&
op_desc
);
// attr 'test_attr' is not set
// attr 'test_attr' is not set
bool
caught
=
false
;
bool
caught
=
false
;
...
@@ -173,7 +171,6 @@ TEST(OpRegistry, CustomChecker) {
...
@@ -173,7 +171,6 @@ TEST(OpRegistry, CustomChecker) {
attr
->
set_name
(
"test_attr"
);
attr
->
set_name
(
"test_attr"
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
INT
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
INT
);
attr
->
set_i
(
4
);
attr
->
set_i
(
4
);
SetInputFormat
(
&
op_desc
);
auto
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
auto
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
paddle
::
platform
::
CPUDeviceContext
dev_ctx
;
paddle
::
platform
::
CPUDeviceContext
dev_ctx
;
paddle
::
framework
::
Scope
scope
;
paddle
::
framework
::
Scope
scope
;
...
...
paddle/framework/operator.cc
浏览文件 @
81f5f861
...
@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <algorithm>
#include "paddle/framework/operator.h"
#include "paddle/framework/operator.h"
#include <algorithm>
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -33,84 +33,122 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
...
@@ -33,84 +33,122 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
}
}
#endif
#endif
const
std
::
string
&
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
static
std
::
unordered_map
<
std
::
string
,
OpProto
>*
g_op_protos
=
nullptr
;
PADDLE_ENFORCE_NOT_NULL
(
in_out_idxs_
,
std
::
unordered_map
<
std
::
string
,
OpProto
>&
OpProtos
()
{
"Input Output Indices could not be nullptr"
);
if
(
g_op_protos
==
nullptr
)
{
auto
it
=
in_out_idxs_
->
find
(
name
);
g_op_protos
=
new
std
::
unordered_map
<
std
::
string
,
OpProto
>
();
PADDLE_ENFORCE
(
it
!=
in_out_idxs_
->
end
(),
"no key [%s] in in_out_idxs_"
,
name
);
if
(
attrs_
.
count
(
"input_format"
)
==
0
)
{
return
inputs_
.
at
((
size_t
)
it
->
second
);
}
else
{
const
auto
&
input_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"input_format"
);
int
idx
=
input_format
[
it
->
second
];
return
inputs_
.
at
((
size_t
)
idx
);
}
}
return
*
g_op_protos
;
}
}
std
::
vector
<
std
::
string
>
OperatorBase
::
Inputs
(
const
std
::
string
&
name
)
const
{
const
std
::
string
&
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
in_out_idxs_
,
"IO Idx could not be nullptr"
);
auto
&
ins
=
Inputs
(
name
);
auto
input_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"input_format"
);
PADDLE_ENFORCE_EQ
(
ins
.
size
(),
1UL
,
auto
offset
=
in_out_idxs_
->
at
(
name
);
"Op %s input %s should contain only one variable"
,
type_
,
PADDLE_ENFORCE
(
input_format
.
at
(
static_cast
<
size_t
>
(
offset
)
+
1
)
<=
name
);
static_cast
<
int
>
(
inputs_
.
size
()),
return
ins
[
0
];
"Input Out Of Range"
);
return
std
::
vector
<
std
::
string
>
{
inputs_
.
begin
()
+
input_format
.
at
(
offset
),
inputs_
.
begin
()
+
input_format
.
at
(
offset
+
1
)};
}
}
const
std
::
string
&
OperatorBase
::
Output
(
const
std
::
string
&
name
)
const
{
const
std
::
vector
<
std
::
string
>&
OperatorBase
::
Inputs
(
PADDLE_ENFORCE_NOT_NULL
(
in_out_idxs_
,
"InOut Indice could not be nullptr"
);
const
std
::
string
&
name
)
const
{
auto
it
=
in
_out_idxs_
->
find
(
name
);
auto
it
=
in
puts_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
in
_out_idxs_
->
end
(),
"no key [%s] in in_out_idxs_"
,
PADDLE_ENFORCE
(
it
!=
in
puts_
.
end
(),
"Op %s do not have input %s"
,
type_
,
name
);
name
);
if
(
attrs_
.
count
(
"output_format"
)
==
0
)
{
return
it
->
second
;
return
outputs_
.
at
((
size_t
)
it
->
second
);
}
}
else
{
const
auto
&
output_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"output_format"
);
const
std
::
string
&
OperatorBase
::
Output
(
const
std
::
string
&
name
)
const
{
int
idx
=
output_format
[
it
->
second
];
auto
&
outs
=
Outputs
(
name
);
return
outputs_
.
at
((
size_t
)
idx
);
PADDLE_ENFORCE_EQ
(
outs
.
size
(),
1UL
,
}
"Op %s output %s should contain only one variable"
,
type_
,
name
);
return
outs
[
0
];
}
}
std
::
vector
<
std
::
string
>
OperatorBase
::
Outputs
(
const
std
::
string
&
name
)
const
{
const
std
::
vector
<
std
::
string
>&
OperatorBase
::
Outputs
(
PADDLE_ENFORCE_NOT_NULL
(
in_out_idxs_
,
"InOut Indice could not be nullptr"
);
const
std
::
string
&
name
)
const
{
auto
output_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"output_format"
);
auto
it
=
outputs_
.
find
(
name
);
auto
offset
=
in_out_idxs_
->
at
(
name
);
PADDLE_ENFORCE
(
it
!=
outputs_
.
end
(),
"Op %s does not have output %s"
,
type_
,
PADDLE_ENFORCE
(
output_format
.
at
(
static_cast
<
size_t
>
(
offset
)
+
1
)
<=
name
);
static_cast
<
int
>
(
outputs_
.
size
()),
return
it
->
second
;
"Output Out of Range"
);
return
std
::
vector
<
std
::
string
>
{
outputs_
.
begin
()
+
output_format
.
at
(
offset
),
outputs_
.
begin
()
+
output_format
.
at
(
offset
+
1
)};
}
}
std
::
string
OperatorBase
::
DebugString
()
const
{
std
::
string
OperatorBase
::
DebugString
()
const
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
ss
<<
"Op("
<<
type_
<<
"), inputs:("
;
ss
<<
"Op("
<<
type_
<<
"), inputs:{"
;
for
(
size_t
i
=
0
;
i
<
inputs_
.
size
();
++
i
)
{
for
(
auto
it
=
inputs_
.
begin
();
it
!=
inputs_
.
end
();)
{
ss
<<
inputs_
[
i
];
auto
&
input
=
*
it
;
if
(
i
!=
inputs_
.
size
()
-
1
)
{
ss
<<
input
.
first
<<
"["
;
for
(
size_t
i
=
0
;
i
<
input
.
second
.
size
();
++
i
)
{
ss
<<
input
.
second
[
i
];
if
(
i
!=
input
.
second
.
size
()
-
1
)
{
ss
<<
", "
;
}
}
ss
<<
"]"
;
++
it
;
if
(
it
!=
inputs_
.
end
())
{
ss
<<
", "
;
ss
<<
", "
;
}
}
}
}
ss
<<
"), outputs:("
;
ss
<<
"}, outputs:{"
;
for
(
size_t
i
=
0
;
i
<
outputs_
.
size
();
++
i
)
{
for
(
auto
it
=
outputs_
.
begin
();
it
!=
outputs_
.
end
();)
{
ss
<<
outputs_
[
i
];
auto
&
output
=
*
it
;
if
(
i
!=
outputs_
.
size
()
-
1
)
{
ss
<<
output
.
first
<<
"["
;
for
(
size_t
i
=
0
;
i
<
output
.
second
.
size
();
++
i
)
{
ss
<<
output
.
second
[
i
];
if
(
i
!=
output
.
second
.
size
()
-
1
)
{
ss
<<
", "
;
}
}
ss
<<
"]"
;
++
it
;
if
(
it
!=
outputs_
.
end
())
{
ss
<<
", "
;
ss
<<
", "
;
}
}
}
}
ss
<<
"
)
."
;
ss
<<
"
}
."
;
return
ss
.
str
();
return
ss
.
str
();
}
}
void
OperatorBase
::
Rename
(
const
std
::
string
&
old_name
,
void
OperatorBase
::
Rename
(
const
std
::
string
&
old_name
,
const
std
::
string
&
new_name
)
{
const
std
::
string
&
new_name
)
{
std
::
replace
(
inputs_
.
begin
(),
inputs_
.
end
(),
old_name
,
new_name
);
for
(
auto
&
input
:
inputs_
)
{
std
::
replace
(
outputs_
.
begin
(),
outputs_
.
end
(),
old_name
,
new_name
);
std
::
replace
(
input
.
second
.
begin
(),
input
.
second
.
end
(),
old_name
,
new_name
);
}
for
(
auto
&
output
:
outputs_
)
{
std
::
replace
(
output
.
second
.
begin
(),
output
.
second
.
end
(),
old_name
,
new_name
);
}
}
std
::
vector
<
std
::
string
>
OperatorBase
::
OutputVars
(
bool
has_intermediate
)
const
{
std
::
vector
<
std
::
string
>
ret_val
;
if
(
has_intermediate
)
{
// push all outputs into ret_val
for
(
auto
&
o
:
outputs_
)
{
ret_val
.
reserve
(
ret_val
.
size
()
+
o
.
second
.
size
());
ret_val
.
insert
(
ret_val
.
end
(),
o
.
second
.
begin
(),
o
.
second
.
end
());
}
return
ret_val
;
}
auto
it
=
OpProtos
().
find
(
type_
);
PADDLE_ENFORCE
(
it
!=
OpProtos
().
end
(),
"Operator %s not registered, cannot figure out intermediate outputs"
,
type_
);
// get all OpProto::Var for outputs
for
(
auto
&
o
:
it
->
second
.
outputs
())
{
// ignore all intermediate output
if
(
o
.
intermediate
())
continue
;
auto
out
=
outputs_
.
find
(
o
.
name
());
if
(
out
!=
outputs_
.
end
())
{
ret_val
.
reserve
(
ret_val
.
size
()
+
out
->
second
.
size
());
ret_val
.
insert
(
ret_val
.
end
(),
out
->
second
.
begin
(),
out
->
second
.
end
());
}
}
return
ret_val
;
}
}
}
// namespace framework
}
// namespace framework
...
...
paddle/framework/operator.h
浏览文件 @
81f5f861
...
@@ -20,8 +20,7 @@ limitations under the License. */
...
@@ -20,8 +20,7 @@ limitations under the License. */
#include <vector>
#include <vector>
#include "paddle/framework/attribute.h"
#include "paddle/framework/attribute.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/device_context.h"
...
@@ -51,6 +50,8 @@ inline std::string GradVarName(const std::string& var_name) {
...
@@ -51,6 +50,8 @@ inline std::string GradVarName(const std::string& var_name) {
return
var_name
+
kGradVarSuffix
;
return
var_name
+
kGradVarSuffix
;
}
}
extern
std
::
unordered_map
<
std
::
string
,
OpProto
>&
OpProtos
();
class
OperatorBase
;
class
OperatorBase
;
class
InferShapeContext
;
class
InferShapeContext
;
class
ExecutionContext
;
class
ExecutionContext
;
...
@@ -63,16 +64,16 @@ class ExecutionContext;
...
@@ -63,16 +64,16 @@ class ExecutionContext;
*/
*/
class
OperatorBase
{
class
OperatorBase
{
public:
public:
OperatorBase
()
{}
// TODO(yi): This constructor is to be removed.
using
VarNameMap
=
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
;
OperatorBase
(
const
std
::
string
&
type
,
const
std
::
vector
<
std
::
string
>&
inputs
,
const
std
::
vector
<
std
::
string
>&
outputs
,
OperatorBase
()
=
default
;
const
AttributeMap
&
attr
s
,
OperatorBase
(
const
std
::
string
&
type
,
const
VarNameMap
&
input
s
,
std
::
unordered_map
<
std
::
string
,
int
>*
in_out_idx
s
)
const
VarNameMap
&
outputs
,
const
AttributeMap
&
attr
s
)
:
type_
(
type
),
:
type_
(
type
),
inputs_
(
inputs
),
outputs_
(
outputs
),
attrs_
(
attrs
)
{}
inputs_
(
inputs
),
outputs_
(
outputs
),
OperatorBase
(
const
OperatorBase
&
o
)
=
delete
;
attrs_
(
attrs
),
OperatorBase
&
operator
=
(
const
OperatorBase
&
o
)
=
delete
;
in_out_idxs_
(
in_out_idxs
)
{}
OperatorBase
(
OperatorBase
&&
o
)
=
delete
;
virtual
~
OperatorBase
()
{}
virtual
~
OperatorBase
()
{}
...
@@ -107,22 +108,18 @@ class OperatorBase {
...
@@ -107,22 +108,18 @@ class OperatorBase {
//! Get a input with argument's name described in `op_proto`
//! Get a input with argument's name described in `op_proto`
const
std
::
string
&
Input
(
const
std
::
string
&
name
)
const
;
const
std
::
string
&
Input
(
const
std
::
string
&
name
)
const
;
//! Get a input which has multiple variables.
//! Get a input which has multiple variables.
//! TODO add a vector_view to prevent memory copy.
const
std
::
vector
<
std
::
string
>&
Inputs
(
const
std
::
string
&
name
)
const
;
std
::
vector
<
std
::
string
>
Inputs
(
const
std
::
string
&
name
)
const
;
//! Get a output with argument's name described in `op_proto`
//! Get a output with argument's name described in `op_proto`
const
std
::
string
&
Output
(
const
std
::
string
&
name
)
const
;
const
std
::
string
&
Output
(
const
std
::
string
&
name
)
const
;
//! Get an output which has multiple variables.
//! Get an output which has multiple variables.
//! TODO add a vector_view to prevent memory copy.
//! TODO add a vector_view to prevent memory copy.
std
::
vector
<
std
::
string
>
Outputs
(
const
std
::
string
&
name
)
const
;
const
std
::
vector
<
std
::
string
>&
Outputs
(
const
std
::
string
&
name
)
const
;
virtual
std
::
vector
<
std
::
string
>
OutputVars
(
bool
has_intermediate
)
const
;
const
std
::
string
Type
()
const
{
return
type_
;
}
std
::
string
Type
()
const
{
return
type_
;
}
const
std
::
vector
<
std
::
string
>
Inputs
()
const
{
return
inputs_
;
}
const
std
::
vector
<
std
::
string
>
Outputs
()
const
{
return
outputs_
;
}
const
AttributeMap
&
Attrs
()
const
{
return
attrs_
;
}
const
AttributeMap
&
Attrs
()
const
{
return
attrs_
;
}
const
std
::
unordered_map
<
std
::
string
,
int
>*
InOutIdx
()
const
{
return
in_out_idxs_
.
get
();
}
public:
public:
std
::
string
type_
;
std
::
string
type_
;
...
@@ -130,30 +127,34 @@ class OperatorBase {
...
@@ -130,30 +127,34 @@ class OperatorBase {
// I (Inputs)
// I (Inputs)
// O (Outputs)
// O (Outputs)
// OG (Output Gradients)
// OG (Output Gradients)
std
::
vector
<
std
::
string
>
inputs_
;
VarNameMap
inputs_
;
// NOTE: in case of OpGrad, outputs_ contains
// NOTE: in case of OpGrad, outputs_ contains
// IG (Inputs Gradients)
// IG (Inputs Gradients)
std
::
vector
<
std
::
string
>
outputs_
;
VarNameMap
outputs_
;
AttributeMap
attrs_
;
AttributeMap
attrs_
;
// store the arguments' offset described in op_desc.
std
::
shared_ptr
<
std
::
unordered_map
<
std
::
string
,
int
>>
in_out_idxs_
;
};
};
#define DEFINE_OPERATOR_CTOR(Class, ParentClass) \
public: \
Class() : ParentClass() {
/* TODO(yi): This constructor is to be removed. */
\
} \
Class(const std::string& type, const VarNameMap& inputs, \
const VarNameMap& outputs, \
const paddle::framework::AttributeMap& attrs) \
: ParentClass(type, inputs, outputs, attrs) {}
class
InferShapeContext
{
class
InferShapeContext
{
public:
public:
InferShapeContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
)
InferShapeContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
)
:
op_
(
op
),
scope_
(
scope
)
{}
:
op_
(
op
),
scope_
(
scope
)
{}
size_t
InputSize
()
const
{
return
op_
.
inputs_
.
size
();
}
size_t
InputSize
(
const
std
::
string
&
name
)
const
{
return
op_
.
Inputs
(
name
).
size
();
size_t
OutputSize
()
const
{
return
op_
.
outputs_
.
size
();
}
const
Variable
*
InputVar
(
const
size_t
index
)
const
{
return
scope_
.
FindVar
(
op_
.
inputs_
.
at
(
index
));
}
}
Variable
*
OutputVar
(
const
size_t
index
)
const
{
size_t
OutputSize
(
const
std
::
string
&
name
)
const
{
return
scope_
.
FindVar
(
op_
.
outputs_
.
at
(
index
)
);
return
op_
.
Outputs
(
name
).
size
(
);
}
}
const
Variable
*
InputVar
(
const
std
::
string
&
name
)
const
{
const
Variable
*
InputVar
(
const
std
::
string
&
name
)
const
{
...
@@ -185,27 +186,9 @@ class InferShapeContext {
...
@@ -185,27 +186,9 @@ class InferShapeContext {
return
res
;
return
res
;
}
}
template
<
typename
T
>
const
T
*
Input
(
const
size_t
index
)
const
{
auto
var
=
InputVar
(
index
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
"Input(%d) should not be nullptr"
,
index
);
return
&
var
->
Get
<
T
>
();
}
template
<
typename
T
>
T
*
Output
(
const
size_t
index
)
const
{
auto
var
=
OutputVar
(
index
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
"Output(%d) not be nullptr, which means variable [%s] does not "
"exist in scope"
,
index
,
op_
.
outputs_
[
index
]);
return
var
->
GetMutable
<
T
>
();
}
template
<
typename
T
>
template
<
typename
T
>
const
T
*
Input
(
const
std
::
string
&
name
)
const
{
const
T
*
Input
(
const
std
::
string
&
name
)
const
{
auto
var
=
InputVar
(
name
);
auto
*
var
=
InputVar
(
name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
"Input(%s) should not be nullptr"
,
name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
"Input(%s) should not be nullptr"
,
name
);
return
&
var
->
Get
<
T
>
();
return
&
var
->
Get
<
T
>
();
}
}
...
@@ -300,13 +283,7 @@ class OpKernel {
...
@@ -300,13 +283,7 @@ class OpKernel {
class
OperatorWithKernel
:
public
OperatorBase
{
class
OperatorWithKernel
:
public
OperatorBase
{
public:
public:
OperatorWithKernel
()
{}
// TODO(yi): This constructor is to be removed.
DEFINE_OPERATOR_CTOR
(
OperatorWithKernel
,
OperatorBase
)
OperatorWithKernel
(
const
std
::
string
&
type
,
const
std
::
vector
<
std
::
string
>&
inputs
,
const
std
::
vector
<
std
::
string
>&
outputs
,
const
AttributeMap
&
attrs
,
std
::
unordered_map
<
std
::
string
,
int
>*
in_out_idxs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
,
in_out_idxs
)
{}
struct
OpKernelKey
{
struct
OpKernelKey
{
platform
::
Place
place_
;
platform
::
Place
place_
;
...
@@ -357,15 +334,5 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -357,15 +334,5 @@ class OperatorWithKernel : public OperatorBase {
virtual
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
=
0
;
virtual
void
InferShape
(
const
InferShapeContext
&
ctx
)
const
=
0
;
};
};
#define DEFINE_OPERATOR_CTOR(Class, ParentClass) \
public: \
Class() {
/* TODO(yi): This constructor is to be removed. */
\
} \
Class(const std::string& type, const std::vector<std::string>& inputs, \
const std::vector<std::string>& outputs, \
const ::paddle::framework::AttributeMap& attrs, \
std::unordered_map<std::string, int>* in_out_idxs) \
: ParentClass(type, inputs, outputs, attrs, in_out_idxs) {}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/operator_test.cc
浏览文件 @
81f5f861
...
@@ -22,19 +22,19 @@ namespace framework {
...
@@ -22,19 +22,19 @@ namespace framework {
static
int
op_run_num
=
0
;
static
int
op_run_num
=
0
;
class
OpWithoutKernelTest
:
public
OperatorBase
{
class
OpWithoutKernelTest
:
public
OperatorBase
{
public:
DEFINE_OPERATOR_CTOR
(
OpWithoutKernelTest
,
framework
::
OperatorBase
)
DEFINE_OPERATOR_CTOR
(
OpWithoutKernelTest
,
OperatorBase
)
public:
void
Init
()
override
{
x
=
1
;
}
void
Init
()
override
{
x
=
1
;
}
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
Run
(
const
Scope
&
scope
,
void
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
op_run_num
++
;
++
op_run_num
;
ASSERT_EQ
(
(
int
)
inputs_
.
size
(
),
1
);
ASSERT_EQ
(
static_cast
<
int
>
(
inputs_
.
size
()
),
1
);
ASSERT_EQ
(
(
int
)
outputs_
.
size
(
),
1
);
ASSERT_EQ
(
static_cast
<
int
>
(
outputs_
.
size
()
),
1
);
ASSERT_EQ
(
scope
.
FindVar
(
inputs_
[
0
]),
nullptr
);
ASSERT_EQ
(
scope
.
FindVar
(
inputs_
.
at
(
"input"
)
[
0
]),
nullptr
);
ASSERT_EQ
(
x
,
1
);
ASSERT_EQ
(
x
,
1
);
ASSERT_NE
(
scope
.
FindVar
(
outputs_
[
0
]),
nullptr
);
ASSERT_NE
(
scope
.
FindVar
(
outputs_
.
at
(
"output"
)
[
0
]),
nullptr
);
}
}
public:
public:
...
@@ -56,14 +56,24 @@ class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -56,14 +56,24 @@ class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
static
void
BuildVar
(
const
std
::
string
&
param_name
,
std
::
initializer_list
<
const
char
*>
arguments
,
paddle
::
framework
::
OpDesc
::
Var
*
var
)
{
var
->
set_parameter
(
param_name
);
for
(
auto
&
arg_name
:
arguments
)
{
*
var
->
mutable_arguments
()
->
Add
()
=
arg_name
;
}
}
REGISTER_OP
(
test_operator
,
paddle
::
framework
::
OpWithoutKernelTest
,
REGISTER_OP
(
test_operator
,
paddle
::
framework
::
OpWithoutKernelTest
,
paddle
::
framework
::
OpeWithoutKernelTestProtoAndCheckerMaker
);
paddle
::
framework
::
OpeWithoutKernelTestProtoAndCheckerMaker
);
TEST
(
OperatorBase
,
all
)
{
TEST
(
OperatorBase
,
all
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"test_operator"
);
op_desc
.
set_type
(
"test_operator"
);
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"IN1"
;
BuildVar
(
"input"
,
{
"IN1"
},
op_desc
.
add_inputs
());
*
op_desc
.
mutable_outputs
()
->
Add
()
=
"OUT1"
;
BuildVar
(
"output"
,
{
"OUT1"
},
op_desc
.
add_outputs
());
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_name
(
"scale"
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
...
@@ -99,8 +109,7 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -99,8 +109,7 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
static
int
cpu_kernel_run_num
=
0
;
static
int
cpu_kernel_run_num
=
0
;
class
OpWithKernelTest
:
public
OperatorWithKernel
{
class
OpWithKernelTest
:
public
OperatorWithKernel
{
public:
DEFINE_OPERATOR_CTOR
(
OpWithKernelTest
,
framework
::
OperatorWithKernel
)
DEFINE_OPERATOR_CTOR
(
OpWithKernelTest
,
OperatorWithKernel
)
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{}
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{}
};
};
...
@@ -117,35 +126,15 @@ class CPUKernelTest : public OpKernel {
...
@@ -117,35 +126,15 @@ class CPUKernelTest : public OpKernel {
}
}
};
};
// multiple inputs test
class
OperatorMultiInputsTest
:
public
OperatorBase
{
public:
DEFINE_OPERATOR_CTOR
(
OperatorMultiInputsTest
,
OperatorBase
)
void
Init
()
override
{
x
=
1
;
}
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
ASSERT_EQ
(
scope
.
FindVar
(
inputs_
[
0
]),
nullptr
);
ASSERT_EQ
(
x
,
1
);
ASSERT_NE
(
scope
.
FindVar
(
outputs_
[
0
]),
nullptr
);
ASSERT_EQ
(
Input
(
"x"
),
"IN1"
);
ASSERT_EQ
(
Input
(
"y"
),
"OUT1"
);
}
public:
float
x
=
0
;
};
class
OpKernelTestMultiInputsProtoAndCheckerMaker
class
OpKernelTestMultiInputsProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
:
public
OpProtoAndCheckerMaker
{
public:
public:
OpKernelTestMultiInputsProtoAndCheckerMaker
(
OpProto
*
proto
,
OpKernelTestMultiInputsProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"xs"
,
"inputs of test op"
).
SetMultip
le
();
AddInput
(
"xs"
,
"inputs of test op"
).
AsDuplicab
le
();
AddInput
(
"k"
,
"input of test op"
);
AddInput
(
"k"
,
"input of test op"
);
AddOutput
(
"ys"
,
"outputs of test op"
).
SetMultip
le
();
AddOutput
(
"ys"
,
"outputs of test op"
).
AsDuplicab
le
();
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
)
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
)
.
SetDefault
(
1.0
)
.
SetDefault
(
1.0
)
.
LargerThan
(
0.0
);
.
LargerThan
(
0.0
);
...
@@ -202,8 +191,9 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel,
...
@@ -202,8 +191,9 @@ REGISTER_OP_CPU_KERNEL(op_with_kernel,
TEST
(
OpKernel
,
all
)
{
TEST
(
OpKernel
,
all
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"op_with_kernel"
);
op_desc
.
set_type
(
"op_with_kernel"
);
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"IN1"
;
BuildVar
(
"x"
,
{
"IN1"
},
op_desc
.
add_inputs
());
*
op_desc
.
mutable_outputs
()
->
Add
()
=
"OUT1"
;
BuildVar
(
"y"
,
{
"OUT1"
},
op_desc
.
add_outputs
());
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_name
(
"scale"
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
...
@@ -229,32 +219,15 @@ TEST(OpKernel, multi_inputs) {
...
@@ -229,32 +219,15 @@ TEST(OpKernel, multi_inputs) {
OpDesc
op_desc
;
OpDesc
op_desc
;
op_desc
.
set_type
(
"op_multi_inputs_with_kernel"
);
op_desc
.
set_type
(
"op_multi_inputs_with_kernel"
);
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"x0"
;
BuildVar
(
"xs"
,
{
"x0"
,
"x1"
,
"x2"
},
op_desc
.
add_inputs
());
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"x1"
;
BuildVar
(
"k"
,
{
"k0"
},
op_desc
.
add_inputs
());
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"x2"
;
BuildVar
(
"ys"
,
{
"y0"
,
"y1"
},
op_desc
.
add_outputs
());
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"k0"
;
*
op_desc
.
mutable_outputs
()
->
Add
()
=
"y0"
;
*
op_desc
.
mutable_outputs
()
->
Add
()
=
"y1"
;
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_name
(
"scale"
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
attr
->
set_f
(
3.14
);
attr
->
set_f
(
3.14
);
auto
attr0
=
op_desc
.
mutable_attrs
()
->
Add
();
attr0
->
set_name
(
"input_format"
);
attr0
->
set_type
(
paddle
::
framework
::
AttrType
::
INTS
);
auto
input_format
=
attr0
->
mutable_ints
();
input_format
->
Add
(
0
);
// x0
input_format
->
Add
(
3
);
// k
input_format
->
Add
(
4
);
// end
auto
attr1
=
op_desc
.
mutable_attrs
()
->
Add
();
attr1
->
set_name
(
"output_format"
);
attr1
->
set_type
(
paddle
::
framework
::
AttrType
::
INTS
);
auto
output_format
=
attr1
->
mutable_ints
();
output_format
->
Add
(
0
);
// y0
output_format
->
Add
(
2
);
// y1
paddle
::
platform
::
CPUDeviceContext
cpu_device_context
;
paddle
::
platform
::
CPUDeviceContext
cpu_device_context
;
paddle
::
framework
::
Scope
scope
;
paddle
::
framework
::
Scope
scope
;
scope
.
NewVar
(
"x0"
)
->
GetMutable
<
Tensor
>
();
scope
.
NewVar
(
"x0"
)
->
GetMutable
<
Tensor
>
();
...
...
paddle/framework/pybind.cc
浏览文件 @
81f5f861
...
@@ -56,30 +56,18 @@ void ExposeOperator(ClassType &m) {
...
@@ -56,30 +56,18 @@ void ExposeOperator(ClassType &m) {
return
op
.
type_
;
return
op
.
type_
;
})
})
.
def
(
"outputs"
,
.
def
(
"outputs"
,
[](
const
typename
ClassType
::
type
&
op
)
->
std
::
vector
<
std
::
string
>
{
[](
const
typename
ClassType
::
type
&
op
)
return
op
.
outputs_
;
->
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
{
})
return
op
.
outputs_
;
})
.
def
(
"inputs"
,
.
def
(
"inputs"
,
[](
const
typename
ClassType
::
type
&
op
)
->
std
::
vector
<
std
::
string
>
{
[](
const
typename
ClassType
::
type
&
op
)
{
return
op
.
inputs_
;
})
return
op
.
inputs_
;
.
def
(
"__str__"
,
&
ClassType
::
type
::
DebugString
)
})
.
def
(
"no_intermediate_outputs"
,
.
def
(
"support_gpu"
,
&
ClassType
::
type
::
SupportGPU
)
[](
const
typename
ClassType
::
type
&
op
)
{
.
def
(
"temp_outputs"
,
return
op
.
OutputVars
(
false
);
[](
const
typename
ClassType
::
type
&
op
)
->
std
::
vector
<
std
::
string
>
{
auto
iter
=
op
.
attrs_
.
find
(
"temporary_index"
);
std
::
vector
<
std
::
string
>
ret
;
if
(
iter
==
op
.
attrs_
.
end
())
{
return
ret
;
}
else
{
auto
tmp_idx
=
boost
::
get
<
std
::
vector
<
int
>>
(
iter
->
second
);
ret
.
reserve
(
tmp_idx
.
size
());
for
(
auto
&
index
:
tmp_idx
)
{
ret
.
push_back
(
op
.
outputs_
.
at
(
index
));
}
return
ret
;
}
})
})
.
def
(
"
__str__"
,
&
ClassType
::
type
::
DebugString
);
.
def
(
"
support_gpu"
,
&
ClassType
::
type
::
SupportGPU
);
}
}
static
size_t
UniqueIntegerGenerator
()
{
static
size_t
UniqueIntegerGenerator
()
{
...
@@ -172,7 +160,7 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -172,7 +160,7 @@ All parameter, weight, gradient are variables in Paddle.
//! @note: Be careful! PyBind will return std::string as an unicode, not
//! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python.
//! Python str. If you want a str object, you should cast them in Python.
m
.
def
(
"get_all_op_protos"
,
[]()
->
std
::
vector
<
py
::
bytes
>
{
m
.
def
(
"get_all_op_protos"
,
[]()
->
std
::
vector
<
py
::
bytes
>
{
auto
&
protos
=
Op
Registry
::
p
rotos
();
auto
&
protos
=
Op
P
rotos
();
std
::
vector
<
py
::
bytes
>
ret_values
;
std
::
vector
<
py
::
bytes
>
ret_values
;
for
(
auto
it
=
protos
.
begin
();
it
!=
protos
.
end
();
++
it
)
{
for
(
auto
it
=
protos
.
begin
();
it
!=
protos
.
end
();
++
it
)
{
PADDLE_ENFORCE
(
it
->
second
.
IsInitialized
(),
PADDLE_ENFORCE
(
it
->
second
.
IsInitialized
(),
...
...
paddle/operators/CMakeLists.txt
浏览文件 @
81f5f861
...
@@ -62,7 +62,7 @@ op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu)
...
@@ -62,7 +62,7 @@ op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu)
op_library
(
sgd_op SRCS sgd_op.cc sgd_op.cu
)
op_library
(
sgd_op SRCS sgd_op.cc sgd_op.cu
)
op_library
(
recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
op_library
(
recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS
op_desc
tensor op_registry operator net_op
)
DEPS
framework_proto
tensor op_registry operator net_op
)
cc_test
(
recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op
)
cc_test
(
recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op
)
op_library
(
uniform_random_op
op_library
(
uniform_random_op
SRCS uniform_random_op.cc uniform_random_op.cu
)
SRCS uniform_random_op.cc uniform_random_op.cu
)
paddle/operators/add_op.cc
浏览文件 @
81f5f861
...
@@ -19,16 +19,13 @@ namespace operators {
...
@@ -19,16 +19,13 @@ namespace operators {
class
AddOp
:
public
framework
::
OperatorWithKernel
{
class
AddOp
:
public
framework
::
OperatorWithKernel
{
DEFINE_OPERATOR_CTOR
(
AddOp
,
framework
::
OperatorWithKernel
)
DEFINE_OPERATOR_CTOR
(
AddOp
,
framework
::
OperatorWithKernel
)
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
.
InputSize
(),
2
);
PADDLE_ENFORCE_EQ
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
(),
PADDLE_ENFORCE_EQ
(
ctx
.
OutputSize
(),
1
);
ctx
.
Input
<
Tensor
>
(
"Y"
)
->
dims
(),
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
0
),
"Inputs of AddOp must all be set"
);
"Two input of Add Op's dimension must be same."
);
PADDLE_ENFORCE
(
ctx
.
OutputVar
(
0
)
!=
nullptr
,
ctx
.
Output
<
Tensor
>
(
"Out"
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
());
"Outputs of AddOp must all be set"
);
PADDLE_ENFORCE
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
()
==
ctx
.
Input
<
Tensor
>
(
1
)
->
dims
(),
"Two input of Add Op's dimension must be same."
);
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
());
}
}
};
};
...
...
paddle/operators/add_op.h
浏览文件 @
81f5f861
...
@@ -28,9 +28,9 @@ template <typename Place, typename T>
...
@@ -28,9 +28,9 @@ template <typename Place, typename T>
class
AddKernel
:
public
framework
::
OpKernel
{
class
AddKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
input0
=
context
.
Input
<
Tensor
>
(
0
);
auto
*
input0
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
input1
=
context
.
Input
<
Tensor
>
(
1
);
auto
*
input1
=
context
.
Input
<
Tensor
>
(
"Y"
);
auto
output
=
context
.
Output
<
Tensor
>
(
0
);
auto
*
output
=
context
.
Output
<
Tensor
>
(
"Out"
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
...
...
paddle/operators/cross_entropy_op.cc
浏览文件 @
81f5f861
...
@@ -21,20 +21,13 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
...
@@ -21,20 +21,13 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR
(
OnehotCrossEntropyOp
,
framework
::
OperatorWithKernel
)
DEFINE_OPERATOR_CTOR
(
OnehotCrossEntropyOp
,
framework
::
OperatorWithKernel
)
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
.
InputSize
(),
2
,
auto
*
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
"Input size of OnehotCrossEntropyOp must be two"
);
auto
*
label
=
ctx
.
Input
<
Tensor
>
(
"label"
);
PADDLE_ENFORCE_EQ
(
ctx
.
OutputSize
(),
1
,
"Output size of OnehotCrossEntropyOp must be one"
);
PADDLE_ENFORCE_EQ
(
X
->
dims
().
size
(),
2
,
"X's dimension must be 2."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
0
),
PADDLE_ENFORCE_EQ
(
label
->
dims
().
size
(),
1
,
"label's dimension must be 1."
);
"0-th input of OnehotCrossEntropyOp should be set"
);
PADDLE_ENFORCE_EQ
(
X
->
dims
()[
0
],
label
->
dims
()[
0
]);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
1
),
ctx
.
Output
<
Tensor
>
(
"Y"
)
->
Resize
({
X
->
dims
()[
0
]});
"1-th input of OnehotCrossEntropyOp should be set"
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
OutputVar
(
0
),
"Outputs of OnehotCrossEntropyOp must all be set"
);
PADDLE_ENFORCE_EQ
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
().
size
(),
2
);
PADDLE_ENFORCE_EQ
(
ctx
.
Output
<
Tensor
>
(
0
)
->
dims
().
size
(),
1
,
"label's dimension must be 1."
);
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
({
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
()[
0
]});
}
}
};
};
...
...
paddle/operators/cross_entropy_op.h
浏览文件 @
81f5f861
...
@@ -45,7 +45,7 @@ class OnehotCrossEntropyOpKernel : public framework::OpKernel {
...
@@ -45,7 +45,7 @@ class OnehotCrossEntropyOpKernel : public framework::OpKernel {
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
X
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
T
*
Xdata
=
X
->
data
<
T
>
();
const
T
*
Xdata
=
X
->
data
<
T
>
();
const
int
*
label_data
=
ctx
.
Input
<
Tensor
>
(
1
)
->
data
<
int
>
();
const
int
*
label_data
=
ctx
.
Input
<
Tensor
>
(
"label"
)
->
data
<
int
>
();
auto
Y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
auto
Y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
Y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
Y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
...
paddle/operators/fill_zeros_like_op.cc
浏览文件 @
81f5f861
...
@@ -18,19 +18,12 @@ namespace paddle {
...
@@ -18,19 +18,12 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
class
FillZerosLikeOp
:
public
framework
::
OperatorWithKernel
{
class
FillZerosLikeOp
:
public
framework
::
OperatorWithKernel
{
DEFINE_OPERATOR_CTOR
(
FillZerosLikeOp
,
framework
::
OperatorWithKernel
)
DEFINE_OPERATOR_CTOR
(
FillZerosLikeOp
,
framework
::
OperatorWithKernel
);
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
.
InputSize
(),
1UL
,
ctx
.
Output
<
framework
::
Tensor
>
(
"Dst"
)
->
Resize
(
"Input size of FillZerosLikeOp must be one."
);
ctx
.
Input
<
framework
::
Tensor
>
(
"Src"
)
->
dims
());
PADDLE_ENFORCE_EQ
(
ctx
.
OutputSize
(),
1UL
,
"Output size of AddOp must be one."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
0
),
"Input of FillZerosLikeOp must be set."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
OutputVar
(
0
),
"Output of FillZerosLikeOp must be set."
);
ctx
.
Output
<
framework
::
Tensor
>
(
0
)
->
Resize
(
ctx
.
Input
<
framework
::
Tensor
>
(
0
)
->
dims
());
}
}
};
};
...
...
paddle/operators/fill_zeros_like_op.h
浏览文件 @
81f5f861
...
@@ -23,7 +23,7 @@ template <typename Place, typename T>
...
@@ -23,7 +23,7 @@ template <typename Place, typename T>
class
FillZerosLikeKernel
:
public
framework
::
OpKernel
{
class
FillZerosLikeKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
output
=
context
.
Output
<
framework
::
Tensor
>
(
0
);
auto
*
output
=
context
.
Output
<
framework
::
Tensor
>
(
"Dst"
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
output
);
auto
t
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
output
);
t
.
device
(
context
.
GetEigenDevice
<
Place
>
())
=
t
.
constant
(
T
(
0
));
t
.
device
(
context
.
GetEigenDevice
<
Place
>
())
=
t
.
constant
(
T
(
0
));
...
...
paddle/operators/gaussian_random_op.cc
浏览文件 @
81f5f861
...
@@ -43,7 +43,8 @@ class GaussianRandomKernel : public framework::OpKernel {
...
@@ -43,7 +43,8 @@ class GaussianRandomKernel : public framework::OpKernel {
};
};
class
GaussianRandomOp
:
public
framework
::
OperatorWithKernel
{
class
GaussianRandomOp
:
public
framework
::
OperatorWithKernel
{
DEFINE_OPERATOR_CTOR
(
GaussianRandomOp
,
framework
::
OperatorWithKernel
)
DEFINE_OPERATOR_CTOR
(
GaussianRandomOp
,
framework
::
OperatorWithKernel
);
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
context
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
context
)
const
override
{
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
0
);
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
0
);
...
...
paddle/operators/mean_op.cc
浏览文件 @
81f5f861
...
@@ -21,11 +21,9 @@ class MeanOp : public framework::OperatorWithKernel {
...
@@ -21,11 +21,9 @@ class MeanOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR
(
MeanOp
,
framework
::
OperatorWithKernel
)
DEFINE_OPERATOR_CTOR
(
MeanOp
,
framework
::
OperatorWithKernel
)
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
.
InputSize
(),
1
,
"Input size of AddOp must be one"
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
PADDLE_ENFORCE_EQ
(
ctx
.
OutputSize
(),
1
,
"Output size of AddOp must be one"
);
"Input of MeanOp must be initialized."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
0
),
"input should be set"
);
ctx
.
Output
<
Tensor
>
(
"Out"
)
->
Resize
({
1
});
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
OutputVar
(
0
),
"output should be set"
);
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
(
framework
::
make_ddim
({
1
}));
}
}
};
};
...
@@ -34,7 +32,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -34,7 +32,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
MeanOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
MeanOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The input of mean op"
);
AddInput
(
"X"
,
"The input of mean op"
);
AddOutput
(
"Out"
,
"The output of mean op"
).
Ignore
Gradient
();
AddOutput
(
"Out"
,
"The output of mean op"
).
AsNo
Gradient
();
AddComment
(
"Mean Operator"
);
AddComment
(
"Mean Operator"
);
}
}
};
};
...
...
paddle/operators/mean_op.h
浏览文件 @
81f5f861
...
@@ -31,14 +31,14 @@ template <typename Place, typename T>
...
@@ -31,14 +31,14 @@ template <typename Place, typename T>
class
MeanKernel
:
public
framework
::
OpKernel
{
class
MeanKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
input
=
context
.
Input
<
Tensor
>
(
0
);
auto
*
input
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
output
=
context
.
Output
<
Tensor
>
(
0
);
auto
*
output
=
context
.
Output
<
Tensor
>
(
"Out"
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
X
=
EigenVector
<
T
>::
Flatten
(
*
input
);
auto
X
=
EigenVector
<
T
>::
Flatten
(
*
input
);
auto
y
=
EigenScalar
<
T
>::
From
(
*
output
);
auto
y
=
EigenScalar
<
T
>::
From
(
*
output
);
auto
place
=
context
.
GetEigenDevice
<
Place
>
();
auto
&
place
=
context
.
GetEigenDevice
<
Place
>
();
y
.
device
(
place
)
=
X
.
mean
();
y
.
device
(
place
)
=
X
.
mean
();
}
}
...
...
paddle/operators/mul_op.cc
浏览文件 @
81f5f861
...
@@ -18,12 +18,12 @@ namespace paddle {
...
@@ -18,12 +18,12 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
class
MulOp
:
public
framework
::
OperatorWithKernel
{
class
MulOp
:
public
framework
::
OperatorWithKernel
{
DEFINE_OPERATOR_CTOR
(
MulOp
,
framework
::
OperatorWithKernel
)
DEFINE_OPERATOR_CTOR
(
MulOp
,
framework
::
OperatorWithKernel
);
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
2
,
"The mul op must take two inputs"
);
auto
dim0
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
();
auto
dim0
=
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
();
auto
dim1
=
ctx
.
Input
<
Tensor
>
(
"Y"
)
->
dims
();
auto
dim1
=
ctx
.
Input
<
Tensor
>
(
1
)
->
dims
();
PADDLE_ENFORCE_EQ
(
dim0
.
size
(),
2
,
PADDLE_ENFORCE_EQ
(
dim0
.
size
(),
2
,
"input X(%s) should be a tensor with 2 dims, a matrix"
,
"input X(%s) should be a tensor with 2 dims, a matrix"
,
ctx
.
op_
.
Input
(
"X"
));
ctx
.
op_
.
Input
(
"X"
));
...
@@ -33,8 +33,7 @@ class MulOp : public framework::OperatorWithKernel {
...
@@ -33,8 +33,7 @@ class MulOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
dim0
[
1
],
dim1
[
0
],
dim0
[
1
],
dim1
[
0
],
"First matrix's width must be equal with second matrix's height."
);
"First matrix's width must be equal with second matrix's height."
);
PADDLE_ENFORCE_EQ
(
ctx
.
OutputSize
(),
1
,
"The mul op takes only one output"
);
ctx
.
Output
<
Tensor
>
(
"Out"
)
->
Resize
({
dim0
[
0
],
dim1
[
1
]});
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
({
dim0
[
0
],
dim1
[
1
]});
}
}
};
};
...
...
paddle/operators/mul_op.h
浏览文件 @
81f5f861
...
@@ -30,17 +30,14 @@ class MulKernel : public framework::OpKernel {
...
@@ -30,17 +30,14 @@ class MulKernel : public framework::OpKernel {
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
Eigen
::
array
<
Eigen
::
IndexPair
<
Eigen
::
DenseIndex
>
,
1
>
dim_pair
=
{
Eigen
::
array
<
Eigen
::
IndexPair
<
Eigen
::
DenseIndex
>
,
1
>
dim_pair
=
{
{
Eigen
::
IndexPair
<
Eigen
::
DenseIndex
>
(
1
,
0
)}};
{
Eigen
::
IndexPair
<
Eigen
::
DenseIndex
>
(
1
,
0
)}};
auto
*
input0
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
input0
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
input1
=
context
.
Input
<
Tensor
>
(
"Y"
);
auto
input1
=
context
.
Input
<
Tensor
>
(
"Y"
);
auto
*
output
=
context
.
Output
<
Tensor
>
(
"Out"
);
auto
output
=
context
.
Output
<
Tensor
>
(
0
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
X
=
EigenMatrix
<
T
>::
From
(
*
input0
);
auto
X
=
EigenMatrix
<
T
>::
From
(
*
input0
);
auto
Y
=
EigenMatrix
<
T
>::
From
(
*
input1
);
auto
Y
=
EigenMatrix
<
T
>::
From
(
*
input1
);
auto
Z
=
EigenMatrix
<
T
>::
From
(
*
output
);
auto
Z
=
EigenMatrix
<
T
>::
From
(
*
output
);
auto
place
=
context
.
GetEigenDevice
<
Place
>
();
auto
&
place
=
context
.
GetEigenDevice
<
Place
>
();
Z
.
device
(
place
)
=
X
.
contract
(
Y
,
dim_pair
);
Z
.
device
(
place
)
=
X
.
contract
(
Y
,
dim_pair
);
}
}
...
...
paddle/operators/net_op.cc
浏览文件 @
81f5f861
...
@@ -15,48 +15,42 @@
...
@@ -15,48 +15,42 @@
*/
*/
#include "paddle/operators/net_op.h"
#include "paddle/operators/net_op.h"
#include <set>
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
const
char
NetOp
::
kAll
[]
=
"all"
;
void
NetOp
::
CompleteAddOp
(
bool
calc
)
{
void
NetOp
::
CompleteAddOp
(
bool
calc
)
{
add_op_done_
=
true
;
add_op_done_
=
true
;
if
(
!
calc
)
return
;
if
(
!
calc
)
return
;
std
::
unordered_set
<
std
::
string
>
input_set
;
std
::
set
<
std
::
string
>
input_set
;
std
::
unordered_set
<
std
::
string
>
output_set
;
std
::
set
<
std
::
string
>
output_set
;
std
::
unordered_set
<
std
::
string
>
temp_output
;
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
ipt
:
op
->
inputs_
)
{
for
(
auto
&
ipt
:
op
->
inputs_
)
{
if
(
!
Contains
(
output_set
,
ipt
))
{
// Not other op's output
for
(
auto
&
var_name
:
ipt
.
second
)
{
input_set
.
insert
(
ipt
);
if
(
!
Contains
(
output_set
,
var_name
))
{
// Not other op's output
}
else
{
input_set
.
insert
(
var_name
);
temp_output
.
insert
(
ipt
);
}
else
{
intermediate_outputs_
.
insert
(
var_name
);
}
}
}
}
}
for
(
auto
&
opt
:
op
->
outputs_
)
{
for
(
auto
&
opt
:
op
->
outputs_
)
{
output_set
.
insert
(
opt
);
for
(
auto
&
var_name
:
opt
.
second
)
{
}
output_set
.
insert
(
var_name
);
}
}
inputs_
.
reserve
(
input_set
.
size
());
std
::
copy
(
input_set
.
begin
(),
input_set
.
end
(),
std
::
back_inserter
(
inputs_
));
std
::
sort
(
inputs_
.
begin
(),
inputs_
.
end
());
outputs_
.
reserve
(
output_set
.
size
());
std
::
copy
(
output_set
.
begin
(),
output_set
.
end
(),
std
::
back_inserter
(
outputs_
));
std
::
sort
(
outputs_
.
begin
(),
outputs_
.
end
());
std
::
vector
<
int
>
tmp_index
;
tmp_index
.
reserve
(
temp_output
.
size
());
int
output_len
=
static_cast
<
int
>
(
outputs_
.
size
());
for
(
int
i
=
0
;
i
<
output_len
;
++
i
)
{
if
(
Contains
(
temp_output
,
outputs_
[
i
]))
{
tmp_index
.
push_back
(
i
);
}
}
}
}
auto
&
inputs
=
inputs_
[
kAll
];
attrs_
[
"temporary_index"
]
=
tmp_index
;
inputs
.
reserve
(
input_set
.
size
());
std
::
copy
(
input_set
.
begin
(),
input_set
.
end
(),
std
::
back_inserter
(
inputs
));
auto
&
outputs
=
outputs_
[
kAll
];
outputs
.
reserve
(
output_set
.
size
());
std
::
copy
(
output_set
.
begin
(),
output_set
.
end
(),
std
::
back_inserter
(
outputs
));
}
}
std
::
string
NetOp
::
DebugString
()
const
{
std
::
string
NetOp
::
DebugString
()
const
{
...
@@ -73,5 +67,19 @@ std::string NetOp::DebugString() const {
...
@@ -73,5 +67,19 @@ std::string NetOp::DebugString() const {
bool
NetOp
::
IsNetOp
()
const
{
return
true
;
}
bool
NetOp
::
IsNetOp
()
const
{
return
true
;
}
std
::
vector
<
std
::
string
>
NetOp
::
OutputVars
(
bool
has_intermediate
)
const
{
if
(
has_intermediate
)
{
return
this
->
outputs_
.
at
(
kAll
);
}
auto
&
all
=
this
->
outputs_
.
at
(
kAll
);
std
::
vector
<
std
::
string
>
ret_val
;
for
(
auto
&
each
:
all
)
{
if
(
!
Contains
(
intermediate_outputs_
,
each
))
{
ret_val
.
push_back
(
each
);
}
}
return
ret_val
;
}
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/operators/net_op.h
浏览文件 @
81f5f861
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#pragma once
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -35,7 +36,8 @@ namespace operators {
...
@@ -35,7 +36,8 @@ namespace operators {
*/
*/
class
NetOp
:
public
framework
::
OperatorBase
{
class
NetOp
:
public
framework
::
OperatorBase
{
public:
public:
DEFINE_OPERATOR_CTOR
(
NetOp
,
framework
::
OperatorBase
)
static
const
char
kAll
[];
DEFINE_OPERATOR_CTOR
(
NetOp
,
framework
::
OperatorBase
);
/**
/**
* Infer all the operators' input and output variables' shapes, will be called
* Infer all the operators' input and output variables' shapes, will be called
...
@@ -92,11 +94,13 @@ class NetOp : public framework::OperatorBase {
...
@@ -92,11 +94,13 @@ class NetOp : public framework::OperatorBase {
std
::
string
DebugString
()
const
override
;
std
::
string
DebugString
()
const
override
;
bool
IsNetOp
()
const
override
;
bool
IsNetOp
()
const
override
;
std
::
vector
<
std
::
string
>
OutputVars
(
bool
has_intermediate
)
const
override
;
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
>>
ops_
;
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
>>
ops_
;
private:
private:
bool
add_op_done_
{
false
};
bool
add_op_done_
{
false
};
std
::
set
<
std
::
string
>
intermediate_outputs_
;
template
<
typename
T
,
typename
KeyType
>
template
<
typename
T
,
typename
KeyType
>
static
bool
Contains
(
T
container
,
KeyType
key
)
{
static
bool
Contains
(
T
container
,
KeyType
key
)
{
...
...
paddle/operators/net_op_test.cc
浏览文件 @
81f5f861
...
@@ -12,8 +12,7 @@ static int run_cnt = 0;
...
@@ -12,8 +12,7 @@ static int run_cnt = 0;
class
TestOp
:
public
framework
::
OperatorBase
{
class
TestOp
:
public
framework
::
OperatorBase
{
public:
public:
DEFINE_OPERATOR_CTOR
(
TestOp
,
framework
::
OperatorBase
)
DEFINE_OPERATOR_CTOR
(
TestOp
,
framework
::
OperatorBase
);
void
InferShape
(
const
Scope
&
scope
)
const
override
{
++
infer_shape_cnt
;
}
void
InferShape
(
const
Scope
&
scope
)
const
override
{
++
infer_shape_cnt
;
}
void
Run
(
const
Scope
&
scope
,
void
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
...
@@ -23,8 +22,7 @@ class TestOp : public framework::OperatorBase {
...
@@ -23,8 +22,7 @@ class TestOp : public framework::OperatorBase {
class
EmptyOp
:
public
framework
::
OperatorBase
{
class
EmptyOp
:
public
framework
::
OperatorBase
{
public:
public:
DEFINE_OPERATOR_CTOR
(
EmptyOp
,
framework
::
OperatorBase
)
DEFINE_OPERATOR_CTOR
(
EmptyOp
,
framework
::
OperatorBase
);
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
Run
(
const
Scope
&
scope
,
const
DeviceContext
&
dev_ctx
)
const
override
{}
void
Run
(
const
Scope
&
scope
,
const
DeviceContext
&
dev_ctx
)
const
override
{}
};
};
...
@@ -47,39 +45,31 @@ TEST(OpKernel, all) {
...
@@ -47,39 +45,31 @@ TEST(OpKernel, all) {
ASSERT_NE
(
net
,
nullptr
);
ASSERT_NE
(
net
,
nullptr
);
auto
op1
=
std
::
make_shared
<
TestOp
>
();
auto
op1
=
std
::
make_shared
<
TestOp
>
();
op1
->
inputs_
=
{
"x"
,
"w1"
,
"b1"
};
op1
->
inputs_
=
{
{
"X"
,
{
"x"
}},
{
"W"
,
{
"w1"
}},
{
"b"
,
{
"b1"
}}
};
op1
->
outputs_
=
{
"y"
};
op1
->
outputs_
=
{
{
"Out"
,
{
"y"
}}
};
net
->
AddOp
(
op1
);
net
->
AddOp
(
op1
);
auto
op2
=
std
::
make_shared
<
TestOp
>
();
auto
op2
=
std
::
make_shared
<
TestOp
>
();
op2
->
inputs_
=
{
"y"
,
"w2"
,
"b2"
};
op2
->
inputs_
=
{
{
"X"
,
{
"y"
}},
{
"W"
,
{
"w2"
}},
{
"b"
,
{
"b2"
}}
};
op2
->
outputs_
=
{
"z"
};
op2
->
outputs_
=
{
{
"Out"
,
{
"z"
}}
};
net
->
AddOp
(
op2
);
net
->
AddOp
(
op2
);
net
->
CompleteAddOp
();
net
->
CompleteAddOp
();
AssertSameVectorWithoutOrder
({
"x"
,
"w1"
,
"b1"
,
"w2"
,
"b2"
},
net
->
inputs_
);
AssertSameVectorWithoutOrder
({
"x"
,
"w1"
,
"b1"
,
"w2"
,
"b2"
},
AssertSameVectorWithoutOrder
({
"y"
,
"z"
},
net
->
outputs_
);
net
->
inputs_
.
at
(
NetOp
::
kAll
));
auto
tmp_idx_iter
=
net
->
attrs_
.
find
(
"temporary_index"
);
AssertSameVectorWithoutOrder
({
"y"
,
"z"
},
net
->
outputs_
.
at
(
NetOp
::
kAll
));
ASSERT_NE
(
net
->
attrs_
.
end
(),
tmp_idx_iter
);
auto
&
tmp_idx
=
boost
::
get
<
std
::
vector
<
int
>>
(
tmp_idx_iter
->
second
);
ASSERT_EQ
(
1UL
,
tmp_idx
.
size
());
ASSERT_EQ
(
"y"
,
net
->
outputs_
[
tmp_idx
[
0
]]);
Scope
scope
;
auto
final_outs
=
net
->
OutputVars
(
false
);
platform
::
CPUDeviceContext
dev_ctx
;
net
->
InferShape
(
scope
);
ASSERT_EQ
(
final_outs
.
size
(),
1UL
);
net
->
Run
(
scope
,
dev_ctx
);
ASSERT_EQ
(
final_outs
[
0
],
"z"
);
ASSERT_EQ
(
2
,
infer_shape_cnt
);
ASSERT_EQ
(
2
,
run_cnt
);
ASSERT_THROW
(
net
->
AddOp
(
op2
),
platform
::
EnforceNotMet
);
}
}
TEST
(
NetOp
,
insert_op
)
{
TEST
(
NetOp
,
insert_op
)
{
NetOp
net
;
NetOp
net
;
auto
op1
=
std
::
make_shared
<
EmptyOp
>
();
auto
op1
=
std
::
make_shared
<
EmptyOp
>
();
op1
->
inputs_
=
{
"x"
,
"w1"
,
"b1"
};
op1
->
inputs_
=
{
{
"X"
,
{
"x"
}},
{
"W"
,
{
"w1"
}},
{
"b"
,
{
"b1"
}}
};
op1
->
outputs_
=
{
"y"
};
op1
->
outputs_
=
{
{
"Out"
,
{
"y"
}}
};
net
.
AddOp
(
op1
);
net
.
AddOp
(
op1
);
net
.
InsertOp
(
0
,
op1
);
net
.
InsertOp
(
0
,
op1
);
ASSERT_EQ
(
2UL
,
net
.
ops_
.
size
());
ASSERT_EQ
(
2UL
,
net
.
ops_
.
size
());
...
...
paddle/operators/recurrent_op.cc
浏览文件 @
81f5f861
...
@@ -91,12 +91,17 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
...
@@ -91,12 +91,17 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
// create step net's temp inputs
// create step net's temp inputs
for
(
auto
&
input
:
net_op
->
inputs_
)
{
for
(
auto
&
input
:
net_op
->
inputs_
)
{
// the weight are located in parent scope
// the weight are located in parent scope
if
(
!
step_scope
.
FindVar
(
input
))
for
(
auto
&
var_name
:
input
.
second
)
{
step_scope
.
NewVar
(
input
)
->
GetMutable
<
Tensor
>
();
if
(
!
step_scope
.
FindVar
(
var_name
))
{
step_scope
.
NewVar
(
var_name
)
->
GetMutable
<
Tensor
>
();
}
}
}
}
// create stepnet's outputs
// create stepnet's outputs
for
(
const
auto
&
output
:
net_op
->
outputs_
)
{
for
(
const
auto
&
output
:
net_op
->
outputs_
)
{
step_scope
.
NewVar
(
output
);
for
(
auto
&
var_name
:
output
.
second
)
{
step_scope
.
NewVar
(
var_name
);
}
}
}
step_scopes
->
emplace_back
(
&
step_scope
);
step_scopes
->
emplace_back
(
&
step_scope
);
}
}
...
@@ -147,13 +152,13 @@ class RecurrentAlgorithmProtoAndCheckerMaker
...
@@ -147,13 +152,13 @@ class RecurrentAlgorithmProtoAndCheckerMaker
// inputs and outputs stored in proto
// inputs and outputs stored in proto
AddInput
(
name
.
inlinks
,
AddInput
(
name
.
inlinks
,
"the inputs that need to be segmented for each step."
)
"the inputs that need to be segmented for each step."
)
.
SetMultip
le
();
.
AsDuplicab
le
();
AddInput
(
name
.
boot_memories
,
"variables to initialize memories."
)
AddInput
(
name
.
boot_memories
,
"variables to initialize memories."
)
.
SetMultip
le
();
.
AsDuplicab
le
();
AddInput
(
name
.
step_net
,
"network shared by all steps."
);
AddInput
(
name
.
step_net
,
"network shared by all steps."
);
AddOutput
(
name
.
outlinks
,
"the outputs that need to concated for all steps."
)
AddOutput
(
name
.
outlinks
,
"the outputs that need to concated for all steps."
)
.
SetMultip
le
();
.
AsDuplicab
le
();
AddOutput
(
name
.
step_scopes
,
"step scopes"
);
AddOutput
(
name
.
step_scopes
,
"step scopes"
);
// Attributes stored in AttributeMap
// Attributes stored in AttributeMap
...
...
paddle/operators/recurrent_op.h
浏览文件 @
81f5f861
...
@@ -100,8 +100,9 @@ class RecurrentGradientAlgorithm {
...
@@ -100,8 +100,9 @@ class RecurrentGradientAlgorithm {
};
};
class
RecurrentOp
final
:
public
framework
::
OperatorBase
{
class
RecurrentOp
final
:
public
framework
::
OperatorBase
{
DEFINE_OPERATOR_CTOR
(
RecurrentOp
,
framework
::
OperatorBase
)
public:
public:
DEFINE_OPERATOR_CTOR
(
RecurrentOp
,
framework
::
OperatorBase
);
void
Init
()
override
;
void
Init
()
override
;
/**
/**
...
@@ -124,6 +125,7 @@ class RecurrentOp final : public framework::OperatorBase {
...
@@ -124,6 +125,7 @@ class RecurrentOp final : public framework::OperatorBase {
class
RecurrentGradientOp
final
:
public
framework
::
OperatorBase
{
class
RecurrentGradientOp
final
:
public
framework
::
OperatorBase
{
public:
public:
DEFINE_OPERATOR_CTOR
(
RecurrentGradientOp
,
framework
::
OperatorBase
)
void
Init
()
override
;
void
Init
()
override
;
/**
/**
...
...
paddle/operators/recurrent_op_test.cc
浏览文件 @
81f5f861
...
@@ -25,157 +25,7 @@
...
@@ -25,157 +25,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
framework
::
make_ddim
;
using
namespace
paddle
::
framework
;
using
framework
::
DDim
;
using
framework
::
Tensor
;
using
framework
::
Variable
;
using
framework
::
Scope
;
using
framework
::
OpRegistry
;
class
RecurrentOpTest
:
public
::
testing
::
Test
{
protected:
virtual
void
SetUp
()
override
{
CreateGlobalVariables
();
CreateStepNet
();
CreateRNNOp
();
}
virtual
void
TearDown
()
override
{}
void
CreateGlobalVariables
()
{
// create input, and init content
LOG
(
INFO
)
<<
"create global variable x"
;
for
(
auto
inlink
:
std
::
vector
<
std
::
string
>
{
"x"
,
"x0"
,
"x1"
,
"h"
})
{
Variable
*
x
=
scope_
.
NewVar
(
inlink
);
DDim
dims
=
make_ddim
(
std
::
vector
<
int
>
{
10
/*sent size*/
,
20
/*batch size*/
,
30
/*input dim*/
});
x
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
dims
,
platform
::
CPUPlace
());
}
// create output alias just for test
for
(
auto
inlink
:
std
::
vector
<
std
::
string
>
{
"h@alias"
})
{
Variable
*
x
=
scope_
.
NewVar
(
inlink
);
DDim
dims
=
make_ddim
(
std
::
vector
<
int
>
{
20
/*batch size*/
,
30
/*input dim*/
});
x
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
dims
,
platform
::
CPUPlace
());
}
LOG
(
INFO
)
<<
"create global variable w"
;
Variable
*
w
=
scope_
.
NewVar
(
"rnn/w"
);
w
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
make_ddim
(
std
::
vector
<
int
>
{
30
,
30
}),
platform
::
CPUPlace
());
for
(
auto
boot
:
std
::
vector
<
std
::
string
>
{
"h_boot"
})
{
LOG
(
INFO
)
<<
"create global variable "
<<
boot
;
Variable
*
h_boot
=
scope_
.
NewVar
(
boot
);
h_boot
->
GetMutable
<
Tensor
>
()
->
mutable_data
<
float
>
(
make_ddim
(
std
::
vector
<
int
>
{
20
/*batch size*/
,
30
/*input dim*/
}),
platform
::
CPUPlace
());
}
LOG
(
INFO
)
<<
"create variable step_scopes"
;
scope_
.
NewVar
(
"step_scopes"
);
LOG
(
INFO
)
<<
"create variable h"
;
scope_
.
NewVar
(
"h"
);
}
void
CreateRNNOp
()
{
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"recurrent_op"
);
// inlinks 0
op_desc
.
add_inputs
(
"x"
);
op_desc
.
add_inputs
(
"x0"
);
op_desc
.
add_inputs
(
"x1"
);
// boot_memories 3
op_desc
.
add_inputs
(
"h_boot"
);
// step net 5
op_desc
.
add_inputs
(
"step_net"
);
// outlinks 6
op_desc
.
add_outputs
(
"h"
);
// step scopes 7
op_desc
.
add_outputs
(
"step_scopes"
);
auto
_input_format
=
std
::
vector
<
int
>
{
0
,
// in_link
3
,
// memories
4
// step_net
};
auto
input_format
=
op_desc
.
add_attrs
();
input_format
->
set_name
(
"input_format"
);
input_format
->
set_type
(
paddle
::
framework
::
AttrType
::
INTS
);
for
(
auto
i
:
_input_format
)
{
input_format
->
add_ints
(
i
);
}
auto
output_format
=
op_desc
.
add_attrs
();
output_format
->
set_name
(
"output_format"
);
output_format
->
set_type
(
paddle
::
framework
::
AttrType
::
INTS
);
for
(
auto
i
:
std
::
vector
<
int
>
{
0
,
1
,
2
})
{
output_format
->
add_ints
(
i
);
}
auto
inlink_alias
=
op_desc
.
add_attrs
();
inlink_alias
->
set_name
(
"inlink_alias"
);
inlink_alias
->
set_type
(
paddle
::
framework
::
AttrType
::
STRINGS
);
auto
outlink_alias
=
op_desc
.
add_attrs
();
outlink_alias
->
set_name
(
"outlink_alias"
);
outlink_alias
->
set_type
(
paddle
::
framework
::
AttrType
::
STRINGS
);
auto
pre_memories
=
op_desc
.
add_attrs
();
pre_memories
->
set_name
(
"pre_memories"
);
pre_memories
->
set_type
(
paddle
::
framework
::
AttrType
::
STRINGS
);
auto
memories
=
op_desc
.
add_attrs
();
memories
->
set_name
(
"memories"
);
memories
->
set_type
(
paddle
::
framework
::
AttrType
::
STRINGS
);
// create inlink_alias
for
(
const
auto
&
item
:
std
::
vector
<
std
::
string
>
{
"x@alias"
,
"x0@alias"
,
"x1@alias"
})
{
inlink_alias
->
add_strings
(
item
);
}
// pre memories
for
(
const
auto
&
item
:
std
::
vector
<
std
::
string
>
{
"rnn/h@pre"
})
{
pre_memories
->
add_strings
(
item
);
}
// memories
for
(
const
auto
&
item
:
std
::
vector
<
std
::
string
>
{
"rnn/h"
})
{
memories
->
add_strings
(
item
);
}
// output alias
for
(
const
auto
&
item
:
std
::
vector
<
std
::
string
>
{
"h@alias"
})
{
outlink_alias
->
add_strings
(
item
);
}
rnn_op_
=
OpRegistry
::
CreateOp
(
op_desc
);
LOG
(
INFO
)
<<
"rnn_op finish init"
;
}
void
CreateStepNet
()
{
LOG
(
INFO
)
<<
"create variable step_net"
;
Variable
*
var
=
scope_
.
NewVar
(
"step_net"
);
auto
net
=
var
->
GetMutable
<
NetOp
>
();
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
{
"rnn/h@pre"
,
"rnn/w"
},
{
"rnn/s"
},
{}));
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"add_two"
,
{
"x@alias"
,
"rnn/s"
},
{
"rnn/h"
},
{}));
net
->
CompleteAddOp
();
}
// father scope
Scope
scope_
;
std
::
shared_ptr
<
framework
::
OperatorBase
>
rnn_op_
;
};
TEST_F
(
RecurrentOpTest
,
Run
)
{
platform
::
CPUDeviceContext
ctx
;
rnn_op_
->
InferShape
(
scope_
);
rnn_op_
->
Run
(
scope_
,
ctx
);
}
class
RecurrentGradientAlgorithmTest
:
public
::
testing
::
Test
{
class
RecurrentGradientAlgorithmTest
:
public
::
testing
::
Test
{
protected:
protected:
...
@@ -281,11 +131,13 @@ class RecurrentGradientAlgorithmTest : public ::testing::Test {
...
@@ -281,11 +131,13 @@ class RecurrentGradientAlgorithmTest : public ::testing::Test {
LOG
(
INFO
)
<<
"create variable step_net"
;
LOG
(
INFO
)
<<
"create variable step_net"
;
Variable
*
var
=
scope_
.
NewVar
(
"step_net"
);
Variable
*
var
=
scope_
.
NewVar
(
"step_net"
);
auto
net
=
var
->
GetMutable
<
NetOp
>
();
auto
net
=
var
->
GetMutable
<
NetOp
>
();
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
{
"rnn/h_pre"
,
"rnn/w"
,
"rnn/s_grad"
},
// TODO(qingqing) modify backward op create for RNNOp unit test
{
"rnn/h_pre_grad"
,
"rnn/w_grad"
},
{}));
// and the unit test will be removed to Python.
// net->AddOp(OpRegistry::CreateOp("mul", {"X", {"rnn/h_pre", "rnn/w",
// "rnn/s_grad"}}, {"Y", {"rnn/h_pre_grad", "rnn/w_grad"}}, {}));
net
->
AddOp
(
OpRegistry
::
CreateOp
(
"add_two"
,
{
"rnn/h_grad"
},
// net->AddOp(OpRegistry::CreateOp("add_two", {"X", {"rnn/h_grad"}
},
{
"rnn/x_grad"
,
"rnn/s_grad"
},
{}));
// {"Y", {"rnn/x_grad"}}, {"Out", "rnn/s_grad"}
}, {}));
net
->
CompleteAddOp
();
net
->
CompleteAddOp
();
}
}
...
@@ -359,7 +211,8 @@ TEST(RecurrentOp, LinkMemories) {
...
@@ -359,7 +211,8 @@ TEST(RecurrentOp, LinkMemories) {
memories
.
push_back
(
mem_attr
);
memories
.
push_back
(
mem_attr
);
for
(
size_t
i
=
1
;
i
<
len
;
++
i
)
{
for
(
size_t
i
=
1
;
i
<
len
;
++
i
)
{
rnn
::
LinkMemories
(
step_scopes
,
memories
,
i
,
-
1
,
false
/*infer_shape_mode*/
);
rnn
::
LinkMemories
(
step_scopes
,
memories
,
i
,
-
1
,
false
/*infer_shape_mode*/
);
}
}
// check
// check
for
(
size_t
i
=
0
;
i
<
len
-
1
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
len
-
1
;
++
i
)
{
...
@@ -375,7 +228,8 @@ TEST(RecurrentOp, LinkMemories) {
...
@@ -375,7 +228,8 @@ TEST(RecurrentOp, LinkMemories) {
}
}
for
(
int
i
=
len
-
2
;
i
>=
0
;
--
i
)
{
for
(
int
i
=
len
-
2
;
i
>=
0
;
--
i
)
{
rnn
::
LinkMemories
(
step_scopes
,
memories
,
i
,
1
,
false
/*infer_shape_mode*/
);
rnn
::
LinkMemories
(
step_scopes
,
memories
,
i
,
1
,
false
/*infer_shape_mode*/
);
}
}
// check
// check
for
(
int
i
=
len
-
2
;
i
>=
0
;
--
i
)
{
for
(
int
i
=
len
-
2
;
i
>=
0
;
--
i
)
{
...
...
paddle/operators/rowwise_add_op.cc
浏览文件 @
81f5f861
...
@@ -21,16 +21,14 @@ class RowWiseAddOp : public framework::OperatorWithKernel {
...
@@ -21,16 +21,14 @@ class RowWiseAddOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR
(
RowWiseAddOp
,
framework
::
OperatorWithKernel
)
DEFINE_OPERATOR_CTOR
(
RowWiseAddOp
,
framework
::
OperatorWithKernel
)
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
2UL
,
auto
dim0
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
();
"Two inputs is needed by rowwise add"
);
auto
dim1
=
ctx
.
Input
<
Tensor
>
(
"b"
)
->
dims
();
auto
dim0
=
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
();
auto
dim1
=
ctx
.
Input
<
Tensor
>
(
1
)
->
dims
();
PADDLE_ENFORCE
(
dim0
.
size
()
==
2
,
"Input 0 must be matrix"
);
PADDLE_ENFORCE
(
dim0
.
size
()
==
2
,
"Input 0 must be matrix"
);
PADDLE_ENFORCE
(
dim1
.
size
()
==
1
,
"The second input must be vector"
);
PADDLE_ENFORCE
(
dim1
.
size
()
==
1
,
"The second input must be vector"
);
PADDLE_ENFORCE
(
dim0
[
1
]
==
dim1
[
0
],
"The width of two input must be same"
);
PADDLE_ENFORCE
(
dim0
[
1
]
==
dim1
[
0
],
"The width of two input must be same"
);
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1
,
"The output size must be 1"
);
PADDLE_ENFORCE
(
ctx
.
OutputSize
(
"Out"
)
==
1
,
"The output size must be 1"
);
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
());
ctx
.
Output
<
Tensor
>
(
"Out"
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
());
}
}
};
};
...
...
paddle/operators/rowwise_add_op.h
浏览文件 @
81f5f861
...
@@ -31,11 +31,11 @@ template <typename Place, typename T>
...
@@ -31,11 +31,11 @@ template <typename Place, typename T>
class
RowWiseAddKernel
:
public
framework
::
OpKernel
{
class
RowWiseAddKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
out
=
context
.
Output
<
Tensor
>
(
0
);
auto
out
=
context
.
Output
<
Tensor
>
(
"Out"
);
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
input
=
EigenMatrix
<
T
>::
From
(
*
context
.
Input
<
Tensor
>
(
0
));
auto
input
=
EigenMatrix
<
T
>::
From
(
*
context
.
Input
<
Tensor
>
(
"X"
));
auto
bias
=
EigenVector
<
T
>::
From
(
*
context
.
Input
<
Tensor
>
(
1
));
auto
bias
=
EigenVector
<
T
>::
From
(
*
context
.
Input
<
Tensor
>
(
"b"
));
auto
output
=
EigenMatrix
<
T
>::
From
(
*
out
);
auto
output
=
EigenMatrix
<
T
>::
From
(
*
out
);
const
int
bias_size
=
bias
.
dimension
(
0
);
const
int
bias_size
=
bias
.
dimension
(
0
);
...
...
paddle/operators/sgd_op.cc
浏览文件 @
81f5f861
...
@@ -21,14 +21,10 @@ class SGDOp : public framework::OperatorWithKernel {
...
@@ -21,14 +21,10 @@ class SGDOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR
(
SGDOp
,
framework
::
OperatorWithKernel
)
DEFINE_OPERATOR_CTOR
(
SGDOp
,
framework
::
OperatorWithKernel
)
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
.
InputSize
(),
2
,
"Input size of SGDOp must be two"
);
PADDLE_ENFORCE
(
PADDLE_ENFORCE_EQ
(
ctx
.
OutputSize
(),
1
,
"Output size of SGDOp must be one"
);
ctx
.
Input
<
Tensor
>
(
"param"
)
->
dims
()
==
ctx
.
Input
<
Tensor
>
(
"grad"
)
->
dims
(),
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
0
),
"inputs[0] mast be set"
);
"Two input of SGD Op's dimension must be same."
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
1
),
"inputs[1] mast be set"
);
ctx
.
Output
<
Tensor
>
(
"param_out"
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
"param"
)
->
dims
());
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
OutputVar
(
0
),
"outputs[0] mast be set"
);
PADDLE_ENFORCE
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
()
==
ctx
.
Input
<
Tensor
>
(
1
)
->
dims
(),
"Two input of SGD Op's dimension must be same."
);
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
());
}
}
};
};
...
...
paddle/operators/sigmoid_op.cc
浏览文件 @
81f5f861
...
@@ -21,9 +21,7 @@ class SigmoidOp : public framework::OperatorWithKernel {
...
@@ -21,9 +21,7 @@ class SigmoidOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR
(
SigmoidOp
,
framework
::
OperatorWithKernel
)
DEFINE_OPERATOR_CTOR
(
SigmoidOp
,
framework
::
OperatorWithKernel
)
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
.
InputSize
()
==
1
,
"Sigmoid Op only have one input"
);
ctx
.
Output
<
Tensor
>
(
"Y"
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
());
PADDLE_ENFORCE
(
ctx
.
OutputSize
()
==
1
,
"Sigmoid Op only have one output"
);
ctx
.
Output
<
Tensor
>
(
0
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
());
}
}
};
};
...
...
paddle/operators/sigmoid_op.h
浏览文件 @
81f5f861
...
@@ -28,8 +28,8 @@ template <typename Place, typename T>
...
@@ -28,8 +28,8 @@ template <typename Place, typename T>
class
SigmoidKernel
:
public
framework
::
OpKernel
{
class
SigmoidKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
input
=
context
.
Input
<
Tensor
>
(
0
);
auto
input
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
output
=
context
.
Output
<
Tensor
>
(
0
);
auto
output
=
context
.
Output
<
Tensor
>
(
"Y"
);
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
// The clipping is used in Paddle's raw implenmention
// The clipping is used in Paddle's raw implenmention
...
...
paddle/operators/softmax_op.cc
浏览文件 @
81f5f861
...
@@ -21,12 +21,8 @@ class SoftmaxOp : public framework::OperatorWithKernel {
...
@@ -21,12 +21,8 @@ class SoftmaxOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR
(
SoftmaxOp
,
framework
::
OperatorWithKernel
)
DEFINE_OPERATOR_CTOR
(
SoftmaxOp
,
framework
::
OperatorWithKernel
)
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
.
InputSize
(),
1UL
,
PADDLE_ENFORCE
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
().
size
()
==
2UL
,
"Only one input is need for softmax"
);
"The input of softmax op must be matrix"
);
PADDLE_ENFORCE_EQ
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
().
size
(),
2UL
,
"The input of softmax op must be matrix"
);
PADDLE_ENFORCE_EQ
(
ctx
.
OutputSize
(),
1UL
,
"Only one output is need for softmax"
);
ctx
.
Output
<
Tensor
>
(
"Y"
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
());
ctx
.
Output
<
Tensor
>
(
"Y"
)
->
Resize
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
());
}
}
};
};
...
@@ -46,11 +42,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
...
@@ -46,11 +42,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR
(
SoftmaxOpGrad
,
framework
::
OperatorWithKernel
)
DEFINE_OPERATOR_CTOR
(
SoftmaxOpGrad
,
framework
::
OperatorWithKernel
)
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
.
InputSize
(),
3UL
,
PADDLE_ENFORCE
(
ctx
.
InputVar
(
"Y"
)
!=
nullptr
,
"Input(Y) should not be null"
);
"Input of SoftmaxOpGrad should be 3, X, Y, YG"
);
PADDLE_ENFORCE_EQ
(
ctx
.
OutputSize
(),
1UL
,
"Output of SoftmaxOpGrad should be 1"
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Y"
),
"Input(Y) should not be null"
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Y"
)),
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Y"
)),
"Input(Y@GRAD) should not be null"
);
"Input(Y@GRAD) should not be null"
);
PADDLE_ENFORCE
(
ctx
.
Input
<
Tensor
>
(
"Y"
)
->
dims
()
==
PADDLE_ENFORCE
(
ctx
.
Input
<
Tensor
>
(
"Y"
)
->
dims
()
==
...
...
paddle/operators/uniform_random_op.cc
浏览文件 @
81f5f861
...
@@ -27,7 +27,7 @@ template <typename T>
...
@@ -27,7 +27,7 @@ template <typename T>
class
CPUUniformRandomKernel
:
public
framework
::
OpKernel
{
class
CPUUniformRandomKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
0
);
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
unsigned
int
seed
=
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
context
.
op_
.
GetAttr
<
int
>
(
"seed"
));
static_cast
<
unsigned
int
>
(
context
.
op_
.
GetAttr
<
int
>
(
"seed"
));
...
@@ -51,7 +51,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
...
@@ -51,7 +51,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
GetAttr
<
float
>
(
"min"
)
<
GetAttr
<
float
>
(
"max"
),
PADDLE_ENFORCE
(
GetAttr
<
float
>
(
"min"
)
<
GetAttr
<
float
>
(
"max"
),
"uniform_random's min must less then max"
);
"uniform_random's min must less then max"
);
auto
*
tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
0
);
auto
*
tensor
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
dims
=
GetAttr
<
std
::
vector
<
int
>>
(
"dims"
);
auto
dims
=
GetAttr
<
std
::
vector
<
int
>>
(
"dims"
);
tensor
->
Resize
(
framework
::
make_ddim
(
dims
));
tensor
->
Resize
(
framework
::
make_ddim
(
dims
));
}
}
...
...
paddle/operators/uniform_random_op.cu
浏览文件 @
81f5f861
...
@@ -46,7 +46,7 @@ template <typename T>
...
@@ -46,7 +46,7 @@ template <typename T>
class
GPUUniformRandomKernel
:
public
framework
::
OpKernel
{
class
GPUUniformRandomKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
0
);
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
unsigned
int
seed
=
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
context
.
op_
.
GetAttr
<
int
>
(
"seed"
));
static_cast
<
unsigned
int
>
(
context
.
op_
.
GetAttr
<
int
>
(
"seed"
));
...
...
python/paddle/v2/framework/op.py
浏览文件 @
81f5f861
import
paddle.v2.framework.core
as
core
import
paddle.v2.framework.core
as
core
import
paddle.v2.framework.proto.op_proto_pb2
as
op_proto_pb2
import
paddle.v2.framework.proto.framework_pb2
as
framework_pb2
import
paddle.v2.framework.proto.op_desc_pb2
as
op_desc_pb2
import
paddle.v2.framework.proto.attribute_pb2
as
attribute_pb2
def
get_all_op_protos
():
def
get_all_op_protos
():
...
@@ -12,11 +10,15 @@ def get_all_op_protos():
...
@@ -12,11 +10,15 @@ def get_all_op_protos():
protostrs
=
core
.
get_all_op_protos
()
protostrs
=
core
.
get_all_op_protos
()
ret_values
=
[]
ret_values
=
[]
for
pbstr
in
protostrs
:
for
pbstr
in
protostrs
:
op_proto
=
op_proto
_pb2
.
OpProto
.
FromString
(
str
(
pbstr
))
op_proto
=
framework
_pb2
.
OpProto
.
FromString
(
str
(
pbstr
))
ret_values
.
append
(
op_proto
)
ret_values
.
append
(
op_proto
)
return
ret_values
return
ret_values
def
is_str
(
s
):
return
isinstance
(
s
,
str
)
or
isinstance
(
s
,
unicode
)
class
OpDescCreationMethod
(
object
):
class
OpDescCreationMethod
(
object
):
"""
"""
A Functor object to convert user input(use key word args) to OpDesc based on
A Functor object to convert user input(use key word args) to OpDesc based on
...
@@ -27,7 +29,7 @@ class OpDescCreationMethod(object):
...
@@ -27,7 +29,7 @@ class OpDescCreationMethod(object):
"""
"""
def
__init__
(
self
,
op_proto
):
def
__init__
(
self
,
op_proto
):
if
not
isinstance
(
op_proto
,
op_proto
_pb2
.
OpProto
):
if
not
isinstance
(
op_proto
,
framework
_pb2
.
OpProto
):
raise
TypeError
(
"Argument should be OpProto"
)
raise
TypeError
(
"Argument should be OpProto"
)
self
.
__op_proto__
=
op_proto
self
.
__op_proto__
=
op_proto
...
@@ -39,26 +41,34 @@ class OpDescCreationMethod(object):
...
@@ -39,26 +41,34 @@ class OpDescCreationMethod(object):
"""
"""
if
len
(
args
)
!=
0
:
if
len
(
args
)
!=
0
:
raise
ValueError
(
"Only keyword arguments is supported by Paddle"
)
raise
ValueError
(
"Only keyword arguments is supported by Paddle"
)
op_desc
=
op_desc_pb2
.
OpDesc
()
op_desc
=
framework_pb2
.
OpDesc
()
# Inputs
for
input_parameter
in
self
.
__op_proto__
.
inputs
:
ipts
,
ipt_format
,
_
=
OpDescCreationMethod
.
extract_input_or_output
(
input_arguments
=
kwargs
.
get
(
input_parameter
.
name
,
[])
"input"
,
kwargs
,
self
.
__op_proto__
.
inputs
)
if
is_str
(
input_arguments
):
op_desc
.
inputs
.
extend
(
ipts
)
input_arguments
=
[
input_arguments
]
if
ipt_format
is
not
None
:
op_desc
.
attrs
.
extend
([
ipt_format
])
if
not
input_parameter
.
duplicable
and
len
(
input_arguments
)
>
1
:
raise
ValueError
(
"Input %s only accepts one input, but give %d"
# Outputs
%
(
input_parameter
.
name
,
len
(
input_arguments
)))
outs
,
out_format
,
tmp_index
=
OpDescCreationMethod
.
extract_input_or_output
(
"output"
,
kwargs
,
self
.
__op_proto__
.
outputs
)
ipt
=
op_desc
.
inputs
.
add
()
op_desc
.
outputs
.
extend
(
outs
)
ipt
.
parameter
=
input_parameter
.
name
if
out_format
is
not
None
:
ipt
.
arguments
.
extend
(
input_arguments
)
op_desc
.
attrs
.
extend
([
out_format
])
if
len
(
tmp_index
)
!=
0
:
for
output_parameter
in
self
.
__op_proto__
.
outputs
:
tmp_index_attr
=
op_desc
.
attrs
.
add
()
output_arguments
=
kwargs
.
get
(
output_parameter
.
name
,
[])
tmp_index_attr
.
type
=
attribute_pb2
.
INTS
if
is_str
(
output_arguments
):
tmp_index_attr
.
name
=
"temporary_index"
output_arguments
=
[
output_arguments
]
tmp_index_attr
.
ints
.
extend
(
tmp_index
)
if
not
output_parameter
.
duplicable
and
len
(
output_arguments
)
>
1
:
raise
ValueError
(
"Output %s only accepts one output, but give %d"
%
(
output_parameter
.
name
,
len
(
output_arguments
)))
out
=
op_desc
.
outputs
.
add
()
out
.
parameter
=
output_parameter
.
name
out
.
arguments
.
extend
(
output_arguments
)
# Types
# Types
op_desc
.
type
=
self
.
__op_proto__
.
type
op_desc
.
type
=
self
.
__op_proto__
.
type
...
@@ -72,17 +82,17 @@ class OpDescCreationMethod(object):
...
@@ -72,17 +82,17 @@ class OpDescCreationMethod(object):
new_attr
=
op_desc
.
attrs
.
add
()
new_attr
=
op_desc
.
attrs
.
add
()
new_attr
.
name
=
attr
.
name
new_attr
.
name
=
attr
.
name
new_attr
.
type
=
attr
.
type
new_attr
.
type
=
attr
.
type
if
attr
.
type
==
attribute
_pb2
.
INT
:
if
attr
.
type
==
framework
_pb2
.
INT
:
new_attr
.
i
=
user_defined_attr
new_attr
.
i
=
user_defined_attr
elif
attr
.
type
==
attribute
_pb2
.
FLOAT
:
elif
attr
.
type
==
framework
_pb2
.
FLOAT
:
new_attr
.
f
=
user_defined_attr
new_attr
.
f
=
user_defined_attr
elif
attr
.
type
==
attribute
_pb2
.
STRING
:
elif
attr
.
type
==
framework
_pb2
.
STRING
:
new_attr
.
s
=
user_defined_attr
new_attr
.
s
=
user_defined_attr
elif
attr
.
type
==
attribute
_pb2
.
INTS
:
elif
attr
.
type
==
framework
_pb2
.
INTS
:
new_attr
.
ints
.
extend
(
user_defined_attr
)
new_attr
.
ints
.
extend
(
user_defined_attr
)
elif
attr
.
type
==
attribute
_pb2
.
FLOATS
:
elif
attr
.
type
==
framework
_pb2
.
FLOATS
:
new_attr
.
floats
.
extend
(
user_defined_attr
)
new_attr
.
floats
.
extend
(
user_defined_attr
)
elif
attr
.
type
==
attribute
_pb2
.
STRINGS
:
elif
attr
.
type
==
framework
_pb2
.
STRINGS
:
new_attr
.
strings
.
extend
(
user_defined_attr
)
new_attr
.
strings
.
extend
(
user_defined_attr
)
else
:
else
:
raise
NotImplementedError
(
"Not support attribute type "
+
raise
NotImplementedError
(
"Not support attribute type "
+
...
@@ -90,50 +100,6 @@ class OpDescCreationMethod(object):
...
@@ -90,50 +100,6 @@ class OpDescCreationMethod(object):
return
op_desc
return
op_desc
@
staticmethod
def
extract_input_or_output
(
in_out
,
kwargs
,
meta
):
"""
Extract input variable names or output variable names from key-word
arguments, which base on VarProtos.
:param in_out: "input" or "output"
:param kwargs: key-word arguments that user inputted.
:param meta: a list of VarProto
:return: The three object will be return. The variable names. The
input_format or output_format attribute(None if the input or output is
not multiple). The temporary variable index list.
"""
multiple
=
OpDescCreationMethod
.
any_is_true
((
m
.
multiple
for
m
in
meta
))
tmp_index
=
[]
retv
=
[]
if
multiple
:
var_format
=
op_desc_pb2
.
AttrDesc
()
var_format
.
type
=
attribute_pb2
.
INTS
var_format
.
name
=
"%s_format"
%
in_out
var_format
.
ints
.
append
(
0
)
for
var
in
meta
:
var_name
=
var
.
name
if
var
.
temporary
:
var_name
=
[
core
.
var_names
.
temp
()]
tmp_index
.
append
(
len
(
retv
))
else
:
var_name
=
kwargs
.
get
(
var_name
,
[])
if
not
isinstance
(
var_name
,
list
):
var_name
=
[
var_name
]
retv
.
extend
(
var_name
)
var_format
.
ints
.
append
(
len
(
var_name
)
+
var_format
.
ints
[
-
1
])
return
retv
,
var_format
,
tmp_index
else
:
for
var
in
meta
:
if
var
.
temporary
:
retv
.
append
(
kwargs
.
get
(
var
.
name
,
core
.
var_names
.
temp
()))
tmp_index
.
append
(
len
(
retv
))
else
:
retv
.
append
(
kwargs
.
get
(
var
.
name
,
core
.
var_names
.
empty
()))
return
retv
,
None
,
tmp_index
@
staticmethod
@
staticmethod
def
any_is_true
(
generator
):
def
any_is_true
(
generator
):
"""
"""
...
@@ -146,13 +112,12 @@ class OpDescCreationMethod(object):
...
@@ -146,13 +112,12 @@ class OpDescCreationMethod(object):
class
OpInfo
(
object
):
class
OpInfo
(
object
):
def
__init__
(
self
,
name
,
method
,
inputs
,
outputs
,
attrs
,
no_temp_outputs
):
def
__init__
(
self
,
name
,
method
,
inputs
,
outputs
,
attrs
):
self
.
name
=
name
self
.
name
=
name
self
.
method
=
method
self
.
method
=
method
self
.
inputs
=
inputs
self
.
inputs
=
inputs
self
.
outputs
=
outputs
self
.
outputs
=
outputs
self
.
attrs
=
attrs
self
.
attrs
=
attrs
self
.
no_temp_outputs
=
no_temp_outputs
def
create_op_creation_method
(
op_proto
):
def
create_op_creation_method
(
op_proto
):
...
@@ -170,10 +135,7 @@ def create_op_creation_method(op_proto):
...
@@ -170,10 +135,7 @@ def create_op_creation_method(op_proto):
name
=
op_proto
.
type
,
name
=
op_proto
.
type
,
inputs
=
[
var
.
name
for
var
in
op_proto
.
inputs
],
inputs
=
[
var
.
name
for
var
in
op_proto
.
inputs
],
outputs
=
[
var
.
name
for
var
in
op_proto
.
outputs
],
outputs
=
[
var
.
name
for
var
in
op_proto
.
outputs
],
attrs
=
[
attr
.
name
for
attr
in
op_proto
.
attrs
],
attrs
=
[
attr
.
name
for
attr
in
op_proto
.
attrs
])
no_temp_outputs
=
[
var
.
name
for
var
in
op_proto
.
outputs
if
not
var
.
temporary
])
class
OperatorFactory
(
object
):
class
OperatorFactory
(
object
):
...
@@ -214,8 +176,5 @@ class OperatorFactory(object):
...
@@ -214,8 +176,5 @@ class OperatorFactory(object):
def
get_op_attr_names
(
self
,
type
):
def
get_op_attr_names
(
self
,
type
):
return
self
.
get_op_info
(
type
).
attrs
return
self
.
get_op_info
(
type
).
attrs
def
get_op_no_temp_output_names
(
self
,
type
):
return
self
.
get_op_info
(
type
).
no_temp_outputs
Operator
=
OperatorFactory
()
# Default global factory
Operator
=
OperatorFactory
()
# Default global factory
python/paddle/v2/framework/tests/gradient_checker.py
浏览文件 @
81f5f861
...
@@ -53,15 +53,18 @@ def get_numeric_gradient(op,
...
@@ -53,15 +53,18 @@ def get_numeric_gradient(op,
tensor
.
set
(
input_values
[
var_name
],
core
.
CPUPlace
())
tensor
.
set
(
input_values
[
var_name
],
core
.
CPUPlace
())
# Create all output variable in local_scope
# Create all output variable in local_scope
for
output
in
op
.
outputs
():
opts
=
op
.
outputs
()
if
local_scope
.
find_var
(
output
)
is
None
:
for
key
in
opts
:
local_scope
.
new_var
(
output
).
get_tensor
()
for
output
in
opts
[
key
]:
if
local_scope
.
find_var
(
output
)
is
None
:
local_scope
.
new_var
(
output
).
get_tensor
()
op
.
infer_shape
(
local_scope
)
op
.
infer_shape
(
local_scope
)
# allocate output memory
# allocate output memory
for
output
in
op
.
outputs
():
for
key
in
opts
:
local_scope
.
find_var
(
output
).
get_tensor
().
alloc_float
(
core
.
CPUPlace
())
for
output
in
opts
[
key
]:
local_scope
.
find_var
(
output
).
get_tensor
().
alloc_float
(
core
.
CPUPlace
(
))
# TODO(yuyang18): Only CPU is support now.
# TODO(yuyang18): Only CPU is support now.
cpu_ctx
=
core
.
DeviceContext
.
create
(
core
.
CPUPlace
())
cpu_ctx
=
core
.
DeviceContext
.
create
(
core
.
CPUPlace
())
...
@@ -150,19 +153,24 @@ class GradientChecker(unittest.TestCase):
...
@@ -150,19 +153,24 @@ class GradientChecker(unittest.TestCase):
if
no_grad_set
is
None
:
if
no_grad_set
is
None
:
no_grad_set
=
set
()
no_grad_set
=
set
()
tmp_outs
=
forward_op
.
temp_outputs
()
no_tmp_out
=
forward_op
.
no_intermediate_outputs
()
no_tmp_out
=
filter
(
lambda
name
:
name
not
in
tmp_outs
,
forward_op
.
outputs
())
if
len
(
no_tmp_out
)
!=
1
:
if
len
(
no_tmp_out
)
!=
1
:
raise
ValueError
(
"non temp out_names should be 1"
)
raise
ValueError
(
"non temp out_names should be 1"
)
in_names
=
forward_op
.
inputs
()
inputs
=
forward_op
.
inputs
()
in_names
=
[
item
for
k
in
inputs
for
item
in
inputs
[
k
]]
outputs
=
forward_op
.
outputs
()
out_names
=
[
item
for
k
in
outputs
for
item
in
outputs
[
k
]]
for
no_grad
in
no_grad_set
:
for
no_grad
in
no_grad_set
:
if
no_grad
not
in
in_names
:
if
no_grad
not
in
in_names
:
raise
ValueError
(
"no_grad should be in in_names"
)
raise
ValueError
(
"no_grad should be in in_names"
)
backward_op
=
core
.
Operator
.
backward
(
forward_op
,
no_grad_set
)
backward_op
=
core
.
Operator
.
backward
(
forward_op
,
no_grad_set
)
bwd_outputs
=
backward_op
.
outputs
()
bwd_out_names
=
[
item
for
k
in
bwd_outputs
for
item
in
bwd_outputs
[
k
]]
places
=
[
core
.
CPUPlace
()]
places
=
[
core
.
CPUPlace
()]
if
not
only_cpu
and
core
.
is_compile_gpu
()
and
backward_op
.
support_gpu
():
if
not
only_cpu
and
core
.
is_compile_gpu
()
and
backward_op
.
support_gpu
():
places
.
append
(
core
.
GPUPlace
(
0
))
places
.
append
(
core
.
GPUPlace
(
0
))
...
@@ -188,7 +196,7 @@ class GradientChecker(unittest.TestCase):
...
@@ -188,7 +196,7 @@ class GradientChecker(unittest.TestCase):
var
.
set
(
value
,
place
)
var
.
set
(
value
,
place
)
# create output var
# create output var
for
out_name
in
forward_op
.
outputs
()
:
for
out_name
in
out_names
:
scope
.
new_var
(
out_name
).
get_tensor
()
scope
.
new_var
(
out_name
).
get_tensor
()
# infer the shape of output var and compute/set value of output var
# infer the shape of output var and compute/set value of output var
...
@@ -198,7 +206,7 @@ class GradientChecker(unittest.TestCase):
...
@@ -198,7 +206,7 @@ class GradientChecker(unittest.TestCase):
# create output grad var
# create output grad var
# set shape as the output var
# set shape as the output var
# set value of this grad to ones
# set value of this grad to ones
for
name
in
forward_op
.
outputs
()
:
for
name
in
out_names
:
out_tensor
=
scope
.
find_var
(
name
).
get_tensor
()
out_tensor
=
scope
.
find_var
(
name
).
get_tensor
()
grad_tensor
=
scope
.
new_var
(
grad_var_name
(
name
)).
get_tensor
()
grad_tensor
=
scope
.
new_var
(
grad_var_name
(
name
)).
get_tensor
()
grad_tensor
.
set_dims
(
out_tensor
.
shape
())
grad_tensor
.
set_dims
(
out_tensor
.
shape
())
...
@@ -206,7 +214,7 @@ class GradientChecker(unittest.TestCase):
...
@@ -206,7 +214,7 @@ class GradientChecker(unittest.TestCase):
grad_tensor
.
set
(
data
,
place
)
grad_tensor
.
set
(
data
,
place
)
# create input grad var
# create input grad var
for
name
in
b
ackward_op
.
outputs
()
:
for
name
in
b
wd_out_names
:
scope
.
new_var
(
name
).
get_tensor
()
scope
.
new_var
(
name
).
get_tensor
()
# infer the shape of input gradient var and compute/set it's value
# infer the shape of input gradient var and compute/set it's value
...
...
python/paddle/v2/framework/tests/test_add_two_op.py
浏览文件 @
81f5f861
...
@@ -19,14 +19,5 @@ class TestAddOp(unittest.TestCase):
...
@@ -19,14 +19,5 @@ class TestAddOp(unittest.TestCase):
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
+
self
.
inputs
[
'Y'
]}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
+
self
.
inputs
[
'Y'
]}
class
TestAddGradOp
(
unittest
.
TestCase
):
def
test_add_grad
(
self
):
op
=
Operator
(
'add_two'
,
X
=
"X"
,
Y
=
"Y"
,
Out
=
"Out"
)
backward_op
=
core
.
Operator
.
backward
(
op
,
set
())
self
.
assertEqual
(
backward_op
.
type
(),
"add_two_grad"
)
expected
=
'''Op(add_two_grad), inputs:(X, Y, Out, Out@GRAD), outputs:(X@GRAD, Y@GRAD).'''
self
.
assertEqual
(
expected
,
str
(
backward_op
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
python/paddle/v2/framework/tests/test_net.py
浏览文件 @
81f5f861
...
@@ -25,12 +25,12 @@ class TestNet(unittest.TestCase):
...
@@ -25,12 +25,12 @@ class TestNet(unittest.TestCase):
net
.
complete_add_op
(
True
)
net
.
complete_add_op
(
True
)
expected
=
'''
expected
=
'''
Op(plain_net), inputs:
(W, X, Y), outputs:(Out, fc.out, pre_activation)
.
Op(plain_net), inputs:
{all[W, X, Y]}, outputs:{all[Out, fc.out, pre_activation]}
.
Op(add_two), inputs:
(X, Y), outputs:(Out)
.
Op(add_two), inputs:
{X[X], Y[Y]}, outputs:{Out[Out]}
.
Op(plain_net), inputs:
(W, X), outputs:(fc.out, pre_activation)
.
Op(plain_net), inputs:
{all[W, X]}, outputs:{all[fc.out, pre_activation]}
.
Op(plain_net), inputs:
(W, X), outputs:(fc.out, pre_activation)
.
Op(plain_net), inputs:
{all[W, X]}, outputs:{all[fc.out, pre_activation]}
.
Op(mul), inputs:
(X, W), outputs:(pre_activation)
.
Op(mul), inputs:
{X[X], Y[W]}, outputs:{Out[pre_activation]}
.
Op(sigmoid), inputs:
(pre_activation), outputs:(fc.out)
.
Op(sigmoid), inputs:
{X[pre_activation]}, outputs:{Y[fc.out]}
.
'''
'''
self
.
assertEqual
(
expected
,
"
\n
"
+
str
(
net
))
self
.
assertEqual
(
expected
,
"
\n
"
+
str
(
net
))
...
...
python/paddle/v2/framework/tests/test_operator.py
浏览文件 @
81f5f861
import
unittest
import
unittest
import
paddle.v2.framework.op
as
op
import
paddle.v2.framework.op
as
op
import
paddle.v2.framework.core
as
core
import
paddle.v2.framework.core
as
core
import
paddle.v2.framework.proto.op_proto_pb2
as
op_proto_pb2
import
paddle.v2.framework.proto.framework_pb2
as
framework_pb2
import
paddle.v2.framework.proto.op_desc_pb2
as
op_desc_pb2
import
paddle.v2.framework.proto.attribute_pb2
as
attribute_pb2
class
TestGetAllProtos
(
unittest
.
TestCase
):
class
TestGetAllProtos
(
unittest
.
TestCase
):
...
@@ -17,7 +15,7 @@ class TestGetAllProtos(unittest.TestCase):
...
@@ -17,7 +15,7 @@ class TestGetAllProtos(unittest.TestCase):
class
TestOpDescCreationMethod
(
unittest
.
TestCase
):
class
TestOpDescCreationMethod
(
unittest
.
TestCase
):
def
test_plain_input_output
(
self
):
def
test_plain_input_output
(
self
):
op_proto
=
op_proto
_pb2
.
OpProto
()
op_proto
=
framework
_pb2
.
OpProto
()
op_proto
.
type
=
"test"
op_proto
.
type
=
"test"
ipt
=
op_proto
.
inputs
.
add
()
ipt
=
op_proto
.
inputs
.
add
()
ipt
.
name
=
"X"
ipt
.
name
=
"X"
...
@@ -37,25 +35,32 @@ class TestOpDescCreationMethod(unittest.TestCase):
...
@@ -37,25 +35,32 @@ class TestOpDescCreationMethod(unittest.TestCase):
method
=
op
.
OpDescCreationMethod
(
op_proto
)
method
=
op
.
OpDescCreationMethod
(
op_proto
)
output
=
method
(
X
=
"a"
,
Y
=
"b"
,
Z
=
"c"
)
output
=
method
(
X
=
"a"
,
Y
=
"b"
,
Z
=
"c"
)
expected
=
framework_pb2
.
OpDesc
()
expected
=
op_desc_pb2
.
OpDesc
()
expected
.
type
=
"test"
expected
.
type
=
"test"
expected
.
inputs
.
extend
([
"a"
,
"b"
])
ipt_0
=
expected
.
inputs
.
add
()
expected
.
outputs
.
append
(
"c"
)
ipt_0
.
parameter
=
"X"
ipt_0
.
arguments
.
extend
([
"a"
])
ipt_1
=
expected
.
inputs
.
add
()
ipt_1
.
parameter
=
'Y'
ipt_1
.
arguments
.
extend
([
'b'
])
opt
=
expected
.
outputs
.
add
()
opt
.
parameter
=
"Z"
opt
.
arguments
.
extend
([
"c"
])
self
.
assertEqual
(
expected
,
output
)
self
.
assertEqual
(
expected
,
output
)
def
test_multiple_input_plain_output
(
self
):
def
test_multiple_input_plain_output
(
self
):
op_proto
=
op_proto
_pb2
.
OpProto
()
op_proto
=
framework
_pb2
.
OpProto
()
op_proto
.
type
=
"fc"
op_proto
.
type
=
"fc"
ipt
=
op_proto
.
inputs
.
add
()
ipt
=
op_proto
.
inputs
.
add
()
ipt
.
name
=
"X"
ipt
.
name
=
"X"
ipt
.
comment
=
""
ipt
.
comment
=
""
ipt
.
multip
le
=
True
ipt
.
duplicab
le
=
True
ipt
=
op_proto
.
inputs
.
add
()
ipt
=
op_proto
.
inputs
.
add
()
ipt
.
name
=
"W"
ipt
.
name
=
"W"
ipt
.
comment
=
""
ipt
.
comment
=
""
ipt
.
multip
le
=
True
ipt
.
duplicab
le
=
True
ipt
=
op_proto
.
inputs
.
add
()
ipt
=
op_proto
.
inputs
.
add
()
ipt
.
name
=
"b"
ipt
.
name
=
"b"
...
@@ -70,30 +75,50 @@ class TestOpDescCreationMethod(unittest.TestCase):
...
@@ -70,30 +75,50 @@ class TestOpDescCreationMethod(unittest.TestCase):
method
=
op
.
OpDescCreationMethod
(
op_proto
)
method
=
op
.
OpDescCreationMethod
(
op_proto
)
generated1
=
method
(
X
=
"x"
,
W
=
"w"
,
b
=
"b"
,
Y
=
"y"
)
generated1
=
method
(
X
=
"x"
,
W
=
"w"
,
b
=
"b"
,
Y
=
"y"
)
expected1
=
op_desc_pb2
.
OpDesc
()
expected1
=
framework_pb2
.
OpDesc
()
expected1
.
inputs
.
extend
([
'x'
,
'w'
,
'b'
])
tmp
=
expected1
.
inputs
.
add
()
expected1
.
outputs
.
extend
([
'y'
])
tmp
.
parameter
=
"X"
tmp
.
arguments
.
extend
([
'x'
])
tmp
=
expected1
.
inputs
.
add
()
tmp
.
parameter
=
'W'
tmp
.
arguments
.
extend
([
'w'
])
tmp
=
expected1
.
inputs
.
add
()
tmp
.
parameter
=
'b'
tmp
.
arguments
.
extend
([
'b'
])
tmp
=
expected1
.
outputs
.
add
()
tmp
.
parameter
=
'Y'
tmp
.
arguments
.
extend
([
'y'
])
expected1
.
type
=
'fc'
expected1
.
type
=
'fc'
attr
=
expected1
.
attrs
.
add
()
attr
.
name
=
'input_format'
attr
.
type
=
attribute_pb2
.
INTS
attr
.
ints
.
extend
([
0
,
1
,
2
,
3
])
self
.
assertEqual
(
expected1
,
generated1
)
self
.
assertEqual
(
expected1
,
generated1
)
generated2
=
method
(
generated2
=
method
(
X
=
[
'x1'
,
'x2'
,
'x3'
],
b
=
'b'
,
W
=
[
'w1'
,
'w2'
,
'w3'
],
Y
=
'y'
)
X
=
[
'x1'
,
'x2'
,
'x3'
],
b
=
'b'
,
W
=
[
'w1'
,
'w2'
,
'w3'
],
Y
=
'y'
)
expected2
=
op_desc_pb2
.
OpDesc
()
expected2
=
framework_pb2
.
OpDesc
()
expected2
.
inputs
.
extend
([
'x1'
,
'x2'
,
'x3'
,
'w1'
,
'w2'
,
'w3'
,
'b'
])
expected2
.
outputs
.
extend
([
'y'
])
tmp
=
expected2
.
inputs
.
add
()
tmp
.
parameter
=
"X"
tmp
.
arguments
.
extend
([
'x1'
,
'x2'
,
'x3'
])
tmp
=
expected2
.
inputs
.
add
()
tmp
.
parameter
=
'W'
tmp
.
arguments
.
extend
([
'w1'
,
'w2'
,
'w3'
])
tmp
=
expected2
.
inputs
.
add
()
tmp
.
parameter
=
'b'
tmp
.
arguments
.
extend
([
'b'
])
tmp
=
expected2
.
outputs
.
add
()
tmp
.
parameter
=
'Y'
tmp
.
arguments
.
extend
([
'y'
])
expected2
.
type
=
'fc'
expected2
.
type
=
'fc'
attr
=
expected2
.
attrs
.
add
()
attr
.
name
=
'input_format'
attr
.
type
=
attribute_pb2
.
INTS
attr
.
ints
.
extend
([
0
,
3
,
6
,
7
])
self
.
assertEqual
(
expected2
,
generated2
)
self
.
assertEqual
(
expected2
,
generated2
)
def
test_attrs
(
self
):
def
test_attrs
(
self
):
op_proto
=
op_proto
_pb2
.
OpProto
()
op_proto
=
framework
_pb2
.
OpProto
()
op_proto
.
type
=
"test"
op_proto
.
type
=
"test"
ipt
=
op_proto
.
inputs
.
add
()
ipt
=
op_proto
.
inputs
.
add
()
ipt
.
name
=
'X'
ipt
.
name
=
'X'
...
@@ -105,12 +130,12 @@ class TestOpDescCreationMethod(unittest.TestCase):
...
@@ -105,12 +130,12 @@ class TestOpDescCreationMethod(unittest.TestCase):
attr
.
comment
=
""
attr
.
comment
=
""
attr
.
type
=
type
attr
.
type
=
type
__add_attr__
(
"int_attr"
,
attribute
_pb2
.
INT
)
__add_attr__
(
"int_attr"
,
framework
_pb2
.
INT
)
__add_attr__
(
"float_attr"
,
attribute
_pb2
.
FLOAT
)
__add_attr__
(
"float_attr"
,
framework
_pb2
.
FLOAT
)
__add_attr__
(
"string_attr"
,
attribute
_pb2
.
STRING
)
__add_attr__
(
"string_attr"
,
framework
_pb2
.
STRING
)
__add_attr__
(
"ints_attr"
,
attribute
_pb2
.
INTS
)
__add_attr__
(
"ints_attr"
,
framework
_pb2
.
INTS
)
__add_attr__
(
"floats_attr"
,
attribute
_pb2
.
FLOATS
)
__add_attr__
(
"floats_attr"
,
framework
_pb2
.
FLOATS
)
__add_attr__
(
"strings_attr"
,
attribute
_pb2
.
STRINGS
)
__add_attr__
(
"strings_attr"
,
framework
_pb2
.
STRINGS
)
op_proto
.
comment
=
""
op_proto
.
comment
=
""
self
.
assertTrue
(
op_proto
.
IsInitialized
())
self
.
assertTrue
(
op_proto
.
IsInitialized
())
...
@@ -126,76 +151,52 @@ class TestOpDescCreationMethod(unittest.TestCase):
...
@@ -126,76 +151,52 @@ class TestOpDescCreationMethod(unittest.TestCase):
floats_attr
=
[
0.2
,
3.2
,
4.5
],
floats_attr
=
[
0.2
,
3.2
,
4.5
],
strings_attr
=
[
"a"
,
"b"
,
"c"
])
strings_attr
=
[
"a"
,
"b"
,
"c"
])
expected
=
op_desc
_pb2
.
OpDesc
()
expected
=
framework
_pb2
.
OpDesc
()
expected
.
type
=
"test"
expected
.
type
=
"test"
expected
.
inputs
.
extend
([
'a'
])
ipt
=
expected
.
inputs
.
add
()
ipt
.
parameter
=
"X"
ipt
.
arguments
.
extend
([
'a'
])
attr
=
expected
.
attrs
.
add
()
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
"int_attr"
attr
.
name
=
"int_attr"
attr
.
type
=
attribute
_pb2
.
INT
attr
.
type
=
framework
_pb2
.
INT
attr
.
i
=
10
attr
.
i
=
10
attr
=
expected
.
attrs
.
add
()
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
"float_attr"
attr
.
name
=
"float_attr"
attr
.
type
=
attribute
_pb2
.
FLOAT
attr
.
type
=
framework
_pb2
.
FLOAT
attr
.
f
=
3.2
attr
.
f
=
3.2
attr
=
expected
.
attrs
.
add
()
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
"string_attr"
attr
.
name
=
"string_attr"
attr
.
type
=
attribute
_pb2
.
STRING
attr
.
type
=
framework
_pb2
.
STRING
attr
.
s
=
"test_str"
attr
.
s
=
"test_str"
attr
=
expected
.
attrs
.
add
()
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
"ints_attr"
attr
.
name
=
"ints_attr"
attr
.
type
=
attribute
_pb2
.
INTS
attr
.
type
=
framework
_pb2
.
INTS
attr
.
ints
.
extend
([
0
,
1
,
2
,
3
,
4
])
attr
.
ints
.
extend
([
0
,
1
,
2
,
3
,
4
])
attr
=
expected
.
attrs
.
add
()
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
'floats_attr'
attr
.
name
=
'floats_attr'
attr
.
type
=
attribute
_pb2
.
FLOATS
attr
.
type
=
framework
_pb2
.
FLOATS
attr
.
floats
.
extend
([
0.2
,
3.2
,
4.5
])
attr
.
floats
.
extend
([
0.2
,
3.2
,
4.5
])
attr
=
expected
.
attrs
.
add
()
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
'strings_attr'
attr
.
name
=
'strings_attr'
attr
.
type
=
attribute
_pb2
.
STRINGS
attr
.
type
=
framework
_pb2
.
STRINGS
attr
.
strings
.
extend
([
'a'
,
'b'
,
'c'
])
attr
.
strings
.
extend
([
'a'
,
'b'
,
'c'
])
self
.
assertEqual
(
expected
,
generated
)
self
.
assertEqual
(
expected
,
generated
)
def
test_input_temporary_output
(
self
):
op_proto
=
op_proto_pb2
.
OpProto
()
op_proto
.
type
=
"test"
out
=
op_proto
.
outputs
.
add
()
out
.
name
=
"OUT"
out
.
comment
=
""
out
=
op_proto
.
outputs
.
add
()
out
.
name
=
"TMP"
out
.
comment
=
""
out
.
temporary
=
True
out
=
op_proto
.
outputs
.
add
()
out
.
name
=
"OUT2"
out
.
comment
=
""
op_proto
.
comment
=
""
method
=
op
.
OpDescCreationMethod
(
op_proto
)
generated
=
method
(
OUT
=
"a"
,
OUT2
=
"b"
)
desc
=
op_desc_pb2
.
OpDesc
()
desc
.
outputs
.
extend
([
"a"
,
core
.
var_names
.
temp
(),
"b"
])
desc
.
type
=
"test"
attr
=
desc
.
attrs
.
add
()
attr
.
name
=
"temporary_index"
attr
.
type
=
attribute_pb2
.
INTS
attr
.
ints
.
append
(
2
)
self
.
assertEqual
(
generated
,
desc
)
class
TestOpCreations
(
unittest
.
TestCase
):
class
TestOpCreations
(
unittest
.
TestCase
):
def
test_all
(
self
):
def
test_all
(
self
):
add_op
=
op
.
Operator
(
"add_two"
,
X
=
"a"
,
Y
=
"b"
,
Out
=
"z"
)
add_op
=
op
.
Operator
(
"add_two"
,
X
=
"a"
,
Y
=
"b"
,
Out
=
"z"
)
self
.
assertIsNotNone
(
add_op
)
self
.
assertIsNotNone
(
add_op
)
# Invoke C++ DebugString()
# Invoke C++ DebugString()
self
.
assertEqual
(
'Op(add_two), inputs:
(a, b), outputs:(z)
.'
,
self
.
assertEqual
(
'Op(add_two), inputs:
{X[a], Y[b]}, outputs:{Out[z]}
.'
,
str
(
add_op
))
str
(
add_op
))
...
...
python/paddle/v2/framework/tests/test_protobuf.py
浏览文件 @
81f5f861
import
paddle.v2.framework.proto.op_proto_pb2
as
op_proto_lib
import
paddle.v2.framework.proto.framework_pb2
as
framework_pb2
import
paddle.v2.framework.proto.attribute_pb2
as
attr_type_lib
import
unittest
import
unittest
class
TestFrameworkProto
(
unittest
.
TestCase
):
class
TestFrameworkProto
(
unittest
.
TestCase
):
def
test_all
(
self
):
def
test_all
(
self
):
op_proto
=
op_proto_lib
.
OpProto
()
op_proto
=
framework_pb2
.
OpProto
()
ipt0
=
op_proto
.
inputs
.
add
()
ipt0
=
op_proto
.
inputs
.
add
()
ipt0
.
name
=
"a"
ipt0
.
name
=
"a"
ipt0
.
comment
=
"the input of cosine op"
ipt0
.
comment
=
"the input of cosine op"
...
@@ -19,7 +18,7 @@ class TestFrameworkProto(unittest.TestCase):
...
@@ -19,7 +18,7 @@ class TestFrameworkProto(unittest.TestCase):
attr
=
op_proto
.
attrs
.
add
()
attr
=
op_proto
.
attrs
.
add
()
attr
.
name
=
"scale"
attr
.
name
=
"scale"
attr
.
comment
=
"scale of cosine op"
attr
.
comment
=
"scale of cosine op"
attr
.
type
=
attr_type_lib
.
FLOAT
attr
.
type
=
framework_pb2
.
FLOAT
op_proto
.
type
=
"cos"
op_proto
.
type
=
"cos"
self
.
assertTrue
(
op_proto
.
IsInitialized
())
self
.
assertTrue
(
op_proto
.
IsInitialized
())
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录