Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
53542d93
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
53542d93
编写于
10月 12, 2017
作者:
Q
qijun
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'baidu/develop' into selected_rows
上级
c49adb86
6316b40a
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
158 addition
and
69 deletion
+158
-69
CMakeLists.txt
CMakeLists.txt
+6
-0
cmake/external/eigen.cmake
cmake/external/eigen.cmake
+1
-1
cmake/external/gflags.cmake
cmake/external/gflags.cmake
+3
-2
cmake/external/glog.cmake
cmake/external/glog.cmake
+3
-2
cmake/external/gtest.cmake
cmake/external/gtest.cmake
+2
-2
cmake/external/protobuf.cmake
cmake/external/protobuf.cmake
+2
-2
cmake/external/warpctc.cmake
cmake/external/warpctc.cmake
+3
-2
cmake/external/zlib.cmake
cmake/external/zlib.cmake
+2
-2
paddle/framework/backward.cc
paddle/framework/backward.cc
+9
-10
paddle/framework/backward_test.cc
paddle/framework/backward_test.cc
+62
-8
paddle/framework/details/op_registry.h
paddle/framework/details/op_registry.h
+4
-2
paddle/framework/executor_test.cc
paddle/framework/executor_test.cc
+10
-0
paddle/framework/grad_op_desc_maker.h
paddle/framework/grad_op_desc_maker.h
+34
-12
paddle/framework/op_registry.cc
paddle/framework/op_registry.cc
+0
-6
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+12
-14
paddle/framework/type_defs.h
paddle/framework/type_defs.h
+2
-2
paddle/operators/multiplex_op.cc
paddle/operators/multiplex_op.cc
+3
-2
未找到文件。
CMakeLists.txt
浏览文件 @
53542d93
...
...
@@ -105,6 +105,12 @@ if (WITH_C_API AND WITH_PYTHON)
"different Python interpreter from compiling."
)
endif
()
if
(
MOBILE_INFERENCE
)
set
(
THIRD_PARTY_BUILD_TYPE MinSizeRel
)
else
()
set
(
THIRD_PARTY_BUILD_TYPE Release
)
endif
()
########################################################################################
include
(
external/mklml
)
# download mklml package
...
...
cmake/external/eigen.cmake
浏览文件 @
53542d93
...
...
@@ -8,7 +8,7 @@ ExternalProject_Add(
extern_eigen3
${
EXTERNAL_PROJECT_LOG_ARGS
}
GIT_REPOSITORY
"https://github.com/RLovelett/eigen.git"
GIT_TAG
"master"
GIT_TAG
4e79cb69b9425f5f8c3a84be4350d4ab75b5fd9d
PREFIX
${
EIGEN_SOURCE_DIR
}
UPDATE_COMMAND
""
CONFIGURE_COMMAND
""
...
...
cmake/external/gflags.cmake
浏览文件 @
53542d93
...
...
@@ -36,6 +36,7 @@ ExternalProject_Add(
# change this back to the official Github repo once my PR is
# merged.
GIT_REPOSITORY
"https://github.com/wangkuiyi/gflags.git"
GIT_TAG 986964c07427ecb9cdb5bd73f73ebbd40e54dadb
PREFIX
${
GFLAGS_SOURCES_DIR
}
UPDATE_COMMAND
""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=
${
CMAKE_CXX_COMPILER
}
...
...
@@ -45,11 +46,11 @@ ExternalProject_Add(
-DCMAKE_INSTALL_PREFIX=
${
GFLAGS_INSTALL_DIR
}
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DBUILD_TESTING=OFF
-DCMAKE_BUILD_TYPE=
Release
-DCMAKE_BUILD_TYPE=
${
THIRD_PARTY_BUILD_TYPE
}
${
EXTERNAL_OPTIONAL_ARGS
}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=
${
GFLAGS_INSTALL_DIR
}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=
Release
-DCMAKE_BUILD_TYPE:STRING=
${
THIRD_PARTY_BUILD_TYPE
}
)
ADD_LIBRARY
(
gflags STATIC IMPORTED GLOBAL
)
...
...
cmake/external/glog.cmake
浏览文件 @
53542d93
...
...
@@ -31,6 +31,7 @@ ExternalProject_Add(
${
EXTERNAL_PROJECT_LOG_ARGS
}
DEPENDS gflags
GIT_REPOSITORY
"https://github.com/google/glog.git"
GIT_TAG v0.3.5
PREFIX
${
GLOG_SOURCES_DIR
}
UPDATE_COMMAND
""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=
${
CMAKE_CXX_COMPILER
}
...
...
@@ -43,12 +44,12 @@ ExternalProject_Add(
-DWITH_GFLAGS=ON
-Dgflags_DIR=
${
GFLAGS_INSTALL_DIR
}
/lib/cmake/gflags
-DBUILD_TESTING=OFF
-DCMAKE_BUILD_TYPE=
Release
-DCMAKE_BUILD_TYPE=
${
THIRD_PARTY_BUILD_TYPE
}
${
EXTERNAL_OPTIONAL_ARGS
}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=
${
GLOG_INSTALL_DIR
}
-DCMAKE_INSTALL_LIBDIR:PATH=
${
GLOG_INSTALL_DIR
}
/lib
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=
Release
-DCMAKE_BUILD_TYPE:STRING=
${
THIRD_PARTY_BUILD_TYPE
}
)
ADD_LIBRARY
(
glog STATIC IMPORTED GLOBAL
)
...
...
cmake/external/gtest.cmake
浏览文件 @
53542d93
...
...
@@ -56,11 +56,11 @@ IF(WITH_TESTING)
-DBUILD_GMOCK=ON
-Dgtest_disable_pthreads=ON
-Dgtest_force_shared_crt=ON
-DCMAKE_BUILD_TYPE=
Release
-DCMAKE_BUILD_TYPE=
${
THIRD_PARTY_BUILD_TYPE
}
${
EXTERNAL_OPTIONAL_ARGS
}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=
${
GTEST_INSTALL_DIR
}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=
Release
-DCMAKE_BUILD_TYPE:STRING=
${
THIRD_PARTY_BUILD_TYPE
}
)
ADD_LIBRARY
(
gtest STATIC IMPORTED GLOBAL
)
...
...
cmake/external/protobuf.cmake
浏览文件 @
53542d93
...
...
@@ -191,12 +191,12 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
${
OPTIONAL_ARGS
}
-Dprotobuf_BUILD_TESTS=OFF
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=
Release
-DCMAKE_BUILD_TYPE=
${
THIRD_PARTY_BUILD_TYPE
}
-DCMAKE_INSTALL_PREFIX=
${
PROTOBUF_INSTALL_DIR
}
-DCMAKE_INSTALL_LIBDIR=lib
CMAKE_CACHE_ARGS
-DCMAKE_INSTALL_PREFIX:PATH=
${
PROTOBUF_INSTALL_DIR
}
-DCMAKE_BUILD_TYPE:STRING=
Release
-DCMAKE_BUILD_TYPE:STRING=
${
THIRD_PARTY_BUILD_TYPE
}
-DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
${
OPTIONAL_CACHE_ARGS
}
...
...
cmake/external/warpctc.cmake
浏览文件 @
53542d93
...
...
@@ -35,6 +35,7 @@ ExternalProject_Add(
extern_warpctc
${
EXTERNAL_PROJECT_LOG_ARGS
}
GIT_REPOSITORY
"https://github.com/gangliao/warp-ctc.git"
GIT_TAG b63a0644654a3e0ed624c85a1767bc8193aead09
PREFIX
${
WARPCTC_SOURCES_DIR
}
UPDATE_COMMAND
""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=
${
CMAKE_CXX_COMPILER
}
...
...
@@ -48,9 +49,9 @@ ExternalProject_Add(
-DCMAKE_DISABLE_FIND_PACKAGE_Torch=ON
-DBUILD_SHARED=ON
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=
Release
-DCMAKE_BUILD_TYPE=
${
THIRD_PARTY_BUILD_TYPE
}
${
EXTERNAL_OPTIONAL_ARGS
}
CMAKE_CACHE_ARGS -DCMAKE_BUILD_TYPE:STRING=
Release
CMAKE_CACHE_ARGS -DCMAKE_BUILD_TYPE:STRING=
${
THIRD_PARTY_BUILD_TYPE
}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_INSTALL_PREFIX:PATH=
${
WARPCTC_INSTALL_DIR
}
)
...
...
cmake/external/zlib.cmake
浏览文件 @
53542d93
...
...
@@ -42,11 +42,11 @@ ExternalProject_Add(
-DBUILD_SHARED_LIBS=OFF
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_MACOSX_RPATH=ON
-DCMAKE_BUILD_TYPE=
Release
-DCMAKE_BUILD_TYPE=
${
THIRD_PARTY_BUILD_TYPE
}
${
EXTERNAL_OPTIONAL_ARGS
}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=
${
ZLIB_INSTALL_DIR
}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=
Release
-DCMAKE_BUILD_TYPE:STRING=
${
THIRD_PARTY_BUILD_TYPE
}
)
LIST
(
APPEND external_project_dependencies zlib
)
...
...
paddle/framework/backward.cc
浏览文件 @
53542d93
...
...
@@ -28,14 +28,15 @@ namespace paddle {
namespace
framework
{
static
inline
std
::
unique_ptr
<
OperatorBase
>
CreateGradOp
(
const
OperatorBase
&
op
)
{
const
OperatorBase
&
op
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_set
)
{
OpDescBind
op_desc
;
op_desc
.
SetInputMap
(
op
.
Inputs
());
op_desc
.
SetOutputMap
(
op
.
Outputs
());
op_desc
.
SetType
(
op
.
Type
());
op_desc
.
SetAttrMap
(
op
.
Attrs
());
auto
&
info
=
OpInfoMap
::
Instance
().
Get
(
op
.
Type
());
auto
grad_descs
=
info
.
GradOpMaker
()(
op_desc
);
auto
grad_descs
=
info
.
GradOpMaker
()(
op_desc
,
no_grad_set
);
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>
grad_ops
;
grad_ops
.
reserve
(
grad_descs
.
size
());
std
::
transform
(
grad_descs
.
begin
(),
grad_descs
.
end
(),
...
...
@@ -187,7 +188,8 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
net
->
InsertOp
(
pos
.
first
+
1
,
std
::
move
(
pos
.
second
));
}
}
else
{
std
::
unique_ptr
<
OperatorBase
>
grad_op
(
CreateGradOp
(
forwardOp
));
std
::
unique_ptr
<
OperatorBase
>
grad_op
(
CreateGradOp
(
forwardOp
,
no_grad_names
));
ForEachVarName
(
grad_op
->
Inputs
(),
[
&
no_grad_names
,
&
net
,
&
grad_op
](
const
std
::
string
&
grad_input
)
{
...
...
@@ -272,7 +274,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
const
std
::
unique_ptr
<
OpDescBind
>&
op_desc
,
std
::
unordered_set
<
std
::
string
>&
no_grad_vars
)
{
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
grad_op_descs
;
// All input gradients of forwarding operator do not need to calculat.
// All input gradients of forwarding operator do not need to calculat
e
.
const
std
::
vector
<
std
::
string
>&
inputs
=
op_desc
->
InputArgumentNames
();
if
(
AllGradInSet
(
inputs
,
no_grad_vars
))
{
return
grad_op_descs
;
// empty vector
...
...
@@ -286,7 +288,9 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
return
grad_op_descs
;
// empty vector
}
grad_op_descs
=
OpRegistry
::
CreateGradOpDescs
(
op_desc
.
get
());
grad_op_descs
=
OpInfoMap
::
Instance
()
.
Get
(
op_desc
->
Type
())
.
GradOpMaker
()(
*
op_desc
,
no_grad_vars
);
std
::
list
<
std
::
unique_ptr
<
OpDescBind
>>
pending_fill_zeros_ops
;
for
(
auto
&
desc
:
grad_op_descs
)
{
...
...
@@ -301,11 +305,6 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
pending_fill_zeros_ops
.
push_back
(
std
::
move
(
fill_zeros_op
));
}
}
for
(
const
std
::
string
&
out_name
:
desc
->
OutputArgumentNames
())
{
if
(
no_grad_vars
.
count
(
out_name
))
{
desc
->
Rename
(
out_name
,
kEmptyVarName
);
}
}
}
for
(
auto
&
p
:
pending_fill_zeros_ops
)
{
...
...
paddle/framework/backward_test.cc
浏览文件 @
53542d93
...
...
@@ -169,6 +169,45 @@ class MultInOutOpMaker : public OpProtoAndCheckerMaker {
}
};
class
MinusGradOpDescMaker
:
public
GradOpDescMakerBase
{
public:
using
GradOpDescMakerBase
::
GradOpDescMakerBase
;
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
operator
()()
const
override
{
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
retv
;
auto
x_g
=
InputGrad
(
"X"
);
if
(
!
x_g
.
empty
())
{
auto
*
op_desc
=
new
OpDescBind
();
op_desc
->
SetType
(
"scale"
);
op_desc
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
op_desc
->
SetOutput
(
"Out"
,
x_g
);
op_desc
->
SetAttr
(
"scale"
,
1.0
f
);
retv
.
emplace_back
(
op_desc
);
}
auto
y_g
=
InputGrad
(
"Y"
);
if
(
!
y_g
.
empty
())
{
auto
*
op_desc
=
new
OpDescBind
();
op_desc
->
SetType
(
"scale"
);
op_desc
->
SetInput
(
"X"
,
OutputGrad
(
"Out"
));
op_desc
->
SetOutput
(
"Out"
,
y_g
);
op_desc
->
SetAttr
(
"scale"
,
-
1.0
f
);
retv
.
emplace_back
(
op_desc
);
}
return
retv
;
}
};
class
MinusOpMaker
:
public
OpProtoAndCheckerMaker
{
public:
MinusOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
""
);
AddInput
(
"Y"
,
""
);
AddOutput
(
"Out"
,
""
);
AddComment
(
"minus for unittest"
);
}
};
}
// namespace framework
}
// namespace paddle
...
...
@@ -187,6 +226,7 @@ REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP
(
many_output_op
,
f
::
NOP
,
f
::
ManyOutputOpMaker
,
many_output_op_grad
,
f
::
NOP
);
REGISTER_OP
(
mult_in_out
,
f
::
NOP
,
f
::
MultInOutOpMaker
,
mult_in_out_grad
,
f
::
NOP
);
REGISTER_OPERATOR
(
minus
,
f
::
NOP
,
f
::
MinusOpMaker
,
f
::
MinusGradOpDescMaker
);
TEST
(
Backward
,
simple_op_not_need_grad
)
{
auto
fwd
=
f
::
OpRegistry
::
CreateOp
(
...
...
@@ -395,12 +435,13 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
2UL
/* external input number */
+
1UL
/* external output number*/
+
1UL
/* number of gradient of external output*/
+
2U
/* internal variable number*/
);
+
2UL
/* internal variable number*/
);
EXPECT_EQ
(
grad_fc
.
Outputs
(
all
).
size
(),
2UL
/* input number of mul*/
+
2UL
/* input number of rowwise_add
*/
+
1UL
/* input number of sigmod
*/
);
+
2UL
/* input number of rowwise_add
*/
+
1UL
/* input number of sigmod
*/
-
1UL
/* out2 is not needed
*/
);
EXPECT_EQ
(
bwd_net
->
ops_
[
1
]
->
Inputs
(
all
).
size
(),
0UL
);
EXPECT_EQ
(
bwd_net
->
ops_
[
1
]
->
Outputs
(
all
).
size
(),
0UL
);
EXPECT_EQ
(
bwd_net
->
ops_
[
2
]
->
Inputs
(
all
).
size
(),
0UL
);
...
...
@@ -580,8 +621,7 @@ TEST(Backward, intermedia_var_no_grad) {
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"out4"
)}));
EXPECT_EQ
(
grad_op4
->
Output
(
f
::
GradVarName
(
"X"
)),
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"out1"
)}));
EXPECT_EQ
(
grad_op4
->
Output
(
f
::
GradVarName
(
"Y"
)),
std
::
vector
<
std
::
string
>
({
f
::
kEmptyVarName
}));
EXPECT_EQ
(
grad_op4
->
Output
(
f
::
GradVarName
(
"Y"
)),
std
::
vector
<
std
::
string
>
());
}
TEST
(
Backward
,
var_no_grad
)
{
...
...
@@ -619,8 +659,7 @@ TEST(Backward, var_no_grad) {
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"z2"
)}));
EXPECT_EQ
(
grad_op2
->
Output
(
f
::
GradVarName
(
"X"
)),
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"y1"
)}));
EXPECT_EQ
(
grad_op2
->
Output
(
f
::
GradVarName
(
"H"
)),
std
::
vector
<
std
::
string
>
({
f
::
kEmptyVarName
}));
EXPECT_EQ
(
grad_op2
->
Output
(
f
::
GradVarName
(
"H"
)),
std
::
vector
<
std
::
string
>
());
f
::
OpDescBind
*
fill_zero_op
=
block
->
AllOps
()[
3
];
ASSERT_EQ
(
fill_zero_op
->
Type
(),
"fill_zeros_like"
);
...
...
@@ -718,4 +757,19 @@ TEST(Backward, shared_var) {
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"x1"
)}));
EXPECT_EQ
(
grad_op1
->
Output
(
f
::
GradVarName
(
"b"
)),
std
::
vector
<
std
::
string
>
({
f
::
GradVarName
(
"b1"
)}));
}
TEST
(
Backward
,
half_backward
)
{
f
::
ProgramDesc
*
program_desc
=
GetNewProgramDesc
();
f
::
ProgramDescBind
&
program
=
f
::
ProgramDescBind
::
Instance
(
program_desc
);
f
::
BlockDescBind
*
block
=
program
.
Block
(
0
);
auto
*
op1
=
block
->
AppendOp
();
op1
->
SetType
(
"minus"
);
op1
->
SetInput
(
"X"
,
{
"a"
});
op1
->
SetInput
(
"Y"
,
{
"b"
});
op1
->
SetOutput
(
"Out"
,
{
"out"
});
AppendBackward
(
program
,
{
"b"
});
auto
ops
=
block
->
AllOps
();
ASSERT_EQ
(
2UL
,
ops
.
size
());
}
\ No newline at end of file
paddle/framework/details/op_registry.h
浏览文件 @
53542d93
...
...
@@ -97,8 +97,10 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
template
<
typename
T
>
struct
OpInfoFiller
<
T
,
kGradOpDescMaker
>
{
void
operator
()(
const
char
*
op_type
,
OpInfo
*
info
)
const
{
info
->
grad_op_maker_
=
[](
const
OpDescBind
&
fwd_op
)
{
T
maker
(
fwd_op
);
info
->
grad_op_maker_
=
[](
const
OpDescBind
&
fwd_op
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_set
)
{
T
maker
(
fwd_op
,
no_grad_set
);
return
maker
();
};
}
...
...
paddle/framework/executor_test.cc
浏览文件 @
53542d93
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <memory>
#include <vector>
#include "gflags/gflags.h"
#include "gtest/gtest.h"
#include "paddle/framework/attribute.h"
#include "paddle/framework/backward.h"
...
...
@@ -317,3 +318,12 @@ TEST_F(ExecutorTesterFeedAndFetch, GPU) {
}
}
#endif
DECLARE_double
(
fraction_of_gpu_memory_to_use
);
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
// Use less GPU memory for unittest.
FLAGS_fraction_of_gpu_memory_to_use
=
0.25
;
return
RUN_ALL_TESTS
();
}
\ No newline at end of file
paddle/framework/grad_op_desc_maker.h
浏览文件 @
53542d93
...
...
@@ -13,6 +13,8 @@
limitations under the License. */
#pragma once
#include <string>
#include <unordered_set>
#include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h"
...
...
@@ -21,27 +23,44 @@ namespace framework {
class
GradOpDescMakerBase
{
public:
explicit
GradOpDescMakerBase
(
const
OpDescBind
&
fwd_op
)
:
fwd_op_
(
fwd_op
)
{}
explicit
GradOpDescMakerBase
(
const
OpDescBind
&
fwd_op
,
const
std
::
unordered_set
<
std
::
string
>&
no_grad_set
)
:
fwd_op_
(
fwd_op
),
no_grad_set_
(
no_grad_set
)
{}
virtual
~
GradOpDescMakerBase
()
=
default
;
virtual
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
operator
()()
const
=
0
;
protected:
st
atic
std
::
vector
<
std
::
string
>
ToGradNames
(
const
std
::
vector
<
std
::
string
>&
var_names
)
{
st
d
::
vector
<
std
::
string
>
InputGrad
(
const
std
::
string
&
name
,
bool
drop_empty_grad
=
true
)
const
{
std
::
vector
<
std
::
string
>
ret_val
;
auto
var_names
=
this
->
Input
(
name
);
ret_val
.
reserve
(
var_names
.
size
());
std
::
transform
(
var_names
.
begin
(),
var_names
.
end
(),
std
::
back_inserter
(
ret_val
),
GradVarName
);
return
ret_val
;
}
std
::
vector
<
std
::
string
>
InputGrad
(
const
std
::
string
&
name
)
const
{
return
ToGradNames
(
fwd_op_
.
Input
(
name
));
std
::
transform
(
var_names
.
begin
(),
var_names
.
end
(),
std
::
back_inserter
(
ret_val
),
[
this
](
const
std
::
string
&
fwd_var_name
)
->
std
::
string
{
auto
g_name
=
GradVarName
(
fwd_var_name
);
return
no_grad_set_
.
count
(
g_name
)
==
0
?
g_name
:
kEmptyVarName
;
});
if
(
!
drop_empty_grad
)
{
return
ret_val
;
}
std
::
vector
<
std
::
string
>
dropped_ret_val
;
dropped_ret_val
.
reserve
(
ret_val
.
size
());
std
::
copy_if
(
ret_val
.
begin
(),
ret_val
.
end
(),
std
::
back_inserter
(
dropped_ret_val
),
[](
const
std
::
string
&
str
)
{
return
str
!=
kEmptyVarName
;
});
return
dropped_ret_val
;
}
std
::
vector
<
std
::
string
>
OutputGrad
(
const
std
::
string
&
name
)
const
{
return
ToGradNames
(
fwd_op_
.
Output
(
name
));
std
::
vector
<
std
::
string
>
ret_val
;
auto
onames
=
this
->
Output
(
name
);
ret_val
.
reserve
(
onames
.
size
());
std
::
transform
(
onames
.
begin
(),
onames
.
end
(),
std
::
back_inserter
(
ret_val
),
GradVarName
);
return
ret_val
;
}
std
::
vector
<
std
::
string
>
InputNames
()
const
{
...
...
@@ -75,6 +94,7 @@ class GradOpDescMakerBase {
private:
const
OpDescBind
&
fwd_op_
;
const
std
::
unordered_set
<
std
::
string
>&
no_grad_set_
;
};
class
SingleGradOpDescMaker
:
public
GradOpDescMakerBase
{
...
...
@@ -91,6 +111,7 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase {
virtual
std
::
unique_ptr
<
OpDescBind
>
Apply
()
const
=
0
;
};
template
<
bool
DropEmptyIG
=
true
>
class
DefaultGradOpDescMaker
:
public
SingleGradOpDescMaker
{
public:
using
SingleGradOpDescMaker
::
SingleGradOpDescMaker
;
...
...
@@ -102,7 +123,8 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
for
(
auto
&
input_param
:
this
->
InputNames
())
{
grad
->
SetInput
(
input_param
,
this
->
Input
(
input_param
));
grad
->
SetOutput
(
GradVarName
(
input_param
),
this
->
InputGrad
(
input_param
));
grad
->
SetOutput
(
GradVarName
(
input_param
),
this
->
InputGrad
(
input_param
,
DropEmptyIG
));
}
for
(
auto
&
output_param
:
this
->
OutputNames
())
{
...
...
paddle/framework/op_registry.cc
浏览文件 @
53542d93
...
...
@@ -59,11 +59,5 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDescBind& op_desc) {
op_desc
.
GetAttrMap
());
}
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
OpRegistry
::
CreateGradOpDescs
(
OpDescBind
*
op_desc
)
{
auto
&
info
=
OpInfoMap
::
Instance
().
Get
(
op_desc
->
Type
());
return
info
.
grad_op_maker_
(
*
op_desc
);
}
}
// namespace framework
}
// namespace paddle
paddle/framework/op_registry.h
浏览文件 @
53542d93
...
...
@@ -79,9 +79,6 @@ class OpRegistry {
static
std
::
unique_ptr
<
OperatorBase
>
CreateOp
(
const
OpDesc
&
op_desc
);
static
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
CreateGradOpDescs
(
OpDescBind
*
op_desc
);
static
std
::
unique_ptr
<
OperatorBase
>
CreateOp
(
const
OpDescBind
&
op_desc
);
};
...
...
@@ -160,17 +157,18 @@ class OpKernelRegistrar : public Registrar {
/**
* Macro to register Operator.
*/
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class) \
REGISTER_OPERATOR(grad_op_type, grad_op_class); \
class _GradOpDescMaker_##grad_op_type##_ \
: public ::paddle::framework::DefaultGradOpDescMaker { \
using ::paddle::framework::DefaultGradOpDescMaker::DefaultGradOpDescMaker; \
\
protected: \
virtual std::string GradOpType() const { return #grad_op_type; } \
}; \
REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class) \
REGISTER_OPERATOR(grad_op_type, grad_op_class); \
class _GradOpDescMaker_##grad_op_type##_ \
: public ::paddle::framework::DefaultGradOpDescMaker<true> { \
using ::paddle::framework::DefaultGradOpDescMaker< \
true>::DefaultGradOpDescMaker; \
\
protected: \
virtual std::string GradOpType() const { return #grad_op_type; } \
}; \
REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \
op_maker_class);
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
...
...
paddle/framework/type_defs.h
浏览文件 @
53542d93
...
...
@@ -36,8 +36,8 @@ using OpCreator = std::function<OperatorBase*(
const
std
::
string
&
/*type*/
,
const
VariableNameMap
&
/*inputs*/
,
const
VariableNameMap
&
/*outputs*/
,
const
AttributeMap
&
/*attrs*/
)
>
;
using
GradOpMakerFN
=
std
::
function
<
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
(
const
OpDescBind
&
)
>
;
using
GradOpMakerFN
=
std
::
function
<
std
::
vector
<
std
::
unique_ptr
<
OpDescBind
>>
(
const
OpDescBind
&
,
const
std
::
unordered_set
<
std
::
string
>&
/*no_grad_set*/
)
>
;
}
// namespace framework
}
// namespace paddle
paddle/operators/multiplex_op.cc
浏览文件 @
53542d93
...
...
@@ -115,8 +115,9 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
multiplex
,
ops
::
MultiplexOp
,
ops
::
MultiplexOpMaker
,
multiplex_grad
,
ops
::
MultiplexGradOp
);
REGISTER_OPERATOR
(
multiplex
,
ops
::
MultiplexOp
,
ops
::
MultiplexOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
false
>
);
REGISTER_OPERATOR
(
multiplex_grad
,
ops
::
MultiplexGradOp
);
REGISTER_OP_CPU_KERNEL
(
multiplex
,
ops
::
MultiplexCPUKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录