Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
9806e7f2
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9806e7f2
编写于
8月 16, 2017
作者:
Y
Yu Yang
提交者:
GitHub
8月 16, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3522 from reyoung/feature/clone_op
Feature/clone op
上级
ac61f784
c7f25325
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
116 addition
and
14 deletion
+116
-14
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+16
-4
paddle/framework/operator.h
paddle/framework/operator.h
+25
-5
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+18
-0
paddle/operators/net_op.cc
paddle/operators/net_op.cc
+8
-1
paddle/operators/net_op.h
paddle/operators/net_op.h
+14
-0
paddle/operators/net_op_test.cc
paddle/operators/net_op_test.cc
+17
-0
paddle/operators/recurrent_op.h
paddle/operators/recurrent_op.h
+18
-4
未找到文件。
paddle/framework/op_registry.h
浏览文件 @
9806e7f2
...
...
@@ -144,8 +144,18 @@ class OpKernelRegistrar : public Registrar {
grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
static ::paddle::framework::OpRegistrar<op_class, op_maker_class, \
grad_op_class> \
class _OpClass_##op_type##_ : public op_class { \
public: \
DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \
DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \
}; \
class _OpGradClass_##op_type##_ : public grad_op_class { \
public: \
DEFINE_OP_CLONE_METHOD(_OpGradClass_##op_type##_); \
DEFINE_OP_CONSTRUCTOR(_OpGradClass_##op_type##_, grad_op_class); \
}; \
static ::paddle::framework::OpRegistrar< \
_OpClass_##op_type##_, op_maker_class, _OpGradClass_##op_type##_> \
__op_registrar_##op_type##__(#op_type, #grad_op_type); \
int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \
...
...
@@ -176,7 +186,8 @@ class OpKernelRegistrar : public Registrar {
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
/**
* Macro to mark what Operator and Kernel we will use and tell the compiler to
* Macro to mark what Operator and Kernel
* we will use and tell the compiler to
* link them into target.
*/
#define USE_OP_ITSELF(op_type) \
...
...
@@ -196,7 +207,8 @@ class OpKernelRegistrar : public Registrar {
__attribute__((unused)) = \
TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE()
// TODO(fengjiayi): The following macros seems ugly, do we have better method?
// TODO(fengjiayi): The following macros
// seems ugly, do we have better method?
#ifdef PADDLE_ONLY_CPU
#define USE_OP_KERNEL(op_type) USE_OP_DEVICE_KERNEL(op_type, CPU)
...
...
paddle/framework/operator.h
浏览文件 @
9806e7f2
...
...
@@ -67,10 +67,6 @@ class OperatorBase {
OperatorBase
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
AttributeMap
&
attrs
);
OperatorBase
(
const
OperatorBase
&
o
)
=
delete
;
OperatorBase
&
operator
=
(
const
OperatorBase
&
o
)
=
delete
;
OperatorBase
(
OperatorBase
&&
o
)
=
delete
;
virtual
~
OperatorBase
()
{}
template
<
typename
T
>
...
...
@@ -116,10 +112,14 @@ class OperatorBase {
void
SetType
(
const
std
::
string
&
type
)
{
type_
=
type
;
}
const
AttributeMap
&
Attrs
()
const
{
return
attrs_
;
}
// Return a new operator instance, which is as same as this.
// Use unique_ptr to prevent caller forget to delete this pointer.
virtual
std
::
unique_ptr
<
OperatorBase
>
Clone
()
const
=
0
;
protected:
std
::
string
type_
;
// NOTE: in case of OpGrad, inputs_ contains:
// I (Inputs)
// I (Inputs)
opear
// O (Outputs)
// OG (Output Gradients)
VarNameMap
inputs_
;
...
...
@@ -130,12 +130,32 @@ class OperatorBase {
AttributeMap
attrs_
;
};
// Macro for define a clone method.
// If you are writing an kernel operator, `Clone` will be defined when you
// register it. i.e. `Clone` method is not needed to define by yourself.
#define DEFINE_OP_CLONE_METHOD(CLS) \
std::unique_ptr<OperatorBase> Clone() const final { \
return std::unique_ptr<OperatorBase>(new CLS(*this)); \
}
// Macro for define a default constructor for Operator.
// You can also use
// using PARENT_CLASS::PARENT_CLASS;
// to use parent's constructor.
#define DEFINE_OP_CONSTRUCTOR(CLS, PARENT_CLS) \
CLS(const std::string& type, const VarNameMap& inputs, \
const VarNameMap& outputs, const paddle::framework::AttributeMap& attrs) \
: PARENT_CLS(type, inputs, outputs, attrs) {}
class
NOP
:
public
OperatorBase
{
public:
using
OperatorBase
::
OperatorBase
;
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
std
::
unique_ptr
<
OperatorBase
>
Clone
()
const
override
{
return
std
::
unique_ptr
<
OperatorBase
>
(
new
NOP
(
*
this
));
}
};
// this class not only make proto but also init attribute checkers.
...
...
paddle/framework/operator_test.cc
浏览文件 @
9806e7f2
...
...
@@ -245,3 +245,21 @@ TEST(OpKernel, multi_inputs) {
auto
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
op
->
Run
(
scope
,
cpu_device_context
);
}
class
OperatorClone
:
public
paddle
::
framework
::
OperatorBase
{
public:
DEFINE_OP_CLONE_METHOD
(
OperatorClone
);
OperatorClone
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
paddle
::
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
const
paddle
::
framework
::
Scope
&
scope
)
const
override
{}
void
Run
(
const
paddle
::
framework
::
Scope
&
scope
,
const
paddle
::
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
};
TEST
(
Operator
,
Clone
)
{
OperatorClone
a
(
"ABC"
,
{},
{},
{});
auto
b
=
a
.
Clone
();
ASSERT_EQ
(
a
.
Type
(),
b
->
Type
());
}
\ No newline at end of file
paddle/operators/net_op.cc
浏览文件 @
9806e7f2
...
...
@@ -85,7 +85,14 @@ NetOp::NetOp(const std::string& type,
const
framework
::
OperatorBase
::
VarNameMap
&
inputs
,
const
framework
::
OperatorBase
::
VarNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
std
::
unique_ptr
<
framework
::
OperatorBase
>
NetOp
::
Clone
()
const
{
PADDLE_ENFORCE
(
add_op_done_
,
"Must clone a sealed NetOp, invoke Net::CompleteAddOp before clone"
);
return
std
::
unique_ptr
<
OperatorBase
>
(
new
NetOp
(
*
this
));
}
}
// namespace operators
}
// namespace paddle
paddle/operators/net_op.h
浏览文件 @
9806e7f2
...
...
@@ -41,6 +41,18 @@ class NetOp : public framework::OperatorBase {
NetOp
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
);
NetOp
(
const
NetOp
&
o
)
:
framework
::
OperatorBase
(
static_cast
<
const
framework
::
OperatorBase
&>
(
o
))
{
this
->
ops_
.
reserve
(
o
.
ops_
.
size
());
std
::
transform
(
o
.
ops_
.
begin
(),
o
.
ops_
.
end
(),
std
::
back_inserter
(
this
->
ops_
),
[](
const
std
::
shared_ptr
<
OperatorBase
>&
op
)
->
std
::
shared_ptr
<
OperatorBase
>
{
return
std
::
shared_ptr
<
OperatorBase
>
(
op
->
Clone
());
});
this
->
CompleteAddOp
();
}
/**
* Infer all the operators' input and output variables' shapes, will be called
* before every mini-batch
...
...
@@ -98,6 +110,8 @@ class NetOp : public framework::OperatorBase {
bool
IsNetOp
()
const
override
;
std
::
vector
<
std
::
string
>
OutputVars
(
bool
has_intermediate
)
const
override
;
std
::
unique_ptr
<
framework
::
OperatorBase
>
Clone
()
const
override
;
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
>>
ops_
;
private:
...
...
paddle/operators/net_op_test.cc
浏览文件 @
9806e7f2
...
...
@@ -13,6 +13,7 @@ static int run_cnt = 0;
class
TestOp
:
public
framework
::
OperatorBase
{
public:
using
framework
::
OperatorBase
::
OperatorBase
;
DEFINE_OP_CLONE_METHOD
(
TestOp
);
void
InferShape
(
const
Scope
&
scope
)
const
override
{
++
infer_shape_cnt
;
}
void
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
...
...
@@ -70,5 +71,21 @@ TEST(NetOp, insert_op) {
ASSERT_EQ
(
3UL
,
net
.
ops_
.
size
());
}
TEST
(
NetOp
,
Clone
)
{
NetOp
net
;
net
.
AddOp
(
std
::
shared_ptr
<
framework
::
NOP
>
(
new
framework
::
NOP
{
"empty"
,
{},
{},
{}}));
net
.
AddOp
(
std
::
shared_ptr
<
framework
::
NOP
>
(
new
framework
::
NOP
{
"empty2"
,
{},
{},
{}}));
net
.
CompleteAddOp
(
true
);
auto
new_net_op
=
net
.
Clone
();
ASSERT_NE
(
new_net_op
,
nullptr
);
ASSERT_TRUE
(
new_net_op
->
IsNetOp
());
auto
*
new_net
=
static_cast
<
NetOp
*>
(
new_net_op
.
get
());
ASSERT_EQ
(
2
,
new_net
->
ops_
.
size
());
ASSERT_EQ
(
new_net
->
ops_
[
0
]
->
Type
(),
"empty"
);
ASSERT_EQ
(
new_net
->
ops_
[
1
]
->
Type
(),
"empty2"
);
}
}
// namespace operators
}
// namespace paddle
paddle/operators/recurrent_op.h
浏览文件 @
9806e7f2
...
...
@@ -110,13 +110,20 @@ class RecurrentGradientAlgorithm {
std
::
shared_ptr
<
NetOp
>*
stepnet_
;
};
class
RecurrentOp
final
:
public
framework
::
OperatorBase
{
class
RecurrentOp
:
public
framework
::
OperatorBase
{
public:
RecurrentOp
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
);
RecurrentOp
(
const
RecurrentOp
&
o
)
:
framework
::
OperatorBase
(
static_cast
<
const
framework
::
OperatorBase
&>
(
o
))
{
// TODO(yuyang18): Implement copy ctor well.
PADDLE_THROW
(
"Not implemented"
);
}
/**
* InferShape must be called before Run.
*/
* InferShape must be called before Run.
*/
void
InferShape
(
const
framework
::
Scope
&
scope
)
const
override
{
alg_
.
InferShape
(
scope
);
}
...
...
@@ -137,12 +144,19 @@ class RecurrentOp final : public framework::OperatorBase {
std
::
shared_ptr
<
NetOp
>
stepnet_
;
};
class
RecurrentGradientOp
final
:
public
framework
::
OperatorBase
{
class
RecurrentGradientOp
:
public
framework
::
OperatorBase
{
public:
RecurrentGradientOp
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
);
RecurrentGradientOp
(
const
RecurrentGradientOp
&
o
)
:
framework
::
OperatorBase
(
static_cast
<
const
framework
::
OperatorBase
&>
(
o
))
{
// TODO(yuyang18): Implement Copy ctor.
PADDLE_THROW
(
"Not Implemented"
);
}
/**
* InferShape must be called before Run.
*/
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录