Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
b1b43645
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b1b43645
编写于
7月 26, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rename PlainNet --> NetOp
上级
ef7e76fc
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
44 addition
and
88 deletion
+44
-88
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+1
-3
paddle/framework/net.cc
paddle/framework/net.cc
+4
-12
paddle/framework/net.h
paddle/framework/net.h
+5
-19
paddle/framework/net_op_test.cc
paddle/framework/net_op_test.cc
+15
-22
paddle/framework/net_proto.proto
paddle/framework/net_proto.proto
+0
-15
paddle/framework/operator.h
paddle/framework/operator.h
+8
-6
paddle/operators/fc_op.cc
paddle/operators/fc_op.cc
+1
-1
paddle/operators/type_alias.h
paddle/operators/type_alias.h
+1
-1
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+9
-9
未找到文件。
paddle/framework/CMakeLists.txt
浏览文件 @
b1b43645
...
@@ -29,7 +29,5 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.
...
@@ -29,7 +29,5 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.
add_custom_target
(
framework_py_proto_init ALL COMMAND
${
CMAKE_COMMAND
}
-E touch __init__.py
)
add_custom_target
(
framework_py_proto_init ALL COMMAND
${
CMAKE_COMMAND
}
-E touch __init__.py
)
add_dependencies
(
framework_py_proto framework_py_proto_init
)
add_dependencies
(
framework_py_proto framework_py_proto_init
)
proto_library
(
net_proto SRCS net_proto.proto DEPS op_proto
)
cc_library
(
net SRCS net.cc DEPS op_registry
)
# cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op)
cc_library
(
net SRCS net.cc DEPS operator net_proto op_registry
)
cc_test
(
net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op
)
cc_test
(
net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op
)
paddle/framework/net.cc
浏览文件 @
b1b43645
...
@@ -20,17 +20,7 @@
...
@@ -20,17 +20,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
std
::
shared_ptr
<
PlainNet
>
AddBackwardOp
(
std
::
shared_ptr
<
PlainNet
>
ForwardOps
)
{
void
NetOp
::
CompleteAddOp
(
bool
calc
)
{
auto
grad_ops
=
std
::
make_shared
<
PlainNet
>
();
for
(
auto
&
op
:
ForwardOps
->
ops_
)
{
auto
op_grad
=
OpRegistry
::
CreateGradOp
(
op
);
grad_ops
->
AddOp
(
op_grad
);
}
grad_ops
->
CompleteAddOp
();
return
grad_ops
;
}
void
PlainNet
::
CompleteAddOp
(
bool
calc
)
{
add_op_done_
=
true
;
add_op_done_
=
true
;
if
(
!
calc
)
return
;
if
(
!
calc
)
return
;
std
::
unordered_set
<
std
::
string
>
input_set
;
std
::
unordered_set
<
std
::
string
>
input_set
;
...
@@ -70,7 +60,7 @@ void PlainNet::CompleteAddOp(bool calc) {
...
@@ -70,7 +60,7 @@ void PlainNet::CompleteAddOp(bool calc) {
attrs_
[
"temporary_index"
]
=
tmp_index
;
attrs_
[
"temporary_index"
]
=
tmp_index
;
}
}
std
::
string
PlainNet
::
DebugString
()
const
{
std
::
string
NetOp
::
DebugString
()
const
{
std
::
ostringstream
os
;
std
::
ostringstream
os
;
os
<<
OperatorBase
::
DebugString
()
<<
std
::
endl
;
os
<<
OperatorBase
::
DebugString
()
<<
std
::
endl
;
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
...
@@ -82,5 +72,7 @@ std::string PlainNet::DebugString() const {
...
@@ -82,5 +72,7 @@ std::string PlainNet::DebugString() const {
return
os
.
str
();
return
os
.
str
();
}
}
bool
NetOp
::
IsNetOp
()
const
{
return
true
;
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/net.h
浏览文件 @
b1b43645
...
@@ -37,21 +37,7 @@ namespace framework {
...
@@ -37,21 +37,7 @@ namespace framework {
* This is the base class of network, all the networks should implement the APIs
* This is the base class of network, all the networks should implement the APIs
* it defines.
* it defines.
*/
*/
class
Net
:
public
OperatorBase
{
class
NetOp
:
public
OperatorBase
{
public:
virtual
void
AddOp
(
const
std
::
shared_ptr
<
OperatorBase
>&
op
)
=
0
;
virtual
void
CompleteAddOp
(
bool
calc
)
=
0
;
};
using
NetPtr
=
std
::
shared_ptr
<
Net
>
;
/**
* @brief a basic implementation of Net.
*
* PlainNet is a very simple Net, it create a list of operators, and run them
* sequentially following the order they added.
*/
class
PlainNet
:
public
Net
{
public:
public:
/**
/**
* Infer all the operators' input and output variables' shapes, will be called
* Infer all the operators' input and output variables' shapes, will be called
...
@@ -80,15 +66,17 @@ class PlainNet : public Net {
...
@@ -80,15 +66,17 @@ class PlainNet : public Net {
/**
/**
* @brief Add an operator by ptr
* @brief Add an operator by ptr
*/
*/
void
AddOp
(
const
std
::
shared_ptr
<
OperatorBase
>&
op
)
override
{
void
AddOp
(
const
std
::
shared_ptr
<
OperatorBase
>&
op
)
{
PADDLE_ENFORCE
(
!
add_op_done_
,
"Cannot AddOp when this network is sealed"
);
PADDLE_ENFORCE
(
!
add_op_done_
,
"Cannot AddOp when this network is sealed"
);
ops_
.
push_back
(
op
);
ops_
.
push_back
(
op
);
}
}
void
CompleteAddOp
(
bool
calculate
=
true
)
override
;
void
CompleteAddOp
(
bool
calculate
=
true
);
std
::
string
DebugString
()
const
override
;
std
::
string
DebugString
()
const
override
;
bool
IsNetOp
()
const
override
;
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
>>
ops_
;
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
>>
ops_
;
private:
private:
...
@@ -100,7 +88,5 @@ class PlainNet : public Net {
...
@@ -100,7 +88,5 @@ class PlainNet : public Net {
}
}
};
};
std
::
shared_ptr
<
PlainNet
>
AddBackwardOp
(
std
::
shared_ptr
<
PlainNet
>
ForwardOps
);
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/net_op_test.cc
浏览文件 @
b1b43645
...
@@ -40,7 +40,7 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
...
@@ -40,7 +40,7 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
}
}
TEST
(
OpKernel
,
all
)
{
TEST
(
OpKernel
,
all
)
{
auto
net
=
std
::
make_shared
<
PlainNet
>
();
auto
net
=
std
::
make_shared
<
NetOp
>
();
ASSERT_NE
(
net
,
nullptr
);
ASSERT_NE
(
net
,
nullptr
);
auto
op1
=
std
::
make_shared
<
TestOp
>
();
auto
op1
=
std
::
make_shared
<
TestOp
>
();
...
@@ -71,28 +71,21 @@ TEST(OpKernel, all) {
...
@@ -71,28 +71,21 @@ TEST(OpKernel, all) {
ASSERT_EQ
(
2
,
run_cnt
);
ASSERT_EQ
(
2
,
run_cnt
);
ASSERT_THROW
(
net
->
AddOp
(
op2
),
paddle
::
platform
::
EnforceNotMet
);
ASSERT_THROW
(
net
->
AddOp
(
op2
),
paddle
::
platform
::
EnforceNotMet
);
}
}
TEST
(
AddBackwardOp
,
TestGradOp
)
{
auto
net
=
std
::
make_shared
<
PlainNet
>
();
ASSERT_NE
(
net
,
nullptr
);
net
->
AddOp
(
framework
::
OpRegistry
::
CreateOp
(
"mul"
,
{
"X"
,
"Y"
},
{
"Out"
},
{}));
net
->
AddOp
(
framework
::
OpRegistry
::
CreateOp
(
"add_two"
,
{
"X"
,
"Y"
},
{
"Out"
},
{}));
net
->
AddOp
(
framework
::
OpRegistry
::
CreateOp
(
"add_two"
,
{
"X"
,
"Y"
},
{
""
},
{}));
auto
grad_ops
=
AddBackwardOp
(
net
);
for
(
auto
&
op
:
grad_ops
->
ops_
)
{
op
->
DebugString
();
}
}
//
TODO(zhihong): add fc grad without registering
.
//
! TODO(yuyang18): Refine Backward Op
.
// TEST(AddBackwardOp, Test
No
GradOp) {
// TEST(AddBackwardOp, TestGradOp) {
//
auto net = std::make_shared<PlainNet
>();
//
auto net = std::make_shared<NetOp
>();
// ASSERT_NE(net, nullptr);
// ASSERT_NE(net, nullptr);
// net->AddOp(framework::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Y"},
// net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {}));
// {})); auto grad_ops = AddBackwardOp(net); for (auto& op : grad_ops->ops_) {
// net->AddOp(
// framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {}));
// net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""},
// {}));
// auto grad_ops = AddBackwardOp(net);
// for (auto& op : grad_ops->ops_) {
// op->DebugString();
// op->DebugString();
// }
// }
//
}
//}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/net_proto.proto
已删除
100644 → 0
浏览文件 @
ef7e76fc
syntax
=
"proto2"
;
package
paddle
.
framework
;
import
"op_proto.proto"
;
message
NetDesc
{
// network identification
optional
string
name
=
1
;
// operator contains in network
repeated
OpProto
operators
=
2
;
// network type to run with. e.g "plainNet", "DAG"
optional
string
net_type
=
3
;
// num worker always
optional
int32
num_workers
=
4
;
}
paddle/framework/operator.h
浏览文件 @
b1b43645
...
@@ -90,15 +90,17 @@ class OperatorBase {
...
@@ -90,15 +90,17 @@ class OperatorBase {
virtual
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
virtual
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
=
0
;
const
platform
::
DeviceContext
&
dev_ctx
)
const
=
0
;
// Get a input with argument's name described in `op_proto`
virtual
bool
IsNetOp
()
const
{
return
false
;
}
//! Get a input with argument's name described in `op_proto`
const
std
::
string
&
Input
(
const
std
::
string
&
name
)
const
;
const
std
::
string
&
Input
(
const
std
::
string
&
name
)
const
;
// Get a input which has multiple variables.
//
!
Get a input which has multiple variables.
// TODO add a vector_view to prevent memory copy.
//
!
TODO add a vector_view to prevent memory copy.
std
::
vector
<
std
::
string
>
Inputs
(
const
std
::
string
&
name
)
const
;
std
::
vector
<
std
::
string
>
Inputs
(
const
std
::
string
&
name
)
const
;
// Get a output with argument's name described in `op_proto`
//
!
Get a output with argument's name described in `op_proto`
const
std
::
string
&
Output
(
const
std
::
string
&
name
)
const
;
const
std
::
string
&
Output
(
const
std
::
string
&
name
)
const
;
// Get an output which has multiple variables.
//
!
Get an output which has multiple variables.
// TODO add a vector_view to prevent memory copy.
//
!
TODO add a vector_view to prevent memory copy.
std
::
vector
<
std
::
string
>
Outputs
(
const
std
::
string
&
name
)
const
;
std
::
vector
<
std
::
string
>
Outputs
(
const
std
::
string
&
name
)
const
;
public:
public:
...
...
paddle/operators/fc_op.cc
浏览文件 @
b1b43645
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
class
FullyConnectedOp
:
public
PlainNet
{
class
FullyConnectedOp
:
public
NetOp
{
public:
public:
void
Init
()
override
{
void
Init
()
override
{
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
...
...
paddle/operators/type_alias.h
浏览文件 @
b1b43645
...
@@ -43,7 +43,7 @@ using OpProto = framework::OpProto;
...
@@ -43,7 +43,7 @@ using OpProto = framework::OpProto;
using
OpAttrChecker
=
framework
::
OpAttrChecker
;
using
OpAttrChecker
=
framework
::
OpAttrChecker
;
using
CPUPlace
=
platform
::
CPUPlace
;
using
CPUPlace
=
platform
::
CPUPlace
;
using
GPUPlace
=
platform
::
GPUPlace
;
using
GPUPlace
=
platform
::
GPUPlace
;
using
PlainNet
=
framework
::
PlainNet
;
using
NetOp
=
framework
::
NetOp
;
using
OpRegistry
=
framework
::
OpRegistry
;
using
OpRegistry
=
framework
::
OpRegistry
;
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
...
...
paddle/pybind/pybind.cc
浏览文件 @
b1b43645
...
@@ -146,22 +146,22 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -146,22 +146,22 @@ All parameter, weight, gradient are variables in Paddle.
});
});
ExposeOperator
(
operator_base
);
ExposeOperator
(
operator_base
);
using
PlainNetPtr
=
std
::
shared_ptr
<
pd
::
PlainNet
>
;
py
::
class_
<
pd
::
NetOp
,
std
::
shared_ptr
<
pd
::
NetOp
>>
net
(
m
,
"Net"
);
py
::
class_
<
pd
::
PlainNet
,
PlainNetPtr
>
net
(
m
,
"Net"
);
net
.
def_static
(
"create"
,
net
.
def_static
(
"create"
,
[]()
->
std
::
shared_ptr
<
pd
::
PlainNet
>
{
[]()
->
std
::
shared_ptr
<
pd
::
NetOp
>
{
auto
retv
=
std
::
make_shared
<
pd
::
PlainNet
>
();
auto
retv
=
std
::
make_shared
<
pd
::
NetOp
>
();
retv
->
type_
=
"plain_net"
;
retv
->
type_
=
"plain_net"
;
return
retv
;
return
retv
;
})
})
.
def
(
"add_op"
,
&
pd
::
PlainNet
::
AddOp
)
.
def
(
"add_op"
,
&
pd
::
NetOp
::
AddOp
)
.
def
(
"add_op"
,
.
def
(
"add_op"
,
[](
PlainNetPtr
&
self
,
const
PlainNetPtr
&
net
)
->
void
{
[](
pd
::
NetOp
&
self
,
const
std
::
shared_ptr
<
pd
::
NetOp
>
&
net
)
->
void
{
self
->
AddOp
(
std
::
static_pointer_cast
<
pd
::
OperatorBase
>
(
net
));
self
.
AddOp
(
std
::
static_pointer_cast
<
pd
::
OperatorBase
>
(
net
));
})
})
.
def
(
"complete_add_op"
,
&
pd
::
PlainNet
::
CompleteAddOp
)
.
def
(
"complete_add_op"
,
&
pd
::
NetOp
::
CompleteAddOp
)
.
def
(
"complete_add_op"
,
[](
PlainNetPtr
&
self
)
{
self
->
CompleteAddOp
();
});
.
def
(
"complete_add_op"
,
[](
std
::
shared_ptr
<
pd
::
NetOp
>&
self
)
{
self
->
CompleteAddOp
();
});
ExposeOperator
(
net
);
ExposeOperator
(
net
);
m
.
def
(
"unique_integer"
,
UniqueIntegerGenerator
);
m
.
def
(
"unique_integer"
,
UniqueIntegerGenerator
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录