Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
72b5bd93
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
72b5bd93
编写于
7月 25, 2017
作者:
F
fengjiayi
提交者:
GitHub
7月 25, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3036 from Canpio/dev_update_backward
update gradient operator registry mechanism
上级
91689b6b
e8a0e92b
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
55 addition
and
49 deletion
+55
-49
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+3
-3
paddle/framework/grad_op_builder.cc
paddle/framework/grad_op_builder.cc
+9
-8
paddle/framework/grad_op_builder.h
paddle/framework/grad_op_builder.h
+3
-3
paddle/framework/grad_op_builder_test.cc
paddle/framework/grad_op_builder_test.cc
+2
-2
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+31
-26
paddle/operators/add_op.cc
paddle/operators/add_op.cc
+1
-1
paddle/operators/add_op_test.cc
paddle/operators/add_op_test.cc
+3
-3
paddle/operators/mul_op.cc
paddle/operators/mul_op.cc
+1
-1
paddle/operators/sigmoid_op.cc
paddle/operators/sigmoid_op.cc
+1
-1
paddle/operators/softmax_op.cc
paddle/operators/softmax_op.cc
+1
-1
未找到文件。
paddle/framework/CMakeLists.txt
浏览文件 @
72b5bd93
...
...
@@ -19,10 +19,10 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library
(
operator SRCS operator.cc DEPS op_desc device_context tensor
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry
)
cc_library
(
grad_op_
creator SRCS grad_op_creato
r.cc DEPS op_proto operator
)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_desc grad_op_
creato
r
)
cc_library
(
grad_op_
builder SRCS grad_op_builde
r.cc DEPS op_proto operator
)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_desc grad_op_
builde
r
)
cc_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry
)
cc_test
(
grad_op_
creator_test SRCS grad_op_creator_test.cc DEPS grad_op_creato
r op_registry add_op
)
cc_test
(
grad_op_
builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builde
r op_registry add_op
)
py_proto_compile
(
framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto
)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
...
...
paddle/framework/grad_op_
creato
r.cc
→
paddle/framework/grad_op_
builde
r.cc
浏览文件 @
72b5bd93
...
...
@@ -12,20 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/grad_op_
creato
r.h"
#include "paddle/framework/grad_op_
builde
r.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
framework
{
OperatorBase
*
GradOp
Creator
::
Create
()
{
OperatorBase
*
GradOp
Builder
::
Build
()
{
BuildOpInOutArgList
();
OperatorBase
*
grad_op
=
OpRegistry
::
grad_creators
().
at
(
op_
->
type_
)();
std
::
string
grad_op_type
=
OpRegistry
::
grad_ops
().
at
(
op_
->
type_
);
OperatorBase
*
grad_op
=
OpRegistry
::
op_creators
().
at
(
grad_op_type
)();
grad_op
->
type_
=
grad_op_type
;
CompleteGradOp
(
grad_op
);
return
grad_op
;
}
OpInOutArg
*
GradOp
Creato
r
::
BuildArg
(
const
VarProto
&
var
,
OpInOutArg
*
GradOp
Builde
r
::
BuildArg
(
const
VarProto
&
var
,
const
VarIndexMap
&
var_map
,
const
std
::
vector
<
int
>&
format
,
InOutType
type
)
{
...
...
@@ -36,7 +38,7 @@ OpInOutArg* GradOpCreator::BuildArg(const VarProto& var,
end_idx
);
}
void
GradOp
Creato
r
::
BuildOpInOutArgList
()
{
void
GradOp
Builde
r
::
BuildOpInOutArgList
()
{
const
OpProto
&
op_proto
=
OpRegistry
::
protos
().
at
(
op_
->
type_
);
const
auto
&
var_map
=
*
(
OpRegistry
::
VarIndexMaps
().
at
(
op_
->
type_
));
const
std
::
vector
<
int
>&
in_format
=
...
...
@@ -57,7 +59,7 @@ void GradOpCreator::BuildOpInOutArgList() {
}
}
void
GradOp
Creato
r
::
AddArgIntoGradOp
(
const
OpInOutArg
*
arg
,
void
GradOp
Builde
r
::
AddArgIntoGradOp
(
const
OpInOutArg
*
arg
,
std
::
vector
<
std
::
string
>&
in_out
,
std
::
vector
<
int
>&
format
,
VarIndexMap
*
varmap
,
int
&
idx
,
...
...
@@ -80,8 +82,7 @@ void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg,
format
.
push_back
(
in_out
.
size
());
}
void
GradOpCreator
::
CompleteGradOp
(
OperatorBase
*
grad_op
)
const
{
grad_op
->
type_
=
op_
->
type_
+
"@GRAD"
;
// not necessary
void
GradOpBuilder
::
CompleteGradOp
(
OperatorBase
*
grad_op
)
const
{
grad_op
->
attrs_
=
op_
->
attrs_
;
grad_op
->
attrs_
.
erase
(
"input_format"
);
grad_op
->
attrs_
.
erase
(
"output_format"
);
...
...
paddle/framework/grad_op_
creato
r.h
→
paddle/framework/grad_op_
builde
r.h
浏览文件 @
72b5bd93
...
...
@@ -25,12 +25,12 @@ struct OpInOutArg {
size_t
end_idx_
;
};
class
GradOp
Creato
r
{
class
GradOp
Builde
r
{
using
VarIndexMap
=
std
::
unordered_map
<
std
::
string
,
int
>
;
public:
GradOp
Creato
r
(
const
OperatorBase
*
op
)
:
op_
(
op
)
{}
OperatorBase
*
Create
();
GradOp
Builde
r
(
const
OperatorBase
*
op
)
:
op_
(
op
)
{}
OperatorBase
*
Build
();
private:
OpInOutArg
*
BuildArg
(
const
VarProto
&
var
,
const
VarIndexMap
&
var_map
,
...
...
paddle/framework/grad_op_
creato
r_test.cc
→
paddle/framework/grad_op_
builde
r_test.cc
浏览文件 @
72b5bd93
#include "paddle/framework/grad_op_
creato
r.h"
#include "paddle/framework/grad_op_
builde
r.h"
#include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
...
...
@@ -8,7 +8,7 @@ USE_OP(add_two);
namespace
paddle
{
namespace
framework
{
TEST
(
GradOp
Creato
r
,
AddTwo
)
{
TEST
(
GradOp
Builde
r
,
AddTwo
)
{
std
::
shared_ptr
<
OperatorBase
>
add_op
(
OpRegistry
::
CreateOp
(
"add_two"
,
{
"x"
,
"y"
},
{
"out"
},
{}));
std
::
shared_ptr
<
OperatorBase
>
grad_add_op
=
OpRegistry
::
CreateGradOp
(
add_op
);
...
...
paddle/framework/op_registry.h
浏览文件 @
72b5bd93
...
...
@@ -20,7 +20,7 @@ limitations under the License. */
#include <unordered_map>
#include <unordered_set>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/grad_op_
creato
r.h"
#include "paddle/framework/grad_op_
builde
r.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/scope.h"
...
...
@@ -222,7 +222,7 @@ class OpRegistry {
public:
template
<
typename
OpType
,
typename
ProtoMakerType
>
static
void
RegisterOp
(
const
std
::
string
&
op_type
)
{
creators
()[
op_type
]
=
[]
{
return
new
OpType
;
};
op_
creators
()[
op_type
]
=
[]
{
return
new
OpType
;
};
OpAttrChecker
&
op_checker
=
op_checkers
()[
op_type
];
OpProto
&
op_proto
=
protos
()[
op_type
];
auto
maker
=
ProtoMakerType
(
&
op_proto
,
&
op_checker
);
...
...
@@ -245,17 +245,19 @@ class OpRegistry {
}
}
template
<
typename
OpType
>
static
void
RegisterGradOp
(
const
std
::
string
&
op_type
)
{
grad_creators
()[
op_type
]
=
[]
{
return
new
OpType
;
};
template
<
typename
GradOpType
>
static
void
RegisterGradOp
(
const
std
::
string
&
op_type
,
const
std
::
string
&
grad_op_type
)
{
op_creators
()[
grad_op_type
]
=
[]
{
return
new
GradOpType
;
};
grad_ops
()[
op_type
]
=
grad_op_type
;
}
static
std
::
shared_ptr
<
OperatorBase
>
CreateOp
(
const
std
::
string
&
type
,
const
VarNameList
&
inputs
,
const
VarNameList
&
outputs
,
const
AttributeMap
&
attrs
)
{
auto
op_create_it
=
creators
().
find
(
type
);
PADDLE_ENFORCE
(
op_create_it
!=
creators
().
end
(),
auto
op_create_it
=
op_
creators
().
find
(
type
);
PADDLE_ENFORCE
(
op_create_it
!=
op_
creators
().
end
(),
"Operator %s cannot be found."
,
type
);
auto
op
=
op_create_it
->
second
();
...
...
@@ -300,8 +302,8 @@ class OpRegistry {
static
std
::
shared_ptr
<
OperatorBase
>
CreateGradOp
(
std
::
shared_ptr
<
OperatorBase
>
op
)
{
GradOp
Creator
creato
r
(
op
.
get
());
std
::
shared_ptr
<
OperatorBase
>
grad_op
(
creator
.
Create
());
GradOp
Builder
builde
r
(
op
.
get
());
std
::
shared_ptr
<
OperatorBase
>
grad_op
(
builder
.
Build
());
grad_op
->
Init
();
return
grad_op
;
}
...
...
@@ -311,9 +313,9 @@ class OpRegistry {
return
protos_
;
};
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>&
grad_creator
s
()
{
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>
grad_creator
s_
;
return
grad_
creator
s_
;
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
grad_op
s
()
{
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
grad_op
s_
;
return
grad_
op
s_
;
}
static
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
VarIndexMap
>>&
...
...
@@ -322,12 +324,12 @@ class OpRegistry {
return
maps_
;
}
private:
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>&
creators
()
{
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>
creators_
;
return
creators_
;
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>&
op_creators
()
{
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>
op_creators_
;
return
op_creators_
;
}
private:
static
std
::
unordered_map
<
std
::
string
,
OpAttrChecker
>&
op_checkers
()
{
static
std
::
unordered_map
<
std
::
string
,
OpAttrChecker
>
op_checkers_
;
return
op_checkers_
;
...
...
@@ -353,11 +355,11 @@ class OpRegisterHelper {
}
};
template
<
typename
OpType
>
template
<
typename
Grad
OpType
>
class
GradOpRegisterHelper
{
public:
GradOpRegisterHelper
(
const
char
*
op_type
)
{
OpRegistry
::
RegisterGradOp
<
OpType
>
(
op_type
);
GradOpRegisterHelper
(
const
char
*
op_type
,
const
char
*
grad_op_type
)
{
OpRegistry
::
RegisterGradOp
<
GradOpType
>
(
op_type
,
grad_
op_type
);
}
};
...
...
@@ -383,13 +385,16 @@ class GradOpRegisterHelper {
/**
* Macro to Register Gradient Operator.
*/
#define REGISTER_GRADIENT_OP(__op_type, __op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##__op_type, \
"REGISTER_GRADIENT_OP must be in global namespace"); \
static ::paddle::framework::GradOpRegisterHelper<__op_class> \
__op_gradient_register_##__op_type##__(#__op_type); \
int __op_gradient_register_##__op_type##_handle__() { return 0; }
#define REGISTER_GRADIENT_OP(__op_type, __grad_op_type, __grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##__op_type##__grad_op_type, \
"REGISTER_GRADIENT_OP must be in global namespace"); \
static ::paddle::framework::GradOpRegisterHelper<__grad_op_class> \
__op_gradient_register_##__op_type##__grad_op_type##__(#__op_type, \
#__grad_op_type); \
int __op_gradient_register_##__op_type##__grad_op_type##_handle__() { \
return 0; \
}
/**
* Macro to Register OperatorKernel.
...
...
paddle/operators/add_op.cc
浏览文件 @
72b5bd93
...
...
@@ -65,6 +65,6 @@ protected:
}
// namespace paddle
REGISTER_OP
(
add_two
,
paddle
::
operators
::
AddOp
,
paddle
::
operators
::
AddOpMaker
);
REGISTER_GRADIENT_OP
(
add_two
,
paddle
::
operators
::
AddOpGrad
);
REGISTER_GRADIENT_OP
(
add_two
,
add_two_grad
,
paddle
::
operators
::
AddOpGrad
);
REGISTER_OP_CPU_KERNEL
(
add_two
,
paddle
::
operators
::
AddKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/add_op_test.cc
浏览文件 @
72b5bd93
...
...
@@ -22,7 +22,7 @@ TEST(AddOp, GetOpProto) {
auto
&
protos
=
paddle
::
framework
::
OpRegistry
::
protos
();
auto
it
=
protos
.
find
(
"add_two"
);
ASSERT_NE
(
it
,
protos
.
end
());
auto
&
grad_creators
=
paddle
::
framework
::
OpRegistry
::
grad
_creators
();
auto
it1
=
grad_creators
.
find
(
"add_two
"
);
ASSERT_NE
(
it1
,
grad
_creators
.
end
());
auto
&
op_creators
=
paddle
::
framework
::
OpRegistry
::
op
_creators
();
auto
it1
=
op_creators
.
find
(
"add_two_grad
"
);
ASSERT_NE
(
it1
,
op
_creators
.
end
());
}
paddle/operators/mul_op.cc
浏览文件 @
72b5bd93
...
...
@@ -67,7 +67,7 @@ protected:
}
// namespace paddle
REGISTER_OP
(
mul
,
paddle
::
operators
::
MulOp
,
paddle
::
operators
::
MulOpMaker
);
REGISTER_GRADIENT_OP
(
mul
,
paddle
::
operators
::
MulOpGrad
);
REGISTER_GRADIENT_OP
(
mul
,
mul_grad
,
paddle
::
operators
::
MulOpGrad
);
REGISTER_OP_CPU_KERNEL
(
mul
,
paddle
::
operators
::
MulKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/sigmoid_op.cc
浏览文件 @
72b5bd93
...
...
@@ -56,7 +56,7 @@ protected:
REGISTER_OP
(
sigmoid
,
paddle
::
operators
::
SigmoidOp
,
paddle
::
operators
::
SigmoidOpMaker
);
REGISTER_GRADIENT_OP
(
sigmoid
,
paddle
::
operators
::
SigmoidOpGrad
);
REGISTER_GRADIENT_OP
(
sigmoid
,
sigmoid_grad
,
paddle
::
operators
::
SigmoidOpGrad
);
REGISTER_OP_CPU_KERNEL
(
sigmoid
,
...
...
paddle/operators/softmax_op.cc
浏览文件 @
72b5bd93
...
...
@@ -59,6 +59,6 @@ protected:
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
softmax
,
ops
::
SoftmaxOp
,
ops
::
SoftmaxOpMaker
);
REGISTER_GRADIENT_OP
(
softmax
,
paddle
::
operators
::
SoftmaxOpGrad
);
REGISTER_GRADIENT_OP
(
softmax
,
softmax_grad
,
paddle
::
operators
::
SoftmaxOpGrad
);
REGISTER_OP_CPU_KERNEL
(
softmax
,
ops
::
SoftmaxKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录