Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
58f3de95
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
58f3de95
编写于
7月 14, 2017
作者:
Q
Qiao Longfei
提交者:
GitHub
7月 14, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize ptr (#2851)
* use OperatorPtr = std::shared_ptr<OperatorBase>; * use ScopePtr = std::share_ptr<Scope>;
上级
2462d0c5
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
82 addition
and
37 deletion
+82
-37
paddle/framework/net.cc
paddle/framework/net.cc
+2
-2
paddle/framework/net.h
paddle/framework/net.h
+7
-6
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+2
-2
paddle/framework/op_registry_test.cc
paddle/framework/op_registry_test.cc
+10
-10
paddle/framework/operator.h
paddle/framework/operator.h
+6
-6
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+50
-9
paddle/framework/scope.h
paddle/framework/scope.h
+5
-2
未找到文件。
paddle/framework/net.cc
浏览文件 @
58f3de95
...
...
@@ -5,13 +5,13 @@ namespace framework {
PlainNet
::
PlainNet
(
const
NetDesc
&
def
)
{}
void
PlainNet
::
InferShape
(
Scope
*
scope
)
{
void
PlainNet
::
InferShape
(
const
ScopePtr
&
scope
)
const
{
for
(
auto
&
op
:
ops_
)
{
op
.
InferShape
();
}
}
void
PlainNet
::
Run
(
std
::
shared_ptr
<
Scope
>
scope
,
DeviceContext
*
ctx
)
{
void
PlainNet
::
Run
(
const
ScopePtr
&
scope
,
const
DeviceContext
&
ctx
)
const
{
for
(
auto
&
op
:
ops_
)
{
op
.
Run
(
ctx
);
}
...
...
paddle/framework/net.h
浏览文件 @
58f3de95
...
...
@@ -37,8 +37,8 @@ struct OpAttrs {};
class
Operator
{
public:
Operator
(
const
OpDesc
&
def
)
{}
void
InferShape
()
{}
void
Run
(
DeviceContext
*
ctx
)
{}
void
InferShape
()
const
{}
void
Run
(
const
DeviceContext
&
ctx
)
const
{}
};
/**
...
...
@@ -60,7 +60,7 @@ class Net {
/**
* @brief Infer shapes of all inputs and outputs of operators.
*/
virtual
void
InferShape
(
Scope
*
scope
)
=
0
;
virtual
void
InferShape
(
const
ScopePtr
&
scope
)
const
=
0
;
/**
* @brief Run the network.
*
...
...
@@ -69,7 +69,7 @@ class Net {
* environment for ops. `begin` and `end` specify the scope of `ops_` to run,
* If no positive indexes are provided, all operators in `ops_` will run.
*/
virtual
void
Run
(
std
::
shared_ptr
<
Scope
>
scope
,
DeviceContext
*
ctx
)
=
0
;
virtual
void
Run
(
const
ScopePtr
&
scope
,
const
DeviceContext
&
ctx
)
const
=
0
;
/**
* @brief Add an Operator according to `def`.
...
...
@@ -114,7 +114,7 @@ class PlainNet : public Net {
* Infer all the operators' input and output varialbes' shapes, will be called
* before every mini-batch
*/
virtual
void
InferShape
(
Scope
*
scope
)
override
;
virtual
void
InferShape
(
const
ScopePtr
&
scope
)
const
override
;
/**
* @brief Run the network.
...
...
@@ -123,7 +123,8 @@ class PlainNet : public Net {
* scope will be used instead. If no OpContext is provicded, default context
* will be used.
*/
virtual
void
Run
(
std
::
shared_ptr
<
Scope
>
scope
,
DeviceContext
*
ctx
)
override
;
virtual
void
Run
(
const
ScopePtr
&
scope
,
const
DeviceContext
&
ctx
)
const
override
;
/**
* @brief Add an operator to this network.
...
...
paddle/framework/op_registry.h
浏览文件 @
58f3de95
...
...
@@ -198,9 +198,9 @@ class OpRegistry {
op_type
,
op_proto
.
InitializationErrorString
());
}
static
Operator
Base
*
CreateOp
(
const
OpDesc
&
op_desc
)
{
static
Operator
Ptr
CreateOp
(
const
OpDesc
&
op_desc
)
{
std
::
string
op_type
=
op_desc
.
type
();
Operator
Base
*
op
=
creators
().
at
(
op_type
)(
);
Operator
Ptr
op
(
creators
().
at
(
op_type
)()
);
op
->
desc_
=
op_desc
;
op
->
inputs_
.
reserve
((
size_t
)
op_desc
.
inputs_size
());
std
::
copy
(
op_desc
.
inputs
().
begin
(),
op_desc
.
inputs
().
end
(),
...
...
paddle/framework/op_registry_test.cc
浏览文件 @
58f3de95
...
...
@@ -5,9 +5,9 @@ namespace paddle {
namespace
framework
{
class
CosineOp
:
public
OperatorBase
{
public:
void
Run
(
const
std
::
shared_ptr
<
Scope
>
&
scope
,
void
Run
(
const
ScopePtr
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>
&
scope
)
const
override
{}
void
InferShape
(
const
ScopePtr
&
scope
)
const
override
{}
};
class
CosineOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
...
...
@@ -25,8 +25,8 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class
MyTestOp
:
public
OperatorBase
{
public:
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>
&
scope
)
const
override
{}
void
Run
(
const
std
::
shared_ptr
<
Scope
>
&
scope
,
void
InferShape
(
const
ScopePtr
&
scope
)
const
override
{}
void
Run
(
const
ScopePtr
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
public:
...
...
@@ -67,7 +67,7 @@ TEST(OpRegistry, CreateOp) {
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
attr
->
set_f
(
scale
);
paddle
::
framework
::
Operator
Base
*
op
=
paddle
::
framework
::
Operator
Ptr
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
auto
scope
=
std
::
make_shared
<
paddle
::
framework
::
Scope
>
();
paddle
::
platform
::
CPUDeviceContext
dev_ctx
;
...
...
@@ -89,7 +89,7 @@ TEST(OpRegistry, IllegalAttr) {
bool
caught
=
false
;
try
{
paddle
::
framework
::
Operator
Base
*
op
__attribute__
((
unused
))
=
paddle
::
framework
::
Operator
Ptr
op
__attribute__
((
unused
))
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
}
catch
(
paddle
::
framework
::
EnforceNotMet
err
)
{
caught
=
true
;
...
...
@@ -110,7 +110,7 @@ TEST(OpRegistry, DefaultValue) {
ASSERT_TRUE
(
op_desc
.
IsInitialized
());
paddle
::
framework
::
Operator
Base
*
op
=
paddle
::
framework
::
Operator
Ptr
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
auto
scope
=
std
::
make_shared
<
paddle
::
framework
::
Scope
>
();
paddle
::
platform
::
CPUDeviceContext
dev_ctx
;
...
...
@@ -136,7 +136,7 @@ TEST(OpRegistry, CustomChecker) {
// attr 'test_attr' is not set
bool
caught
=
false
;
try
{
paddle
::
framework
::
Operator
Base
*
op
__attribute__
((
unused
))
=
paddle
::
framework
::
Operator
Ptr
op
__attribute__
((
unused
))
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
}
catch
(
paddle
::
framework
::
EnforceNotMet
err
)
{
caught
=
true
;
...
...
@@ -155,7 +155,7 @@ TEST(OpRegistry, CustomChecker) {
attr
->
set_i
(
3
);
caught
=
false
;
try
{
paddle
::
framework
::
Operator
Base
*
op
__attribute__
((
unused
))
=
paddle
::
framework
::
Operator
Ptr
op
__attribute__
((
unused
))
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
}
catch
(
paddle
::
framework
::
EnforceNotMet
err
)
{
caught
=
true
;
...
...
@@ -174,7 +174,7 @@ TEST(OpRegistry, CustomChecker) {
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
INT
);
attr
->
set_i
(
4
);
SetInputFormat
(
&
op_desc
);
paddle
::
framework
::
Operator
Base
*
op
=
paddle
::
framework
::
Operator
Ptr
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
paddle
::
platform
::
CPUDeviceContext
dev_ctx
;
auto
scope
=
std
::
make_shared
<
paddle
::
framework
::
Scope
>
();
...
...
paddle/framework/operator.h
浏览文件 @
58f3de95
...
...
@@ -30,7 +30,7 @@ namespace paddle {
namespace
framework
{
class
OperatorBase
;
using
OperatorPtr
=
std
::
shared_ptr
<
OperatorBase
>
;
/**
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
...
...
@@ -56,10 +56,10 @@ class OperatorBase {
/// InferShape infer the size of Variables used by this Operator with
/// information inside scope
virtual
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>
&
scope
)
const
=
0
;
virtual
void
InferShape
(
const
ScopePtr
&
scope
)
const
=
0
;
/// Net will call this function to Run an op.
virtual
void
Run
(
const
std
::
shared_ptr
<
Scope
>
&
scope
,
virtual
void
Run
(
const
ScopePtr
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
=
0
;
protected:
...
...
@@ -82,7 +82,7 @@ class OpKernel {
*/
class
KernelContext
{
public:
KernelContext
(
const
OperatorBase
*
op
,
const
std
::
shared_ptr
<
Scope
>
&
scope
,
KernelContext
(
const
OperatorBase
*
op
,
const
ScopePtr
&
scope
,
const
platform
::
DeviceContext
&
device_context
)
:
op_
(
*
op
),
scope_
(
scope
),
device_context_
(
device_context
)
{}
...
...
@@ -95,7 +95,7 @@ class OpKernel {
}
const
OperatorBase
&
op_
;
const
std
::
shared_ptr
<
Scope
>
&
scope_
;
const
ScopePtr
&
scope_
;
const
platform
::
DeviceContext
&
device_context_
;
};
...
...
@@ -140,7 +140,7 @@ class OperatorWithKernel : public OperatorBase {
using
OpKernelMap
=
std
::
unordered_map
<
OpKernelKey
,
std
::
unique_ptr
<
OpKernel
>
,
OpKernelHash
>
;
void
Run
(
const
std
::
shared_ptr
<
Scope
>
&
scope
,
void
Run
(
const
ScopePtr
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
final
{
auto
&
opKernel
=
AllOpKernels
().
at
(
Type
()).
at
(
OpKernelKey
(
dev_ctx
));
opKernel
->
Compute
(
OpKernel
::
KernelContext
(
this
,
scope
,
dev_ctx
));
...
...
paddle/framework/operator_test.cc
浏览文件 @
58f3de95
...
...
@@ -22,8 +22,8 @@ namespace framework {
class
OperatorTest
:
public
OperatorBase
{
public:
void
Init
()
override
{
x
=
1
;
}
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>
&
scope
)
const
override
{}
void
Run
(
const
std
::
shared_ptr
<
Scope
>
&
scope
,
void
InferShape
(
const
ScopePtr
&
scope
)
const
override
{}
void
Run
(
const
ScopePtr
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
float
scale
=
GetAttr
<
float
>
(
"scale"
);
ASSERT_NEAR
(
scale
,
3.14
,
1e-5
);
...
...
@@ -36,6 +36,50 @@ class OperatorTest : public OperatorBase {
float
x
=
0
;
};
class
OperatorTestProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
public:
OperatorTestProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"input"
,
"input of test op"
);
AddOutput
(
"output"
,
"output of test op"
);
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
)
.
SetDefault
(
1.0
)
.
LargerThan
(
0.0
);
AddComment
(
"This is test op"
);
}
};
}
// namespace framework
}
// namespace paddle
REGISTER_OP
(
test_operator
,
paddle
::
framework
::
OperatorTest
,
paddle
::
framework
::
OperatorTestProtoAndCheckerMaker
);
TEST
(
OperatorBase
,
all
)
{
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"test_operator"
);
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"IN1"
;
*
op_desc
.
mutable_outputs
()
->
Add
()
=
"OUT1"
;
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
float
scale
=
3.14
;
attr
->
set_f
(
scale
);
paddle
::
platform
::
CPUDeviceContext
device_context
;
auto
scope
=
std
::
make_shared
<
paddle
::
framework
::
Scope
>
();
paddle
::
framework
::
OperatorPtr
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
ASSERT_EQ
(
op
->
GetAttr
<
float
>
(
"scale"
),
scale
);
scope
->
CreateVariable
(
"OUT1"
);
op
->
Run
(
scope
,
device_context
);
std
::
cout
<<
op
->
DebugString
()
<<
std
::
endl
;
}
namespace
paddle
{
namespace
framework
{
class
OpKernelTestProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
public:
OpKernelTestProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
...
...
@@ -73,9 +117,7 @@ REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest,
REGISTER_OP_CPU_KERNEL
(
op_with_kernel
,
paddle
::
framework
::
CPUKernelTest
);
TEST
(
OpKernel
,
all
)
{
using
namespace
paddle
::
framework
;
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"op_with_kernel"
);
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"IN1"
;
*
op_desc
.
mutable_outputs
()
->
Add
()
=
"OUT1"
;
...
...
@@ -85,10 +127,9 @@ TEST(OpKernel, all) {
attr
->
set_f
(
3.14
);
paddle
::
platform
::
CPUDeviceContext
cpu_device_context
;
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
scope
=
std
::
make_shared
<
paddle
::
framework
::
Scope
>
();
OperatorBase
*
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
paddle
::
framework
::
OperatorPtr
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
op
->
Run
(
scope
,
cpu_device_context
);
delete
op
;
}
paddle/framework/scope.h
浏览文件 @
58f3de95
...
...
@@ -23,6 +23,9 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
class
Scope
;
using
ScopePtr
=
std
::
shared_ptr
<
Scope
>
;
/**
* @brief Scope that manage all variables.
*
...
...
@@ -41,7 +44,7 @@ class Scope {
/**
* @brief Initialize a Scope with parent.
*/
explicit
Scope
(
const
std
::
shared_ptr
<
Scope
>
&
parent
)
:
parent_
(
parent
)
{}
explicit
Scope
(
const
ScopePtr
&
parent
)
:
parent_
(
parent
)
{}
/**
* @brief Create Variable
...
...
@@ -88,7 +91,7 @@ class Scope {
private:
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
Variable
>>
vars_
;
std
::
shared_ptr
<
Scope
>
parent_
{
nullptr
};
ScopePtr
parent_
{
nullptr
};
};
}
// namespace framework
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录