Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
4c96008a
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4c96008a
编写于
10月 06, 2017
作者:
Y
Yi Wang
提交者:
GitHub
10月 06, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #4566 from reyoung/feature/grad_reg_mechanism_cont2
Complete register gradient for compile time
上级
f8b5d54c
803b7b62
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
271 addition
and
491 deletion
+271
-491
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+1
-3
paddle/framework/backward.cc
paddle/framework/backward.cc
+31
-1
paddle/framework/backward_test.cc
paddle/framework/backward_test.cc
+23
-36
paddle/framework/framework.proto
paddle/framework/framework.proto
+0
-1
paddle/framework/grad_op_builder.cc
paddle/framework/grad_op_builder.cc
+0
-97
paddle/framework/grad_op_builder.h
paddle/framework/grad_op_builder.h
+0
-28
paddle/framework/grad_op_builder_test.cc
paddle/framework/grad_op_builder_test.cc
+0
-186
paddle/framework/op_desc.h
paddle/framework/op_desc.h
+16
-2
paddle/framework/op_info.h
paddle/framework/op_info.h
+8
-8
paddle/framework/op_proto_maker.h
paddle/framework/op_proto_maker.h
+0
-5
paddle/framework/op_registry.cc
paddle/framework/op_registry.cc
+6
-4
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+47
-40
paddle/operators/mean_op.cc
paddle/operators/mean_op.cc
+18
-2
paddle/operators/minus_op.cc
paddle/operators/minus_op.cc
+32
-24
paddle/operators/pad_op.cc
paddle/operators/pad_op.cc
+20
-3
paddle/operators/scale_op.cc
paddle/operators/scale_op.cc
+15
-18
paddle/operators/softmax_with_cross_entropy_op.cc
paddle/operators/softmax_with_cross_entropy_op.cc
+32
-13
paddle/operators/sum_op.cc
paddle/operators/sum_op.cc
+22
-20
未找到文件。
paddle/framework/CMakeLists.txt
浏览文件 @
4c96008a
...
@@ -26,10 +26,8 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc)
...
@@ -26,10 +26,8 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc)
cc_library
(
operator SRCS operator.cc DEPS op_info device_context tensor scope proto_desc
)
cc_library
(
operator SRCS operator.cc DEPS op_info device_context tensor scope proto_desc
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry
)
cc_library
(
grad_op_builder SRCS grad_op_builder.cc DEPS operator proto_desc
)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator
)
cc_library
(
op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker op_info
)
cc_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry
)
cc_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry
)
cc_test
(
grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry sum_op
)
py_proto_compile
(
framework_py_proto SRCS framework.proto
)
py_proto_compile
(
framework_py_proto SRCS framework.proto
)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
...
...
paddle/framework/backward.cc
浏览文件 @
4c96008a
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
limitations under the License. */
limitations under the License. */
#include "paddle/framework/backward.h"
#include "paddle/framework/backward.h"
#include "paddle/operators/net_op.h"
#include <list>
#include <list>
#include <memory>
#include <memory>
...
@@ -24,6 +25,35 @@
...
@@ -24,6 +25,35 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
static
inline
std
::
unique_ptr
<
OperatorBase
>
CreateGradOp
(
const
OperatorBase
&
op
)
{
OpDescBind
op_desc
;
op_desc
.
SetInputMap
(
op
.
Inputs
());
op_desc
.
SetOutputMap
(
op
.
Outputs
());
op_desc
.
SetType
(
op
.
Type
());
op_desc
.
SetAttrMap
(
op
.
Attrs
());
auto
&
info
=
OpInfoMap
::
Instance
().
Get
(
op
.
Type
());
auto
grad_descs
=
info
.
GradOpMaker
()(
op_desc
);
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>
grad_ops
;
grad_ops
.
reserve
(
grad_descs
.
size
());
std
::
transform
(
grad_descs
.
begin
(),
grad_descs
.
end
(),
std
::
back_inserter
(
grad_ops
),
[](
const
std
::
unique_ptr
<
OpDescBind
>&
grad_desc
)
{
return
OpRegistry
::
CreateOp
(
*
grad_desc
);
});
PADDLE_ENFORCE
(
!
grad_ops
.
empty
());
if
(
grad_ops
.
size
()
==
1
)
{
return
std
::
move
(
grad_ops
[
0
]);
}
else
{
auto
net_op
=
new
operators
::
NetOp
();
for
(
auto
&
grad_op
:
grad_ops
)
{
net_op
->
AppendOp
(
std
::
move
(
grad_op
));
}
net_op
->
CompleteAddOp
();
return
std
::
unique_ptr
<
OperatorBase
>
(
net_op
);
}
}
template
<
typename
Map
,
typename
T
>
template
<
typename
Map
,
typename
T
>
static
void
ForEachVarName
(
const
Map
&
names
,
T
callback
)
{
static
void
ForEachVarName
(
const
Map
&
names
,
T
callback
)
{
for
(
auto
&
name
:
names
)
{
for
(
auto
&
name
:
names
)
{
...
@@ -171,7 +201,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
...
@@ -171,7 +201,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
net
->
InsertOp
(
pos
.
first
+
1
,
std
::
move
(
pos
.
second
));
net
->
InsertOp
(
pos
.
first
+
1
,
std
::
move
(
pos
.
second
));
}
}
}
else
{
}
else
{
std
::
unique_ptr
<
OperatorBase
>
grad_op
(
OpRegistry
::
CreateGradOp
(
forwardOp
));
std
::
unique_ptr
<
OperatorBase
>
grad_op
(
CreateGradOp
(
forwardOp
));
ForEachVarName
(
grad_op
->
Inputs
(),
[
&
no_grad_names
,
&
net
,
&
grad_op
](
ForEachVarName
(
grad_op
->
Inputs
(),
[
&
no_grad_names
,
&
net
,
&
grad_op
](
const
std
::
string
&
grad_input
)
{
const
std
::
string
&
grad_input
)
{
...
...
paddle/framework/backward_test.cc
浏览文件 @
4c96008a
...
@@ -21,24 +21,34 @@
...
@@ -21,24 +21,34 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
using
OperatorBase
=
framework
::
OperatorBase
;
using
OpProtoAndCheckerMaker
=
framework
::
OpProtoAndCheckerMaker
;
using
OpProto
=
framework
::
OpProto
;
using
OpAttrChecker
=
framework
::
OpAttrChecker
;
using
Scope
=
framework
::
Scope
;
using
DeviceContext
=
platform
::
DeviceContext
;
using
DeviceContext
=
platform
::
DeviceContext
;
class
RowWiseAddOpMaker
:
public
OpProtoAndCheckerMaker
{
class
RowWiseAddOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
public:
RowWiseAddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
RowWiseAddOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"Input X of Add"
)
.
NotInGradient
()
;
AddInput
(
"X"
,
"Input X of Add"
);
AddInput
(
"b"
,
"Bias of Add"
)
.
NotInGradient
()
;
AddInput
(
"b"
,
"Bias of Add"
);
AddOutput
(
"Out"
,
"Out of Add"
)
.
NotInGradient
()
;
AddOutput
(
"Out"
,
"Out of Add"
);
AddComment
(
"Add Op"
);
AddComment
(
"Add Op"
);
}
}
};
};
class
RowWiseAddGradMaker
:
public
SingleGradOpDescMaker
{
public:
using
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
std
::
unique_ptr
<
OpDescBind
>
Apply
()
const
override
{
auto
grad_op
=
new
OpDescBind
();
grad_op
->
SetInput
(
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
grad_op
->
SetOutput
(
GradVarName
(
"X"
),
InputGrad
(
"X"
));
grad_op
->
SetOutput
(
GradVarName
(
"b"
),
InputGrad
(
"b"
));
grad_op
->
SetType
(
"rowwise_add_grad"
);
return
std
::
unique_ptr
<
OpDescBind
>
(
grad_op
);
}
};
class
MulOpMaker
:
public
OpProtoAndCheckerMaker
{
class
MulOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
public:
MulOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
MulOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
...
@@ -137,10 +147,8 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -137,10 +147,8 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
public:
SumOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
SumOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"the input tensors of sum operator."
)
AddInput
(
"X"
,
"the input tensors of sum operator."
).
AsDuplicable
();
.
AsDuplicable
()
AddOutput
(
"Out"
,
"the output tensor of sum operator."
);
.
NotInGradient
();
AddOutput
(
"Out"
,
"the output tensor of sum operator."
).
NotInGradient
();
AddComment
(
""
);
AddComment
(
""
);
}
}
};
};
...
@@ -151,8 +159,9 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -151,8 +159,9 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
namespace
f
=
paddle
::
framework
;
namespace
f
=
paddle
::
framework
;
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
using
EnforceNotMet
=
paddle
::
platform
::
EnforceNotMet
;
using
EnforceNotMet
=
paddle
::
platform
::
EnforceNotMet
;
REGISTER_OP
(
rowwise_add
,
f
::
NOP
,
f
::
RowWiseAddOpMaker
,
rowwise_add_grad
,
REGISTER_OPERATOR
(
rowwise_add
,
f
::
NOP
,
f
::
RowWiseAddOpMaker
,
f
::
NOP
);
f
::
RowWiseAddGradMaker
);
REGISTER_OPERATOR
(
rowwise_add_grad
,
f
::
NOP
);
REGISTER_OP
(
mul
,
f
::
NOP
,
f
::
MulOpMaker
,
mul_grad
,
f
::
NOP
);
REGISTER_OP
(
mul
,
f
::
NOP
,
f
::
MulOpMaker
,
mul_grad
,
f
::
NOP
);
REGISTER_OP
(
sigmoid
,
f
::
NOP
,
f
::
SigmoidOpMaker
,
sigmoid_grad
,
f
::
NOP
);
REGISTER_OP
(
sigmoid
,
f
::
NOP
,
f
::
SigmoidOpMaker
,
sigmoid_grad
,
f
::
NOP
);
REGISTER_OP_WITHOUT_GRADIENT
(
nograd
,
f
::
NOP
,
f
::
NoGradOpMaker
);
REGISTER_OP_WITHOUT_GRADIENT
(
nograd
,
f
::
NOP
,
f
::
NoGradOpMaker
);
...
@@ -162,17 +171,6 @@ REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker);
...
@@ -162,17 +171,6 @@ REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP
(
many_output_op
,
f
::
NOP
,
f
::
ManyOutputOpMaker
,
many_output_op_grad
,
REGISTER_OP
(
many_output_op
,
f
::
NOP
,
f
::
ManyOutputOpMaker
,
many_output_op_grad
,
f
::
NOP
);
f
::
NOP
);
TEST
(
Backward
,
simple_op_grad
)
{
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{{
"X"
,
{
"x"
}},
{
"b"
,
{
"b"
}}},
{{
"Out"
,
{
"out"
}}},
{});
ASSERT_NE
(
fwd
,
nullptr
);
auto
gop
=
f
::
OpRegistry
::
CreateGradOp
(
*
fwd
);
ASSERT_EQ
(
1UL
,
gop
->
Inputs
().
size
());
ASSERT_EQ
(
"rowwise_add_grad"
,
gop
->
Type
());
ASSERT_EQ
(
f
::
GradVarName
(
"x"
),
gop
->
Output
(
f
::
GradVarName
(
"X"
)));
ASSERT_EQ
(
f
::
GradVarName
(
"b"
),
gop
->
Output
(
f
::
GradVarName
(
"b"
)));
}
TEST
(
Backward
,
simple_op_not_need_grad
)
{
TEST
(
Backward
,
simple_op_not_need_grad
)
{
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{{
"X"
,
{
"x"
}},
{
"b"
,
{
"b"
}}},
{{
"Out"
,
{
"out"
}}},
{});
"rowwise_add"
,
{{
"X"
,
{
"x"
}},
{
"b"
,
{
"b"
}}},
{{
"Out"
,
{
"out"
}}},
{});
...
@@ -289,17 +287,6 @@ TEST(Backward, net_shared_weight) {
...
@@ -289,17 +287,6 @@ TEST(Backward, net_shared_weight) {
ASSERT_EQ
(
"sum"
,
bwd_net
->
ops_
[
2
]
->
Type
());
ASSERT_EQ
(
"sum"
,
bwd_net
->
ops_
[
2
]
->
Type
());
}
}
TEST
(
Backward
,
op_register_grad_not_for_network
)
{
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"fc"
,
{{
"X"
,
{
"x"
}},
{
"W"
,
{
"w"
}},
{
"b"
,
{
"b"
}}},
{{
"mul_result"
,
{
"mul_out"
}},
{
"add_result"
,
{
"add_out"
}},
{
"Out"
,
{
"out1"
}}},
{{
"temporary_index"
,
std
::
vector
<
int
>
{
0
,
1
}}});
ASSERT_THROW
(
f
::
OpRegistry
::
CreateGradOp
(
*
fwd
),
EnforceNotMet
);
}
TEST
(
Backward
,
op_all_input_are_not_need
)
{
TEST
(
Backward
,
op_all_input_are_not_need
)
{
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
"rowwise_add"
,
{{
"X"
,
{
"x"
}},
{
"b"
,
{
"b"
}}},
{{
"Out"
,
{
"out"
}}},
{});
"rowwise_add"
,
{{
"X"
,
{
"x"
}},
{
"b"
,
{
"b"
}}},
{{
"Out"
,
{
"out"
}}},
{});
...
...
paddle/framework/framework.proto
浏览文件 @
4c96008a
...
@@ -66,7 +66,6 @@ message OpProto {
...
@@ -66,7 +66,6 @@ message OpProto {
optional
bool
duplicable
=
3
[
default
=
false
];
optional
bool
duplicable
=
3
[
default
=
false
];
optional
bool
intermediate
=
4
[
default
=
false
];
optional
bool
intermediate
=
4
[
default
=
false
];
optional
bool
not_in_gradient
=
5
[
default
=
false
];
}
}
// AttrProto describes the C++ type Attribute.
// AttrProto describes the C++ type Attribute.
...
...
paddle/framework/grad_op_builder.cc
已删除
100644 → 0
浏览文件 @
f8b5d54c
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOpArgType::OUT WARRANTIES OR CONDITIONS OF ANY KOpArgType::IND, either
express or implied. See the License for the specific language governing
permissions and limitations under the License. */
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
framework
{
enum
class
OpArgType
{
IN
,
OUT
};
static
void
TransOpArg
(
const
OperatorBase
*
src_op
,
const
OpArgType
&
src_type
,
bool
is_grad
,
VariableNameMap
*
vars
)
{
const
auto
&
src_inout
=
src_type
==
OpArgType
::
IN
?
src_op
->
Inputs
()
:
src_op
->
Outputs
();
auto
&
dst_inout
=
*
vars
;
auto
&
proto
=
OpInfoMap
::
Instance
().
Get
(
src_op
->
Type
()).
Proto
();
const
auto
&
src_arg_list
=
src_type
==
OpArgType
::
IN
?
proto
.
inputs
()
:
proto
.
outputs
();
for
(
const
auto
&
arg
:
src_arg_list
)
{
if
(
arg
.
not_in_gradient
()
&&
!
is_grad
)
continue
;
const
std
::
string
src_name
=
arg
.
name
();
std
::
string
dst_name
=
is_grad
?
GradVarName
(
src_name
)
:
src_name
;
dst_inout
[
dst_name
].
reserve
(
src_inout
.
at
(
src_name
).
size
());
for
(
auto
&
var_name
:
src_inout
.
at
(
src_name
))
{
std
::
string
s
=
is_grad
?
GradVarName
(
var_name
)
:
var_name
;
dst_inout
[
dst_name
].
emplace_back
(
s
);
}
}
}
OperatorBase
*
BuildGradOp
(
const
OperatorBase
*
op
)
{
auto
&
info
=
OpInfoMap
::
Instance
().
Get
(
op
->
Type
());
PADDLE_ENFORCE
(
info
.
HasGradientOp
());
VariableNameMap
inputs
;
VariableNameMap
outputs
;
TransOpArg
(
op
,
OpArgType
::
IN
,
false
,
&
inputs
);
// I
TransOpArg
(
op
,
OpArgType
::
OUT
,
false
,
&
inputs
);
// O
TransOpArg
(
op
,
OpArgType
::
OUT
,
true
,
&
inputs
);
// OG
TransOpArg
(
op
,
OpArgType
::
IN
,
true
,
&
outputs
);
// IG
auto
&
grad_info
=
OpInfoMap
::
Instance
().
Get
(
info
.
grad_op_type_
);
return
grad_info
.
Creator
()(
info
.
grad_op_type_
,
inputs
,
outputs
,
op
->
Attrs
());
}
static
void
TransOpDescArg
(
const
OpDescBind
*
src_op
,
const
OpArgType
&
src_type
,
bool
is_grad
,
OpDescBind
*
dst_op
,
const
OpArgType
&
dst_type
)
{
PADDLE_ENFORCE
(
dst_op
!=
nullptr
,
"Protobuf desc of gradient op must be initialized first."
);
const
auto
&
proto
=
OpInfoMap
::
Instance
().
Get
(
src_op
->
Type
()).
Proto
();
const
auto
&
src_arg_list
=
src_type
==
OpArgType
::
IN
?
proto
.
inputs
()
:
proto
.
outputs
();
for
(
const
auto
&
arg
:
src_arg_list
)
{
if
(
arg
.
not_in_gradient
()
&&
!
is_grad
)
continue
;
const
std
::
string
src_name
=
arg
.
name
();
std
::
vector
<
std
::
string
>
vars
=
src_type
==
OpArgType
::
IN
?
src_op
->
Input
(
src_name
)
:
src_op
->
Output
(
src_name
);
if
(
is_grad
)
{
for
(
std
::
string
&
var
:
vars
)
{
var
=
GradVarName
(
var
);
}
}
std
::
string
dst_name
=
is_grad
?
GradVarName
(
src_name
)
:
src_name
;
dst_type
==
OpArgType
::
IN
?
dst_op
->
SetInput
(
dst_name
,
vars
)
:
dst_op
->
SetOutput
(
dst_name
,
vars
);
}
}
void
CompleteGradOpDesc
(
const
OpDescBind
*
forw_op
,
OpDescBind
*
grad_op
)
{
auto
&
info
=
OpInfoMap
::
Instance
().
Get
(
forw_op
->
Type
());
PADDLE_ENFORCE
(
info
.
HasGradientOp
());
grad_op
->
SetType
(
info
.
grad_op_type_
);
TransOpDescArg
(
forw_op
,
OpArgType
::
IN
,
false
,
grad_op
,
OpArgType
::
IN
);
TransOpDescArg
(
forw_op
,
OpArgType
::
OUT
,
false
,
grad_op
,
OpArgType
::
IN
);
TransOpDescArg
(
forw_op
,
OpArgType
::
OUT
,
true
,
grad_op
,
OpArgType
::
IN
);
TransOpDescArg
(
forw_op
,
OpArgType
::
IN
,
true
,
grad_op
,
OpArgType
::
OUT
);
grad_op
->
SetAttrMap
(
forw_op
->
GetAttrMap
());
}
}
// namespace framework
}
// namespace paddle
paddle/framework/grad_op_builder.h
已删除
100644 → 0
浏览文件 @
f8b5d54c
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
framework
{
OperatorBase
*
BuildGradOp
(
const
OperatorBase
*
op
);
void
CompleteGradOpDesc
(
const
OpDescBind
*
forw_op
,
OpDescBind
*
grad_op
);
}
// namespace framework
}
// namespace paddle
paddle/framework/grad_op_builder_test.cc
已删除
100644 → 0
浏览文件 @
f8b5d54c
#include "paddle/framework/grad_op_builder.h"
#include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
USE_OP
(
sum
);
namespace
paddle
{
namespace
framework
{
class
MutiInOutOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
MutiInOutOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"In1"
,
"a single input"
);
AddInput
(
"In2_mult"
,
"a multiple input"
).
AsDuplicable
();
AddInput
(
"In3"
,
"another single input"
);
AddOutput
(
"Out1"
,
"a single output"
);
AddOutput
(
"Out2_mult"
,
"a multiple output"
).
AsDuplicable
();
AddComment
(
"test op with multiple inputs and outputs"
);
}
};
class
IOIgnoredOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
IOIgnoredOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"In1"
,
"a single input"
);
AddInput
(
"In2_mult"
,
"a multiple input"
).
AsDuplicable
().
NotInGradient
();
AddInput
(
"In3_mult"
,
"another multiple input"
).
AsDuplicable
();
AddOutput
(
"Out1_mult"
,
"a multiple output"
).
AsDuplicable
();
AddOutput
(
"Out2"
,
"a single output"
).
NotInGradient
();
AddComment
(
"op with inputs and outputs ignored in gradient calculating"
);
}
};
}
// namespace framework
}
// namespace paddle
namespace
f
=
paddle
::
framework
;
REGISTER_OP
(
mult_io
,
f
::
NOP
,
f
::
MutiInOutOpMaker
,
mult_io_grad
,
f
::
NOP
);
REGISTER_OP
(
io_ignored
,
f
::
NOP
,
f
::
IOIgnoredOpMaker
,
io_ignored_grad
,
f
::
NOP
);
TEST
(
GradOpBuilder
,
MutiInOut
)
{
std
::
shared_ptr
<
f
::
OperatorBase
>
test_op
(
f
::
OpRegistry
::
CreateOp
(
"mult_io"
,
{{
"In1"
,
{
"in1"
}},
{
"In2_mult"
,
{
"in2_1"
,
"in2_2"
,
"in2_3"
}},
{
"In3"
,
{
"in3"
}}},
{{
"Out1"
,
{
"out1"
}},
{
"Out2_mult"
,
{
"out2_1"
,
"out2_2"
}}},
{}));
std
::
shared_ptr
<
f
::
OperatorBase
>
grad_test_op
=
f
::
OpRegistry
::
CreateGradOp
(
*
test_op
);
ASSERT_EQ
(
grad_test_op
->
Inputs
().
size
(),
3UL
+
2UL
+
2UL
);
EXPECT_EQ
(
grad_test_op
->
Input
(
"In1"
),
"in1"
);
EXPECT_EQ
(
grad_test_op
->
Inputs
(
"In2_mult"
),
std
::
vector
<
std
::
string
>
({
"in2_1"
,
"in2_2"
,
"in2_3"
}));
EXPECT_EQ
(
grad_test_op
->
Input
(
"In3"
),
"in3"
);
EXPECT_EQ
(
grad_test_op
->
Input
(
"Out1"
),
"out1"
);
EXPECT_EQ
(
grad_test_op
->
Inputs
(
"Out2_mult"
),
std
::
vector
<
std
::
string
>
({
"out2_1"
,
"out2_2"
}));
EXPECT_EQ
(
grad_test_op
->
Input
(
f
::
GradVarName
(
"Out1"
)),
f
::
GradVarName
(
"out1"
));
EXPECT_EQ
(
grad_test_op
->
Inputs
(
f
::
GradVarName
(
"Out2_mult"
)),
std
::
vector
<
std
::
string
>
(
{
f
::
GradVarName
(
"out2_1"
),
f
::
GradVarName
(
"out2_2"
)}));
ASSERT_EQ
(
grad_test_op
->
Outputs
().
size
(),
3UL
);
EXPECT_EQ
(
grad_test_op
->
Output
(
f
::
GradVarName
(
"In1"
)),
f
::
GradVarName
(
"in1"
));
EXPECT_EQ
(
grad_test_op
->
Outputs
(
f
::
GradVarName
(
"In2_mult"
)),
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"in2_1"
),
f
::
GradVarName
(
"in2_2"
),
f
::
GradVarName
(
"in2_3"
)}));
EXPECT_EQ
(
grad_test_op
->
Output
(
f
::
GradVarName
(
"In3"
)),
f
::
GradVarName
(
"in3"
));
}
TEST
(
GradOpBuilder
,
IOIgnoredInGradient
)
{
std
::
shared_ptr
<
f
::
OperatorBase
>
test_op
(
f
::
OpRegistry
::
CreateOp
(
"io_ignored"
,
{{
"In1"
,
{
"in1"
}},
{
"In2_mult"
,
{
"in2_1"
,
"in2_2"
}},
{
"In3_mult"
,
{
"in3_1"
,
"in3_2"
}}},
{{
"Out1_mult"
,
{
"out1_1"
,
"out1_2"
}},
{
"Out2"
,
{
"out2"
}}},
{}));
std
::
shared_ptr
<
f
::
OperatorBase
>
grad_test_op
=
f
::
OpRegistry
::
CreateGradOp
(
*
test_op
);
// 'In2' and 'Out2' are ignored in gradient calculating
ASSERT_EQ
(
grad_test_op
->
Inputs
().
size
(),
2UL
+
1UL
+
2UL
);
EXPECT_EQ
(
grad_test_op
->
Input
(
"In1"
),
"in1"
);
EXPECT_EQ
(
grad_test_op
->
Inputs
(
"In3_mult"
),
std
::
vector
<
std
::
string
>
({
"in3_1"
,
"in3_2"
}));
EXPECT_EQ
(
grad_test_op
->
Inputs
(
"Out1_mult"
),
std
::
vector
<
std
::
string
>
({
"out1_1"
,
"out1_2"
}));
EXPECT_EQ
(
grad_test_op
->
Inputs
(
f
::
GradVarName
(
"Out1_mult"
)),
std
::
vector
<
std
::
string
>
(
{
f
::
GradVarName
(
"out1_1"
),
f
::
GradVarName
(
"out1_2"
)}));
EXPECT_EQ
(
grad_test_op
->
Input
(
f
::
GradVarName
(
"Out2"
)),
f
::
GradVarName
(
"out2"
));
ASSERT_EQ
(
grad_test_op
->
Outputs
().
size
(),
3UL
);
EXPECT_EQ
(
grad_test_op
->
Output
(
f
::
GradVarName
(
"In1"
)),
f
::
GradVarName
(
"in1"
));
EXPECT_EQ
(
grad_test_op
->
Outputs
(
f
::
GradVarName
(
"In2_mult"
)),
std
::
vector
<
std
::
string
>
(
{
f
::
GradVarName
(
"in2_1"
),
f
::
GradVarName
(
"in2_2"
)}));
EXPECT_EQ
(
grad_test_op
->
Outputs
(
f
::
GradVarName
(
"In3_mult"
)),
std
::
vector
<
std
::
string
>
(
{
f
::
GradVarName
(
"in3_1"
),
f
::
GradVarName
(
"in3_2"
)}));
}
TEST
(
GradOpDescBuilder
,
MutiInOut
)
{
f
::
OpDescBind
*
forw_op
=
new
f
::
OpDescBind
();
forw_op
->
SetType
(
"mult_io"
);
forw_op
->
SetInput
(
"In1"
,
{
"in1"
});
forw_op
->
SetInput
(
"In2_mult"
,
{
"in2_1"
,
"in2_2"
,
"in2_3"
});
forw_op
->
SetInput
(
"In3"
,
{
"in3"
});
forw_op
->
SetOutput
(
"Out1"
,
{
"out1"
});
forw_op
->
SetOutput
(
"Out2_mult"
,
{
"out2_1"
,
"out2_2"
});
f
::
OpDescBind
*
grad_op
=
new
f
::
OpDescBind
();
f
::
CompleteGradOpDesc
(
forw_op
,
grad_op
);
EXPECT_EQ
(
grad_op
->
Type
(),
"mult_io_grad"
);
ASSERT_EQ
(
grad_op
->
InputNames
().
size
(),
3UL
+
2UL
+
2UL
);
EXPECT_EQ
(
grad_op
->
Input
(
"In1"
),
std
::
vector
<
std
::
string
>
({
"in1"
}));
EXPECT_EQ
(
grad_op
->
Input
(
"In2_mult"
),
std
::
vector
<
std
::
string
>
({
"in2_1"
,
"in2_2"
,
"in2_3"
}));
EXPECT_EQ
(
grad_op
->
Input
(
"In3"
),
std
::
vector
<
std
::
string
>
({
"in3"
}));
EXPECT_EQ
(
grad_op
->
Input
(
"Out1"
),
std
::
vector
<
std
::
string
>
({
"out1"
}));
EXPECT_EQ
(
grad_op
->
Input
(
"Out2_mult"
),
std
::
vector
<
std
::
string
>
({
"out2_1"
,
"out2_2"
}));
EXPECT_EQ
(
grad_op
->
Input
(
f
::
GradVarName
(
"Out1"
)),
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"out1"
)}));
EXPECT_EQ
(
grad_op
->
Input
(
f
::
GradVarName
(
"Out2_mult"
)),
std
::
vector
<
std
::
string
>
(
{
f
::
GradVarName
(
"out2_1"
),
f
::
GradVarName
(
"out2_2"
)}));
ASSERT_EQ
(
grad_op
->
OutputNames
().
size
(),
3UL
);
EXPECT_EQ
(
grad_op
->
Output
(
f
::
GradVarName
(
"In1"
)),
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"in1"
)}));
EXPECT_EQ
(
grad_op
->
Output
(
f
::
GradVarName
(
"In2_mult"
)),
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"in2_1"
),
f
::
GradVarName
(
"in2_2"
),
f
::
GradVarName
(
"in2_3"
)}));
EXPECT_EQ
(
grad_op
->
Output
(
f
::
GradVarName
(
"In3"
)),
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"in3"
)}));
delete
forw_op
;
delete
grad_op
;
}
TEST
(
GradOpDescBuilder
,
IOIgnoredInGradient
)
{
f
::
OpDescBind
*
forw_op
=
new
f
::
OpDescBind
();
forw_op
->
SetType
(
"io_ignored"
);
forw_op
->
SetInput
(
"In1"
,
{
"in1"
});
forw_op
->
SetInput
(
"In2_mult"
,
{
"in2_1"
,
"in2_2"
});
forw_op
->
SetInput
(
"In3_mult"
,
{
"in3_1"
,
"in3_2"
});
forw_op
->
SetOutput
(
"Out1_mult"
,
{
"out1_1"
,
"out1_2"
});
forw_op
->
SetOutput
(
"Out2"
,
{
"out2"
});
f
::
OpDescBind
*
grad_op
=
new
f
::
OpDescBind
();
f
::
CompleteGradOpDesc
(
forw_op
,
grad_op
);
EXPECT_EQ
(
grad_op
->
Type
(),
"io_ignored_grad"
);
// 'In2' and 'Out2' are ignored in gradient calculating
ASSERT_EQ
(
grad_op
->
InputNames
().
size
(),
2UL
+
1UL
+
2UL
);
EXPECT_EQ
(
grad_op
->
Input
(
"In1"
),
std
::
vector
<
std
::
string
>
({
"in1"
}));
EXPECT_EQ
(
grad_op
->
Input
(
"In3_mult"
),
std
::
vector
<
std
::
string
>
({
"in3_1"
,
"in3_2"
}));
EXPECT_EQ
(
grad_op
->
Input
(
"Out1_mult"
),
std
::
vector
<
std
::
string
>
({
"out1_1"
,
"out1_2"
}));
EXPECT_EQ
(
grad_op
->
Input
(
f
::
GradVarName
(
"Out1_mult"
)),
std
::
vector
<
std
::
string
>
(
{
f
::
GradVarName
(
"out1_1"
),
f
::
GradVarName
(
"out1_2"
)}));
EXPECT_EQ
(
grad_op
->
Input
(
f
::
GradVarName
(
"Out2"
)),
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"out2"
)}));
ASSERT_EQ
(
grad_op
->
OutputNames
().
size
(),
3UL
);
EXPECT_EQ
(
grad_op
->
Output
(
f
::
GradVarName
(
"In1"
)),
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"in1"
)}));
EXPECT_EQ
(
grad_op
->
Output
(
f
::
GradVarName
(
"In2_mult"
)),
std
::
vector
<
std
::
string
>
(
{
f
::
GradVarName
(
"in2_1"
),
f
::
GradVarName
(
"in2_2"
)}));
EXPECT_EQ
(
grad_op
->
Output
(
f
::
GradVarName
(
"In3_mult"
)),
std
::
vector
<
std
::
string
>
(
{
f
::
GradVarName
(
"in3_1"
),
f
::
GradVarName
(
"in3_2"
)}));
delete
forw_op
;
delete
grad_op
;
}
paddle/framework/op_desc.h
浏览文件 @
4c96008a
...
@@ -70,6 +70,22 @@ class OpDescBind {
...
@@ -70,6 +70,22 @@ class OpDescBind {
std
::
vector
<
std
::
string
>
InputNames
()
const
{
return
MapKeys
(
inputs_
);
}
std
::
vector
<
std
::
string
>
InputNames
()
const
{
return
MapKeys
(
inputs_
);
}
std
::
vector
<
std
::
string
>
OutputNames
()
const
{
return
MapKeys
(
outputs_
);
}
std
::
vector
<
std
::
string
>
OutputNames
()
const
{
return
MapKeys
(
outputs_
);
}
void
SetInputMap
(
const
VariableNameMap
&
input
)
{
this
->
inputs_
=
input
;
this
->
need_update_
=
true
;
}
void
SetOutputMap
(
const
VariableNameMap
&
output
)
{
this
->
outputs_
=
output
;
this
->
need_update_
=
true
;
}
void
Sync
();
const
VariableNameMap
&
Inputs
()
const
{
return
inputs_
;
}
const
VariableNameMap
&
Outputs
()
const
{
return
outputs_
;
}
private:
private:
template
<
typename
MapType
>
template
<
typename
MapType
>
static
std
::
vector
<
typename
MapType
::
key_type
>
MapKeys
(
const
MapType
&
map
)
{
static
std
::
vector
<
typename
MapType
::
key_type
>
MapKeys
(
const
MapType
&
map
)
{
...
@@ -81,8 +97,6 @@ class OpDescBind {
...
@@ -81,8 +97,6 @@ class OpDescBind {
return
ret_val
;
return
ret_val
;
}
}
void
Sync
();
OpDesc
op_desc_
;
OpDesc
op_desc_
;
VariableNameMap
inputs_
;
VariableNameMap
inputs_
;
VariableNameMap
outputs_
;
VariableNameMap
outputs_
;
...
...
paddle/framework/op_info.h
浏览文件 @
4c96008a
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include <map>
#include <map>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include "paddle/framework/attribute.h"
#include "paddle/framework/attribute.h"
#include "paddle/framework/op_desc.h"
#include "paddle/framework/op_desc.h"
#include "paddle/framework/type_defs.h"
#include "paddle/framework/type_defs.h"
...
@@ -27,7 +28,6 @@ namespace framework {
...
@@ -27,7 +28,6 @@ namespace framework {
struct
OpInfo
{
struct
OpInfo
{
OpCreator
creator_
;
OpCreator
creator_
;
std
::
string
grad_op_type_
;
GradOpMakerFN
grad_op_maker_
;
GradOpMakerFN
grad_op_maker_
;
OpProto
*
proto_
{
nullptr
};
OpProto
*
proto_
{
nullptr
};
OpAttrChecker
*
checker_
{
nullptr
};
OpAttrChecker
*
checker_
{
nullptr
};
...
@@ -43,19 +43,19 @@ struct OpInfo {
...
@@ -43,19 +43,19 @@ struct OpInfo {
return
*
proto_
;
return
*
proto_
;
}
}
const
OpAttrChecker
&
Checker
()
const
{
PADDLE_ENFORCE_NOT_NULL
(
checker_
,
"Operator Checker has not been registered"
);
return
*
checker_
;
}
const
OpCreator
&
Creator
()
const
{
const
OpCreator
&
Creator
()
const
{
PADDLE_ENFORCE_NOT_NULL
(
creator_
,
PADDLE_ENFORCE_NOT_NULL
(
creator_
,
"Operator Creator has not been registered"
);
"Operator Creator has not been registered"
);
return
creator_
;
return
creator_
;
}
}
bool
HasGradientOp
()
const
{
return
!
grad_op_type_
.
empty
();
}
const
GradOpMakerFN
&
GradOpMaker
()
const
{
PADDLE_ENFORCE_NOT_NULL
(
grad_op_maker_
,
"Operator GradOpMaker has not been registered."
);
return
grad_op_maker_
;
}
const
OpAttrChecker
*
Checker
()
const
{
return
checker_
;
}
};
};
class
OpInfoMap
{
class
OpInfoMap
{
...
...
paddle/framework/op_proto_maker.h
浏览文件 @
4c96008a
...
@@ -44,11 +44,6 @@ class OpProtoAndCheckerMaker {
...
@@ -44,11 +44,6 @@ class OpProtoAndCheckerMaker {
var_
->
set_intermediate
(
true
);
var_
->
set_intermediate
(
true
);
return
*
this
;
return
*
this
;
}
}
VariableBuilder
&
NotInGradient
()
{
var_
->
set_not_in_gradient
(
true
);
return
*
this
;
}
};
};
VariableBuilder
AddInput
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
);
VariableBuilder
AddInput
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
);
...
...
paddle/framework/op_registry.cc
浏览文件 @
4c96008a
...
@@ -23,7 +23,9 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
...
@@ -23,7 +23,9 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
AttributeMap
attrs
)
{
const
VariableNameMap
&
outputs
,
AttributeMap
attrs
)
{
auto
&
info
=
OpInfoMap
::
Instance
().
Get
(
type
);
auto
&
info
=
OpInfoMap
::
Instance
().
Get
(
type
);
info
.
Checker
().
Check
(
attrs
);
if
(
info
.
Checker
()
!=
nullptr
)
{
info
.
Checker
()
->
Check
(
attrs
);
}
auto
op
=
info
.
Creator
()(
type
,
inputs
,
outputs
,
attrs
);
auto
op
=
info
.
Creator
()(
type
,
inputs
,
outputs
,
attrs
);
return
std
::
unique_ptr
<
OperatorBase
>
(
op
);
return
std
::
unique_ptr
<
OperatorBase
>
(
op
);
}
}
...
@@ -52,9 +54,9 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
...
@@ -52,9 +54,9 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
return
CreateOp
(
op_desc
.
type
(),
inputs
,
outputs
,
attrs
);
return
CreateOp
(
op_desc
.
type
(),
inputs
,
outputs
,
attrs
);
}
}
std
::
unique_ptr
<
OperatorBase
>
OpRegistry
::
Create
GradOp
(
const
OperatorBase
&
op
)
{
std
::
unique_ptr
<
OperatorBase
>
OpRegistry
::
Create
Op
(
const
OpDescBind
&
op_desc
)
{
PADDLE_ENFORCE
(
!
op
.
IsNetOp
(),
"Use framework::Backward to get backward ops"
);
return
CreateOp
(
op_desc
.
Type
(),
op_desc
.
Inputs
(),
op_desc
.
Outputs
(),
return
std
::
unique_ptr
<
OperatorBase
>
(
BuildGradOp
(
&
op
));
op_desc
.
GetAttrMap
(
));
}
}
}
// namespace framework
}
// namespace framework
...
...
paddle/framework/op_registry.h
浏览文件 @
4c96008a
...
@@ -23,25 +23,37 @@ limitations under the License. */
...
@@ -23,25 +23,37 @@ limitations under the License. */
#include "paddle/framework/attribute.h"
#include "paddle/framework/attribute.h"
#include "paddle/framework/details/op_registry.h"
#include "paddle/framework/details/op_registry.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/grad_op_desc_maker.h"
#include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/scope.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
Registrar
{
public:
// In our design, various kinds of classes, e.g., operators and kernels,
// have their corresponding registry and registrar. The action of
// registration is in the constructor of a global registrar variable, which,
// however, are not used in the code that calls package framework, and would
// be removed from the generated binary file by the linker. To avoid such
// removal, we add Touch to all registrar classes and make USE_OP macros to
// call this method. So, as long as the callee code calls USE_OP, the global
// registrar variable won't be removed by the linker.
void
Touch
()
{}
};
template
<
typename
...
ARGS
>
template
<
typename
...
ARGS
>
struct
OperatorRegistrar
{
struct
OperatorRegistrar
:
public
Registrar
{
explicit
OperatorRegistrar
(
const
char
*
op_type
)
:
op_type
(
op_type
)
{
explicit
OperatorRegistrar
(
const
char
*
op_type
)
:
op_type
(
op_type
)
{
PADDLE_ENFORCE
(
!
OpInfoMap
::
Instance
().
Has
(
op_type
),
PADDLE_ENFORCE
(
!
OpInfoMap
::
Instance
().
Has
(
op_type
),
"'%s' is registered more than once."
,
op_type
);
"'%s' is registered more than once."
,
op_type
);
static_assert
(
sizeof
...(
ARGS
)
!=
0
,
static_assert
(
sizeof
...(
ARGS
)
!=
0
,
"OperatorRegistrar should be invoked at least by OpClass"
);
"OperatorRegistrar should be invoked at least by OpClass"
);
details
::
OperatorRegistrarRecursive
<
0
,
false
,
ARGS
...
>
(
op_type
,
&
info
);
details
::
OperatorRegistrarRecursive
<
0
,
false
,
ARGS
...
>
(
op_type
,
&
info
);
OpInfoMap
::
Instance
().
Insert
(
op_type
,
info
);
}
}
~
OperatorRegistrar
()
{
OpInfoMap
::
Instance
().
Insert
(
op_type
,
info
);
}
const
char
*
op_type
;
const
char
*
op_type
;
OpInfo
info
;
OpInfo
info
;
...
@@ -67,20 +79,7 @@ class OpRegistry {
...
@@ -67,20 +79,7 @@ class OpRegistry {
static
std
::
unique_ptr
<
OperatorBase
>
CreateOp
(
const
OpDesc
&
op_desc
);
static
std
::
unique_ptr
<
OperatorBase
>
CreateOp
(
const
OpDesc
&
op_desc
);
static
std
::
unique_ptr
<
OperatorBase
>
CreateGradOp
(
const
OperatorBase
&
op
);
static
std
::
unique_ptr
<
OperatorBase
>
CreateOp
(
const
OpDescBind
&
op_desc
);
};
class
Registrar
{
public:
// In our design, various kinds of classes, e.g., operators and kernels,
// have their corresponding registry and registrar. The action of
// registration is in the constructor of a global registrar variable, which,
// however, are not used in the code that calls package framework, and would
// be removed from the generated binary file by the linker. To avoid such
// removal, we add Touch to all registrar classes and make USE_OP macros to
// call this method. So, as long as the callee code calls USE_OP, the global
// registrar variable won't be removed by the linker.
void
Touch
()
{}
};
};
template
<
typename
OpType
,
typename
ProtoMakerType
,
typename
GradOpType
>
template
<
typename
OpType
,
typename
ProtoMakerType
,
typename
GradOpType
>
...
@@ -138,33 +137,41 @@ class OpKernelRegistrar : public Registrar {
...
@@ -138,33 +137,41 @@ class OpKernelRegistrar : public Registrar {
__test_global_namespace_##uniq_name##__>::value, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
msg)
#define REGISTER_OPERATOR(op_type, op_class, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, \
"REGISTER_OPERATOR must be called in global namespace"); \
class _OpClass_##op_type##_ : public op_class { \
public: \
DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \
DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \
}; \
static ::paddle::framework::OperatorRegistrar<_OpClass_##op_type##_, \
##__VA_ARGS__> \
__op_registrar_##op_type##__(#op_type); \
int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \
return 0; \
}
/**
/**
* Macro to register Operator.
* Macro to register Operator.
*/
*/
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class) \
grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
REGISTER_OPERATOR(grad_op_type, grad_op_class); \
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
class _GradOpDescMaker_##grad_op_type##_ \
class _OpClass_##op_type##_ : public op_class { \
: public ::paddle::framework::DefaultGradOpDescMaker { \
public: \
using ::paddle::framework::DefaultGradOpDescMaker::DefaultGradOpDescMaker; \
DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \
\
DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \
protected: \
}; \
virtual std::string GradOpType() const { return #grad_op_type; } \
class _OpGradClass_##op_type##_ : public grad_op_class { \
}; \
public: \
REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \
DEFINE_OP_CLONE_METHOD(_OpGradClass_##op_type##_); \
op_maker_class);
DEFINE_OP_CONSTRUCTOR(_OpGradClass_##op_type##_, grad_op_class); \
}; \
static ::paddle::framework::OpRegistrar< \
_OpClass_##op_type##_, op_maker_class, _OpGradClass_##op_type##_> \
__op_registrar_##op_type##__(#op_type, #grad_op_type); \
int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \
return 0; \
}
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
REGISTER_OP
(op_type, op_class, op_maker_class, , ::paddle::framework::NOP
)
REGISTER_OP
ERATOR(op_type, op_class, op_maker_class
)
/**
/**
* Macro to register OperatorKernel.
* Macro to register OperatorKernel.
...
...
paddle/operators/mean_op.cc
浏览文件 @
4c96008a
...
@@ -36,7 +36,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -36,7 +36,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
MeanOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
MeanOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The input of mean op"
);
AddInput
(
"X"
,
"The input of mean op"
);
AddOutput
(
"Out"
,
"The output of mean op"
)
.
NotInGradient
()
;
AddOutput
(
"Out"
,
"The output of mean op"
);
AddComment
(
R"DOC( Mean Operator
AddComment
(
R"DOC( Mean Operator
)DOC"
);
)DOC"
);
}
}
...
@@ -52,11 +52,27 @@ class MeanGradOp : public framework::OperatorWithKernel {
...
@@ -52,11 +52,27 @@ class MeanGradOp : public framework::OperatorWithKernel {
}
}
};
};
class
MeanGradMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
std
::
unique_ptr
<
framework
::
OpDescBind
>
Apply
()
const
override
{
auto
*
grad_op
=
new
framework
::
OpDescBind
();
grad_op
->
SetType
(
"mean_grad"
);
grad_op
->
SetInput
(
"X"
,
Input
(
"X"
));
grad_op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
return
std
::
unique_ptr
<
framework
::
OpDescBind
>
(
grad_op
);
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
mean
,
ops
::
MeanOp
,
ops
::
MeanOpMaker
,
mean_grad
,
ops
::
MeanGradOp
);
REGISTER_OPERATOR
(
mean
,
ops
::
MeanOp
,
ops
::
MeanOpMaker
,
ops
::
MeanGradMaker
);
REGISTER_OPERATOR
(
mean_grad
,
ops
::
MeanGradOp
);
REGISTER_OP_CPU_KERNEL
(
mean
,
REGISTER_OP_CPU_KERNEL
(
mean
,
ops
::
MeanKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
ops
::
MeanKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
mean_grad
,
REGISTER_OP_CPU_KERNEL
(
mean_grad
,
...
...
paddle/operators/minus_op.cc
浏览文件 @
4c96008a
...
@@ -49,9 +49,9 @@ class MinusOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -49,9 +49,9 @@ class MinusOpMaker : public framework::OpProtoAndCheckerMaker {
public:
public:
MinusOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
MinusOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The left tensor of minus operator."
)
.
NotInGradient
()
;
AddInput
(
"X"
,
"The left tensor of minus operator."
);
AddInput
(
"Y"
,
"The right tensor of minus operator."
)
.
NotInGradient
()
;
AddInput
(
"Y"
,
"The right tensor of minus operator."
);
AddOutput
(
"Out"
,
"The output tensor of minus operator."
)
.
NotInGradient
()
;
AddOutput
(
"Out"
,
"The output tensor of minus operator."
);
AddComment
(
R"DOC(Minus Operator
AddComment
(
R"DOC(Minus Operator
...
@@ -64,26 +64,35 @@ or not. But the output only shares the LoD with input `X`.
...
@@ -64,26 +64,35 @@ or not. But the output only shares the LoD with input `X`.
)DOC"
);
)DOC"
);
}
}
};
};
template
<
typename
AttrType
>
class
MinusGrad
Op
:
public
NetOp
{
class
MinusGrad
Maker
:
public
framework
::
GradOpDescMakerBase
{
public:
public:
MinusGradOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
using
framework
::
GradOpDescMakerBase
::
GradOpDescMakerBase
;
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
std
::
vector
<
std
::
unique_ptr
<
framework
::
OpDescBind
>>
operator
()()
:
NetOp
(
type
,
inputs
,
outputs
,
attrs
)
{
const
override
{
auto
out_grad
=
Input
(
framework
::
GradVarName
(
"Out"
));
std
::
vector
<
std
::
unique_ptr
<
framework
::
OpDescBind
>>
ops
;
auto
x_grad
=
Output
(
framework
::
GradVarName
(
"X"
));
auto
x_g
=
InputGrad
(
"X"
);
auto
y_grad
=
Output
(
framework
::
GradVarName
(
"Y"
));
if
(
!
x_g
.
empty
())
{
auto
*
x_g_op
=
new
framework
::
OpDescBind
();
// x_grad = out_grad
x_g_op
->
SetType
(
"scale"
);
AppendOp
(
framework
::
OpRegistry
::
CreateOp
(
"identity"
,
{{
"X"
,
{
out_grad
}}},
x_g_op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
{{
"Y"
,
{
x_grad
}}},
{}));
x_g_op
->
SetOutput
(
"Out"
,
x_g
);
x_g_op
->
SetAttr
(
"scale"
,
1.0
f
);
framework
::
AttributeMap
scale_attr
;
ops
.
emplace_back
(
x_g_op
);
scale_attr
[
"scale"
]
=
static_cast
<
AttrType
>
(
-
1
);
}
AppendOp
(
framework
::
OpRegistry
::
CreateOp
(
"scale"
,
{{
"X"
,
{
out_grad
}}},
{{
"Out"
,
{
y_grad
}}},
scale_attr
));
auto
y_g
=
InputGrad
(
"Y"
);
CompleteAddOp
(
false
);
if
(
!
y_g
.
empty
())
{
auto
*
y_g_op
=
new
framework
::
OpDescBind
();
y_g_op
->
SetType
(
"scale"
);
y_g_op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
y_g_op
->
SetOutput
(
"Out"
,
y_g
);
y_g_op
->
SetAttr
(
"scale"
,
-
1.0
f
);
ops
.
emplace_back
(
y_g_op
);
}
return
ops
;
}
}
};
};
...
@@ -91,7 +100,6 @@ class MinusGradOp : public NetOp {
...
@@ -91,7 +100,6 @@ class MinusGradOp : public NetOp {
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
minus
,
ops
::
MinusOp
,
ops
::
MinusOpMaker
,
minus_grad
,
REGISTER_OPERATOR
(
minus
,
ops
::
MinusOp
,
ops
::
MinusOpMaker
,
ops
::
MinusGradMaker
);
ops
::
MinusGradOp
<
float
>
);
REGISTER_OP_CPU_KERNEL
(
minus
,
REGISTER_OP_CPU_KERNEL
(
minus
,
ops
::
MinusKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
ops
::
MinusKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/pad_op.cc
浏览文件 @
4c96008a
...
@@ -56,8 +56,7 @@ class PadOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -56,8 +56,7 @@ class PadOpMaker : public framework::OpProtoAndCheckerMaker {
"The input should be a k-D tensor(k > 0 and k < 7)"
);
"The input should be a k-D tensor(k > 0 and k < 7)"
);
AddOutput
(
"Out"
,
AddOutput
(
"Out"
,
"The output of pad op."
"The output of pad op."
"A tensor with the same shape as X."
)
"A tensor with the same shape as X."
);
.
NotInGradient
();
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Pad input into output, as specified by paddings and pad_value. The input should be a k-D tensor(k > 0 and k < 7). As an example:
Pad input into output, as specified by paddings and pad_value. The input should be a k-D tensor(k > 0 and k < 7). As an example:
...
@@ -111,11 +110,29 @@ class PadOpGrad : public framework::OperatorWithKernel {
...
@@ -111,11 +110,29 @@ class PadOpGrad : public framework::OperatorWithKernel {
}
}
};
};
class
PadOpGradMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
std
::
unique_ptr
<
framework
::
OpDescBind
>
Apply
()
const
override
{
auto
*
bind
=
new
framework
::
OpDescBind
();
bind
->
SetInput
(
"X"
,
Input
(
"X"
));
bind
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
OutputGrad
(
"Out"
));
bind
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
InputGrad
(
"X"
));
bind
->
SetAttrMap
(
Attrs
());
bind
->
SetType
(
"pad_grad"
);
return
std
::
unique_ptr
<
framework
::
OpDescBind
>
(
bind
);
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
pad
,
ops
::
PadOp
,
ops
::
PadOpMaker
,
pad_grad
,
ops
::
PadOpGrad
);
REGISTER_OPERATOR
(
pad
,
ops
::
PadOp
,
ops
::
PadOpMaker
,
ops
::
PadOpGradMaker
);
REGISTER_OPERATOR
(
pad_grad
,
ops
::
PadOpGrad
);
REGISTER_OP_CPU_KERNEL
(
pad
,
ops
::
PadKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
pad
,
ops
::
PadKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
pad_grad
,
REGISTER_OP_CPU_KERNEL
(
pad_grad
,
ops
::
PadGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
ops
::
PadGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/scale_op.cc
浏览文件 @
4c96008a
...
@@ -41,8 +41,8 @@ class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -41,8 +41,8 @@ class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
public:
ScaleOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
ScaleOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The input tensor of scale operator."
)
.
NotInGradient
()
;
AddInput
(
"X"
,
"The input tensor of scale operator."
);
AddOutput
(
"Out"
,
"The output tensor of scale operator."
)
.
NotInGradient
()
;
AddOutput
(
"Out"
,
"The output tensor of scale operator."
);
AddComment
(
R"DOC(Scale operator
AddComment
(
R"DOC(Scale operator
The equation is: Out = scale*X
The equation is: Out = scale*X
...
@@ -52,21 +52,18 @@ The equation is: Out = scale*X
...
@@ -52,21 +52,18 @@ The equation is: Out = scale*X
}
}
};
};
// The operator to calculate gradients of a scale operator is just the scale
class
ScaleGradMaker
:
public
framework
::
SingleGradOpDescMaker
{
// operator itself.
// Grad(Out=scale(X)) => Grad(X) = scale(Grad(Out))
template
<
typename
AttrType
>
class
ScaleGradOp
:
public
NetOp
{
public:
public:
ScaleGradOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
protected:
:
NetOp
(
type
,
inputs
,
outputs
,
attrs
)
{
std
::
unique_ptr
<
framework
::
OpDescBind
>
Apply
()
const
override
{
AppendOp
(
framework
::
OpRegistry
::
CreateOp
(
auto
*
grad_op
=
new
framework
::
OpDescBind
();
"scale"
,
{{
"X"
,
{
Input
(
framework
::
GradVarName
(
"Out"
))}}},
grad_op
->
SetType
(
"scale"
);
{{
"Out"
,
{
Output
(
framework
::
GradVarName
(
"X"
))}}},
grad_op
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
{{
"scale"
,
Attr
<
AttrType
>
(
"scale"
)}}));
grad_op
->
SetOutput
(
"Out"
,
InputGrad
(
"X"
));
CompleteAddOp
(
false
);
grad_op
->
SetAttr
(
"scale"
,
GetAttr
(
"scale"
));
return
std
::
unique_ptr
<
framework
::
OpDescBind
>
(
grad_op
);
}
}
};
};
...
@@ -75,7 +72,7 @@ class ScaleGradOp : public NetOp {
...
@@ -75,7 +72,7 @@ class ScaleGradOp : public NetOp {
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
scale
,
ops
::
ScaleOp
,
ops
::
ScaleOpMaker
<
float
>
,
scale_grad
,
REGISTER_OP
ERATOR
(
scale
,
ops
::
ScaleOp
,
ops
::
ScaleOpMaker
<
float
>
,
ops
::
ScaleGradOp
<
float
>
);
ops
::
ScaleGradMaker
);
REGISTER_OP_CPU_KERNEL
(
scale
,
REGISTER_OP_CPU_KERNEL
(
scale
,
ops
::
ScaleKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
ops
::
ScaleKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/softmax_with_cross_entropy_op.cc
浏览文件 @
4c96008a
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "paddle/operators/softmax_with_cross_entropy_op.h"
#include "paddle/operators/softmax_with_cross_entropy_op.h"
#include <paddle/function/TensorType.h>
#include <paddle/function/TensorType.h>
#include <iostream>
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -27,15 +28,14 @@ class SoftmaxWithCrossEntropyOpMaker
...
@@ -27,15 +28,14 @@ class SoftmaxWithCrossEntropyOpMaker
AddInput
(
"Logits"
,
AddInput
(
"Logits"
,
"(Tensor, default: Tensor<float>), The unscaled log probabilities "
"(Tensor, default: Tensor<float>), The unscaled log probabilities "
"which is a 2-D tensor with shape [N x K]. N is the batch_size, "
"which is a 2-D tensor with shape [N x K]. N is the batch_size, "
"and K is the class number."
)
"and K is the class number."
);
.
NotInGradient
();
AddInput
(
"Label"
,
AddInput
(
"(Tensor, default: Tensor<int>), The ground truth which is a 2-D "
"Label"
,
"tensor. "
"(Tensor, default: Tensor<int>), The ground truth which is a 2-D "
"If softLable is set to 0, Label is a Tensor<int> with shape [N x "
"tensor. "
"1]. "
"If softLable is set to 0, Label is a Tensor<int> with shape [N x 1]. "
"If softLable is set to 1, Label is a Tensor<float/double> "
"If softLable is set to 1, Label is a Tensor<float/double> "
"with shape [N x K]."
);
"with shape [N x K]."
);
AddOutput
(
AddOutput
(
"Softmax"
,
"Softmax"
,
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x K]. "
"(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x K]. "
...
@@ -163,15 +163,34 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
...
@@ -163,15 +163,34 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
}
}
};
};
class
SoftmaxGradMaker
:
public
framework
::
SingleGradOpDescMaker
{
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
protected:
std
::
unique_ptr
<
framework
::
OpDescBind
>
Apply
()
const
override
{
auto
*
grad_op
=
new
framework
::
OpDescBind
();
grad_op
->
SetType
(
"softmax_with_cross_entropy_grad"
);
grad_op
->
SetInput
(
"Label"
,
Input
(
"Label"
));
grad_op
->
SetInput
(
"Softmax"
,
Output
(
"Softmax"
));
grad_op
->
SetInput
(
"Loss"
,
Output
(
"Loss"
));
grad_op
->
SetInput
(
framework
::
GradVarName
(
"Softmax"
),
OutputGrad
(
"Softmax"
));
grad_op
->
SetInput
(
framework
::
GradVarName
(
"Loss"
),
OutputGrad
(
"Loss"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"Logits"
),
InputGrad
(
"Logits"
));
grad_op
->
SetAttrMap
(
Attrs
());
return
std
::
unique_ptr
<
framework
::
OpDescBind
>
(
grad_op
);
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
softmax_with_cross_entropy
,
ops
::
SoftmaxWithCrossEntropyOp
,
REGISTER_OP
ERATOR
(
softmax_with_cross_entropy
,
ops
::
SoftmaxWithCrossEntropyOp
,
ops
::
SoftmaxWithCrossEntropyOpMaker
,
ops
::
SoftmaxWithCrossEntropyOpMaker
,
ops
::
SoftmaxGradMaker
);
softmax_with_cross_entropy_grad
,
REGISTER_OPERATOR
(
softmax_with_cross_entropy_grad
,
ops
::
SoftmaxWithCrossEntropyOpGrad
);
ops
::
SoftmaxWithCrossEntropyOpGrad
);
REGISTER_OP_CPU_KERNEL
(
softmax_with_cross_entropy
,
REGISTER_OP_CPU_KERNEL
(
softmax_with_cross_entropy
,
ops
::
SoftmaxWithCrossEntropyKernel
<
float
>
);
ops
::
SoftmaxWithCrossEntropyKernel
<
float
>
);
REGISTER_OP_CPU_KERNEL
(
softmax_with_cross_entropy_grad
,
REGISTER_OP_CPU_KERNEL
(
softmax_with_cross_entropy_grad
,
...
...
paddle/operators/sum_op.cc
浏览文件 @
4c96008a
...
@@ -45,10 +45,8 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -45,10 +45,8 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
public:
SumOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
SumOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"the input tensors of sum operator."
)
AddInput
(
"X"
,
"the input tensors of sum operator."
).
AsDuplicable
();
.
AsDuplicable
()
AddOutput
(
"Out"
,
"the output tensor of sum operator."
);
.
NotInGradient
();
AddOutput
(
"Out"
,
"the output tensor of sum operator."
).
NotInGradient
();
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Sum the input tensors.
Sum the input tensors.
...
@@ -58,23 +56,26 @@ or not. But the output only shares the LoD with the first input.
...
@@ -58,23 +56,26 @@ or not. But the output only shares the LoD with the first input.
}
}
};
};
class
SumGrad
Op
:
public
NetOp
{
class
SumGrad
Maker
:
public
framework
::
GradOpDescMakerBase
{
public:
public:
SumGradOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
using
framework
::
GradOpDescMakerBase
::
GradOpDescMakerBase
;
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
NetOp
(
type
,
inputs
,
outputs
,
attrs
)
{
auto
&
x_grad_names
=
Outputs
(
framework
::
GradVarName
(
"X"
));
auto
out_grad_name
=
this
->
Input
(
framework
::
GradVarName
(
"Out"
));
framework
::
AttributeMap
grad_attrs
;
std
::
vector
<
std
::
unique_ptr
<
framework
::
OpDescBind
>>
operator
()()
grad_attrs
[
"scale"
]
=
1.0
f
;
const
override
{
for
(
auto
&
x_grad_name
:
x_grad_names
)
{
auto
x_grads
=
InputGrad
(
"X"
);
AppendOp
(
framework
::
OpRegistry
::
CreateOp
(
std
::
vector
<
std
::
unique_ptr
<
framework
::
OpDescBind
>>
grad_ops
;
"scale"
,
{{
"X"
,
{
out_grad_name
}}},
{{
"Out"
,
{
x_grad_name
}}},
grad_ops
.
reserve
(
x_grads
.
size
());
grad_attrs
));
auto
og
=
OutputGrad
(
"Out"
);
}
std
::
transform
(
x_grads
.
begin
(),
x_grads
.
end
(),
std
::
back_inserter
(
grad_ops
),
CompleteAddOp
(
false
);
[
&
og
](
const
std
::
string
&
x_grad
)
{
auto
*
grad_op
=
new
framework
::
OpDescBind
();
grad_op
->
SetType
(
"scale"
);
grad_op
->
SetInput
(
"X"
,
og
);
grad_op
->
SetOutput
(
"Out"
,
{
x_grad
});
grad_op
->
SetAttr
(
"scale"
,
1.0
f
);
return
std
::
unique_ptr
<
framework
::
OpDescBind
>
(
grad_op
);
});
return
grad_ops
;
}
}
};
};
...
@@ -82,5 +83,6 @@ class SumGradOp : public NetOp {
...
@@ -82,5 +83,6 @@ class SumGradOp : public NetOp {
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
sum
,
ops
::
SumOp
,
ops
::
SumOpMaker
,
sum_grad
,
ops
::
SumGradOp
);
REGISTER_OPERATOR
(
sum
,
ops
::
SumOp
,
ops
::
SumOpMaker
,
ops
::
SumGradMaker
);
REGISTER_OP_CPU_KERNEL
(
sum
,
ops
::
SumKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
sum
,
ops
::
SumKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录