Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
816b4c8a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
816b4c8a
编写于
7月 18, 2017
作者:
D
dongzhihong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
"add backward Op"
上级
83f263e6
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
226 addition
and
40 deletion
+226
-40
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+3
-0
paddle/framework/fully_connected_op.cc
paddle/framework/fully_connected_op.cc
+39
-0
paddle/framework/fully_connected_op.h
paddle/framework/fully_connected_op.h
+52
-0
paddle/framework/net.cc
paddle/framework/net.cc
+14
-0
paddle/framework/net.h
paddle/framework/net.h
+2
-0
paddle/framework/net_op_test.cc
paddle/framework/net_op_test.cc
+66
-38
paddle/framework/net_test.cc
paddle/framework/net_test.cc
+4
-1
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+46
-1
未找到文件。
paddle/framework/CMakeLists.txt
浏览文件 @
816b4c8a
...
@@ -15,6 +15,8 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
...
@@ -15,6 +15,8 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library
(
operator SRCS operator.cc DEPS op_desc device_context
)
cc_library
(
operator SRCS operator.cc DEPS op_desc device_context
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry
)
# cc_library(fc_op SRCS fully_connected_op.cc DEPS operator)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_proto op_desc
)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_proto op_desc
)
cc_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry operator
)
cc_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry operator
)
py_proto_compile
(
framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto
)
py_proto_compile
(
framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto
)
...
@@ -23,5 +25,6 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch
...
@@ -23,5 +25,6 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch
add_dependencies
(
framework_py_proto framework_py_proto_init
)
add_dependencies
(
framework_py_proto framework_py_proto_init
)
proto_library
(
net_proto SRCS net_proto.proto DEPS op_proto
)
proto_library
(
net_proto SRCS net_proto.proto DEPS op_proto
)
# cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op)
cc_library
(
net SRCS net.cc DEPS operator net_proto op_registry
)
cc_library
(
net SRCS net.cc DEPS operator net_proto op_registry
)
cc_test
(
net_op_test SRCS net_op_test.cc DEPS net
)
cc_test
(
net_op_test SRCS net_op_test.cc DEPS net
)
paddle/framework/fully_connected_op.cc
0 → 100644
浏览文件 @
816b4c8a
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/fully_connected_op.h"
#include <iostream>
namespace
paddle
{
namespace
framework
{
void
FCOp
::
Run
(
const
ScopePtr
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
std
::
cout
<<
"FC"
<<
std
::
endl
;
}
void
FCOp
::
InferShape
(
const
ScopePtr
&
scope
)
const
override
{}
void
FCGradientOp
::
Run
(
const
ScopePtr
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
std
::
cout
<<
"FCGrad"
<<
std
::
endl
;
}
void
FCGradientOp
::
InferShape
(
const
ScopePtr
&
scope
)
const
override
{}
REGISTER_OP
(
my_fc
,
paddle
::
framework
::
FCOp
,
paddle
::
framework
::
FCOpProtoAndCheckerMaker
);
REGISTER_OP
(
my_fc_grad
,
paddle
::
framework
::
FCGradientOp
,
paddle
::
framework
::
FCGradientOpProtoAndCheckerMaker
);
}
// namespace framework
}
// namespace paddle
paddle/framework/fully_connected_op.h
0 → 100644
浏览文件 @
816b4c8a
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
framework
{
class
FCOp
:
public
OperatorBase
{
public:
void
Run
(
const
ScopePtr
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
std
::
cout
<<
"FC"
<<
std
::
endl
;
};
void
InferShape
(
const
ScopePtr
&
scope
)
const
override
{};
};
class
FCOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
public:
FCOpProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"x"
,
"input data"
);
AddInput
(
"w"
,
"weights"
);
AddInput
(
"b"
,
"bias"
);
AddOutput
(
"y"
,
"output data"
);
AddComment
(
"Fully connnect op"
);
}
};
class
FCGradientOp
:
public
OperatorBase
{
void
Run
(
const
ScopePtr
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
std
::
cout
<<
"FCGrad"
<<
std
::
endl
;
};
void
InferShape
(
const
ScopePtr
&
scope
)
const
override
{};
};
// class FCGradientOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {};
}
// namespace framework
}
// namespace paddle
paddle/framework/net.cc
浏览文件 @
816b4c8a
...
@@ -15,10 +15,24 @@
...
@@ -15,10 +15,24 @@
*/
*/
#include "paddle/framework/net.h"
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
std
::
shared_ptr
<
PlainNet
>
AddBackwardOp
(
std
::
shared_ptr
<
PlainNet
>
ForwardOps
)
{
// NetPtr->reset(new PlainNet);
// NetPtr grad_ops = new PlainNet;
std
::
shared_ptr
<
PlainNet
>
grad_ops
;
grad_ops
.
reset
(
new
PlainNet
);
for
(
auto
&
op
:
ForwardOps
->
ops_
)
{
auto
op_grad
=
OpRegistry
::
CreateGradOp
(
op
);
grad_ops
->
AddOp
(
op_grad
);
}
grad_ops
->
CompleteAddOp
();
return
grad_ops
;
}
void
PlainNet
::
CompleteAddOp
()
{
void
PlainNet
::
CompleteAddOp
()
{
std
::
unordered_set
<
std
::
string
>
input_set
;
std
::
unordered_set
<
std
::
string
>
input_set
;
std
::
unordered_set
<
std
::
string
>
output_set
;
std
::
unordered_set
<
std
::
string
>
output_set
;
...
...
paddle/framework/net.h
浏览文件 @
816b4c8a
...
@@ -99,5 +99,7 @@ class PlainNet : public Net {
...
@@ -99,5 +99,7 @@ class PlainNet : public Net {
}
}
};
};
std
::
shared_ptr
<
PlainNet
>
AddBackwardOp
(
std
::
shared_ptr
<
PlainNet
>
ForwardOps
);
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/net_op_test.cc
浏览文件 @
816b4c8a
...
@@ -3,18 +3,17 @@
...
@@ -3,18 +3,17 @@
#include <paddle/framework/op_registry.h>
#include <paddle/framework/op_registry.h>
#include <paddle/framework/operator.h>
#include <paddle/framework/operator.h>
namespace
pd
=
paddle
::
framework
;
namespace
paddle
{
namespace
framework
{
static
int
infer_shape_cnt
=
0
;
static
int
infer_shape_cnt
=
0
;
static
int
run_cnt
=
0
;
static
int
run_cnt
=
0
;
class
TestOp
:
public
pd
::
OperatorBase
{
class
TestOp
:
public
OperatorBase
{
public:
public:
void
InferShape
(
const
paddle
::
framework
::
ScopePtr
&
scope
)
const
override
{
void
InferShape
(
const
ScopePtr
&
scope
)
const
override
{
++
infer_shape_cnt
;
}
++
infer_shape_cnt
;
void
Run
(
const
ScopePtr
&
scope
,
}
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
void
Run
(
const
paddle
::
framework
::
ScopePtr
&
scope
,
const
paddle
::
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
++
run_cnt
;
++
run_cnt
;
}
}
};
};
...
@@ -32,36 +31,65 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
...
@@ -32,36 +31,65 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
}
}
}
}
TEST
(
OpKernel
,
all
)
{
class
PlainNetTest
:
public
testing
::
Test
{
auto
net
=
std
::
make_shared
<
paddle
::
framework
::
PlainNet
>
();
virtual
void
SetUp
()
{
ASSERT_NE
(
net
,
nullptr
);
net_
=
std
::
make_shared
<
PlainNet
>
();
ASSERT_NE
(
net_
,
nullptr
);
auto
op1
=
std
::
make_shared
<
TestOp
>
();
auto
op1
=
std
::
make_shared
<
TestOp
>
();
op1
->
inputs_
=
{
"x"
,
"w1"
,
"b1"
};
op1
->
inputs_
=
{
"x"
,
"w1"
,
"b1"
};
op1
->
outputs_
=
{
"y"
};
op1
->
outputs_
=
{
"y"
};
net
->
AddOp
(
op1
);
net_
->
AddOp
(
op1
);
auto
op2
=
std
::
make_shared
<
TestOp
>
();
auto
op2
=
std
::
make_shared
<
TestOp
>
();
op2
->
inputs_
=
{
"y"
,
"w2"
,
"b2"
};
op2
->
inputs_
=
{
"y"
,
"w2"
,
"b2"
};
op2
->
outputs_
=
{
"z"
};
op2
->
outputs_
=
{
"z"
};
net
->
AddOp
(
op2
);
net_
->
AddOp
(
op2
);
net_
->
CompleteAddOp
();
}
virtual
void
TearDown
()
{}
net
->
CompleteAddOp
();
void
TestOpKernel
()
{
AssertSameVectorWithoutOrder
({
"x"
,
"w1"
,
"b1"
,
"w2"
,
"b2"
},
net
->
inputs_
);
AssertSameVectorWithoutOrder
({
"x"
,
"w1"
,
"b1"
,
"w2"
,
"b2"
},
net_
->
inputs_
);
AssertSameVectorWithoutOrder
({
"y"
,
"z"
},
net
->
outputs_
);
AssertSameVectorWithoutOrder
({
"y"
,
"z"
},
net_
->
outputs_
);
auto
tmp_idx_iter
=
net
->
attrs_
.
find
(
"temporary_index"
);
auto
tmp_idx_iter
=
net_
->
attrs_
.
find
(
"temporary_index"
);
ASSERT_NE
(
net
->
attrs_
.
end
(),
tmp_idx_iter
);
ASSERT_NE
(
net_
->
attrs_
.
end
(),
tmp_idx_iter
);
auto
&
tmp_idx
=
boost
::
get
<
std
::
vector
<
int
>>
(
tmp_idx_iter
->
second
);
auto
&
tmp_idx
=
boost
::
get
<
std
::
vector
<
int
>>
(
tmp_idx_iter
->
second
);
ASSERT_EQ
(
1UL
,
tmp_idx
.
size
());
ASSERT_EQ
(
1UL
,
tmp_idx
.
size
());
ASSERT_EQ
(
"y"
,
net
->
outputs_
[
tmp_idx
[
0
]]);
ASSERT_EQ
(
"y"
,
net_
->
outputs_
[
tmp_idx
[
0
]]);
auto
scope
=
std
::
make_shared
<
pd
::
Scope
>
();
auto
scope
=
std
::
make_shared
<
Scope
>
();
paddle
::
platform
::
CPUDeviceContext
dev_ctx
;
platform
::
CPUDeviceContext
dev_ctx
;
net
->
InferShape
(
scope
);
net_
->
InferShape
(
scope
);
net
->
Run
(
scope
,
dev_ctx
);
net_
->
Run
(
scope
,
dev_ctx
);
ASSERT_EQ
(
2
,
infer_shape_cnt
);
ASSERT_EQ
(
2
,
infer_shape_cnt
);
ASSERT_EQ
(
2
,
run_cnt
);
ASSERT_EQ
(
2
,
run_cnt
);
ASSERT_THROW
(
net
->
AddOp
(
op2
),
paddle
::
framework
::
EnforceNotMet
);
ASSERT_THROW
(
net_
->
AddOp
(
op2
),
EnforceNotMet
);
}
void
TestAddBackwardOp
()
{
auto
grad_ops
=
AddBackwardOp
(
net_
);
for
(
auto
&
op
:
grad_ops
->
ops_
)
{
op
->
DebugString
();
}
}
private:
std
::
shared_ptr
<
PlainNet
>
net_
;
};
TEST
(
OpKernel
,
all
)
{
PlainNetTest
net
;
net
->
TestOpKernel
();
}
}
TEST
(
AddBackwardOp
,
TestAddBackwardOp
)
{
PlainNetTest
net
;
net
->
TestAddBackwardOp
();
}
}
// namespace framework
}
// namespace paddle
paddle/framework/net_test.cc
浏览文件 @
816b4c8a
...
@@ -13,12 +13,15 @@
...
@@ -13,12 +13,15 @@
limitations under the License. */
limitations under the License. */
#include "paddle/framework/net.h"
#include "paddle/framework/net.h"
#include "paddle/framework/fully_connected_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
#include <gtest/gtest.h>
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
FakeFC
:
public
Operator
{}
TEST
(
AddBackwardOp
,
ALL
)
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/op_registry.h
浏览文件 @
816b4c8a
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -188,8 +189,8 @@ class OpRegistry {
...
@@ -188,8 +189,8 @@ class OpRegistry {
template
<
typename
OpType
,
typename
ProtoMakerType
>
template
<
typename
OpType
,
typename
ProtoMakerType
>
static
void
RegisterOp
(
const
std
::
string
&
op_type
)
{
static
void
RegisterOp
(
const
std
::
string
&
op_type
)
{
creators
()[
op_type
]
=
[]
{
return
new
OpType
;
};
creators
()[
op_type
]
=
[]
{
return
new
OpType
;
};
OpProto
&
op_proto
=
protos
()[
op_type
];
OpAttrChecker
&
op_checker
=
op_checkers
()[
op_type
];
OpAttrChecker
&
op_checker
=
op_checkers
()[
op_type
];
OpProto
&
op_proto
=
protos
()[
op_type
];
ProtoMakerType
(
&
op_proto
,
&
op_checker
);
ProtoMakerType
(
&
op_proto
,
&
op_checker
);
*
op_proto
.
mutable_type
()
=
op_type
;
*
op_proto
.
mutable_type
()
=
op_type
;
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
...
@@ -198,6 +199,11 @@ class OpRegistry {
...
@@ -198,6 +199,11 @@ class OpRegistry {
op_type
,
op_proto
.
InitializationErrorString
());
op_type
,
op_proto
.
InitializationErrorString
());
}
}
template
<
typename
OpType
>
static
void
RegisterGradOp
(
const
std
::
string
&
op_type
)
{
grad_creators
()[
op_type
]
=
[]
{
return
new
OpType
;
};
}
static
OperatorPtr
CreateOp
(
const
OpDesc
&
op_desc
)
{
static
OperatorPtr
CreateOp
(
const
OpDesc
&
op_desc
)
{
std
::
string
op_type
=
op_desc
.
type
();
std
::
string
op_type
=
op_desc
.
type
();
OperatorPtr
op
(
creators
().
at
(
op_type
)());
OperatorPtr
op
(
creators
().
at
(
op_type
)());
...
@@ -216,6 +222,21 @@ class OpRegistry {
...
@@ -216,6 +222,21 @@ class OpRegistry {
return
op
;
return
op
;
}
}
static
OperatorPtr
CreateGradOp
(
std
::
shared_ptr
<
OperatorBase
>
op
)
{
OperatorPtr
op_grad
(
grad_creators
().
at
(
op
->
type_
)());
op_grad
->
type_
=
op
->
type_
;
op_grad
->
inputs_
.
reserve
(
op
->
inputs_
.
size
());
for
(
auto
&
input
:
op
->
inputs_
)
{
op_grad
->
inputs_
.
emplace_back
(
input
);
op_grad
->
outputs_
.
emplace_back
(
input
+
"@grad"
);
}
for
(
auto
&
output
:
op
->
outputs_
)
{
op_grad
->
inputs_
.
emplace_back
(
output
);
op_grad
->
inputs_
.
emplace_back
(
output
+
"@grad"
);
}
return
op_grad
;
}
static
std
::
unordered_map
<
std
::
string
,
OpProto
>&
protos
()
{
static
std
::
unordered_map
<
std
::
string
,
OpProto
>&
protos
()
{
static
std
::
unordered_map
<
std
::
string
,
OpProto
>
protos_
;
static
std
::
unordered_map
<
std
::
string
,
OpProto
>
protos_
;
return
protos_
;
return
protos_
;
...
@@ -231,6 +252,11 @@ class OpRegistry {
...
@@ -231,6 +252,11 @@ class OpRegistry {
static
std
::
unordered_map
<
std
::
string
,
OpAttrChecker
>
op_checkers_
;
static
std
::
unordered_map
<
std
::
string
,
OpAttrChecker
>
op_checkers_
;
return
op_checkers_
;
return
op_checkers_
;
};
};
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>&
grad_creators
()
{
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>
grad_creators_
;
return
grad_creators_
;
}
};
};
template
<
typename
OpType
,
typename
ProtoMakerType
>
template
<
typename
OpType
,
typename
ProtoMakerType
>
...
@@ -241,6 +267,14 @@ class OpRegisterHelper {
...
@@ -241,6 +267,14 @@ class OpRegisterHelper {
}
}
};
};
template
<
typename
OpType
>
class
GradOpRegisterHelper
{
public:
GradOpRegisterHelper
(
const
char
*
op_type
)
{
OpRegistry
::
RegisterGradOp
<
OpType
>
(
op_type
);
}
};
/**
/**
* check if MACRO is used in GLOBAL NAMESPACE.
* check if MACRO is used in GLOBAL NAMESPACE.
*/
*/
...
@@ -260,6 +294,17 @@ class OpRegisterHelper {
...
@@ -260,6 +294,17 @@ class OpRegisterHelper {
__op_register_##__op_type##__(#__op_type); \
__op_register_##__op_type##__(#__op_type); \
int __op_register_##__op_type##_handle__() { return 0; }
int __op_register_##__op_type##_handle__() { return 0; }
/**
* Macro to Register Operator.
*/
#define REGISTER_GRADIENT_OP(__op_type, __op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##__op_type, \
"REGISTER_GRADIENT_OP must be in global namespace"); \
static ::paddle::framework::GradOpRegisterHelper<__op_class> \
__op_register_##__op_type##__(#__op_type); \
int __op_register_##__op_type##_handle__() { return 0; }
/**
/**
* Macro to Register OperatorKernel.
* Macro to Register OperatorKernel.
*/
*/
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录