提交 53542d93 编写于 作者: Q qijun

Merge remote-tracking branch 'baidu/develop' into selected_rows

...@@ -105,6 +105,12 @@ if (WITH_C_API AND WITH_PYTHON) ...@@ -105,6 +105,12 @@ if (WITH_C_API AND WITH_PYTHON)
"different Python interpreter from compiling.") "different Python interpreter from compiling.")
endif() endif()
if(MOBILE_INFERENCE)
set(THIRD_PARTY_BUILD_TYPE MinSizeRel)
else()
set(THIRD_PARTY_BUILD_TYPE Release)
endif()
######################################################################################## ########################################################################################
include(external/mklml) # download mklml package include(external/mklml) # download mklml package
......
...@@ -8,7 +8,7 @@ ExternalProject_Add( ...@@ -8,7 +8,7 @@ ExternalProject_Add(
extern_eigen3 extern_eigen3
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/RLovelett/eigen.git" GIT_REPOSITORY "https://github.com/RLovelett/eigen.git"
GIT_TAG "master" GIT_TAG 4e79cb69b9425f5f8c3a84be4350d4ab75b5fd9d
PREFIX ${EIGEN_SOURCE_DIR} PREFIX ${EIGEN_SOURCE_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
......
...@@ -36,6 +36,7 @@ ExternalProject_Add( ...@@ -36,6 +36,7 @@ ExternalProject_Add(
# change this back to the official Github repo once my PR is # change this back to the official Github repo once my PR is
# merged. # merged.
GIT_REPOSITORY "https://github.com/wangkuiyi/gflags.git" GIT_REPOSITORY "https://github.com/wangkuiyi/gflags.git"
GIT_TAG 986964c07427ecb9cdb5bd73f73ebbd40e54dadb
PREFIX ${GFLAGS_SOURCES_DIR} PREFIX ${GFLAGS_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
...@@ -45,11 +46,11 @@ ExternalProject_Add( ...@@ -45,11 +46,11 @@ ExternalProject_Add(
-DCMAKE_INSTALL_PREFIX=${GFLAGS_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${GFLAGS_INSTALL_DIR}
-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DBUILD_TESTING=OFF -DBUILD_TESTING=OFF
-DCMAKE_BUILD_TYPE=Release -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS} ${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GFLAGS_INSTALL_DIR} CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GFLAGS_INSTALL_DIR}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
) )
ADD_LIBRARY(gflags STATIC IMPORTED GLOBAL) ADD_LIBRARY(gflags STATIC IMPORTED GLOBAL)
......
...@@ -31,6 +31,7 @@ ExternalProject_Add( ...@@ -31,6 +31,7 @@ ExternalProject_Add(
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS gflags DEPENDS gflags
GIT_REPOSITORY "https://github.com/google/glog.git" GIT_REPOSITORY "https://github.com/google/glog.git"
GIT_TAG v0.3.5
PREFIX ${GLOG_SOURCES_DIR} PREFIX ${GLOG_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
...@@ -43,12 +44,12 @@ ExternalProject_Add( ...@@ -43,12 +44,12 @@ ExternalProject_Add(
-DWITH_GFLAGS=ON -DWITH_GFLAGS=ON
-Dgflags_DIR=${GFLAGS_INSTALL_DIR}/lib/cmake/gflags -Dgflags_DIR=${GFLAGS_INSTALL_DIR}/lib/cmake/gflags
-DBUILD_TESTING=OFF -DBUILD_TESTING=OFF
-DCMAKE_BUILD_TYPE=Release -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS} ${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GLOG_INSTALL_DIR} CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GLOG_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR:PATH=${GLOG_INSTALL_DIR}/lib -DCMAKE_INSTALL_LIBDIR:PATH=${GLOG_INSTALL_DIR}/lib
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
) )
ADD_LIBRARY(glog STATIC IMPORTED GLOBAL) ADD_LIBRARY(glog STATIC IMPORTED GLOBAL)
......
...@@ -56,11 +56,11 @@ IF(WITH_TESTING) ...@@ -56,11 +56,11 @@ IF(WITH_TESTING)
-DBUILD_GMOCK=ON -DBUILD_GMOCK=ON
-Dgtest_disable_pthreads=ON -Dgtest_disable_pthreads=ON
-Dgtest_force_shared_crt=ON -Dgtest_force_shared_crt=ON
-DCMAKE_BUILD_TYPE=Release -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS} ${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GTEST_INSTALL_DIR} CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GTEST_INSTALL_DIR}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
) )
ADD_LIBRARY(gtest STATIC IMPORTED GLOBAL) ADD_LIBRARY(gtest STATIC IMPORTED GLOBAL)
......
...@@ -191,12 +191,12 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) ...@@ -191,12 +191,12 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
${OPTIONAL_ARGS} ${OPTIONAL_ARGS}
-Dprotobuf_BUILD_TESTS=OFF -Dprotobuf_BUILD_TESTS=OFF
-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=Release -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_INSTALL_PREFIX=${PROTOBUF_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${PROTOBUF_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR=lib -DCMAKE_INSTALL_LIBDIR=lib
CMAKE_CACHE_ARGS CMAKE_CACHE_ARGS
-DCMAKE_INSTALL_PREFIX:PATH=${PROTOBUF_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX:PATH=${PROTOBUF_INSTALL_DIR}
-DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
${OPTIONAL_CACHE_ARGS} ${OPTIONAL_CACHE_ARGS}
......
...@@ -35,6 +35,7 @@ ExternalProject_Add( ...@@ -35,6 +35,7 @@ ExternalProject_Add(
extern_warpctc extern_warpctc
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/gangliao/warp-ctc.git" GIT_REPOSITORY "https://github.com/gangliao/warp-ctc.git"
GIT_TAG b63a0644654a3e0ed624c85a1767bc8193aead09
PREFIX ${WARPCTC_SOURCES_DIR} PREFIX ${WARPCTC_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
...@@ -48,9 +49,9 @@ ExternalProject_Add( ...@@ -48,9 +49,9 @@ ExternalProject_Add(
-DCMAKE_DISABLE_FIND_PACKAGE_Torch=ON -DCMAKE_DISABLE_FIND_PACKAGE_Torch=ON
-DBUILD_SHARED=ON -DBUILD_SHARED=ON
-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=Release -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS} ${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS -DCMAKE_BUILD_TYPE:STRING=Release CMAKE_CACHE_ARGS -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_INSTALL_PREFIX:PATH=${WARPCTC_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX:PATH=${WARPCTC_INSTALL_DIR}
) )
......
...@@ -42,11 +42,11 @@ ExternalProject_Add( ...@@ -42,11 +42,11 @@ ExternalProject_Add(
-DBUILD_SHARED_LIBS=OFF -DBUILD_SHARED_LIBS=OFF
-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_MACOSX_RPATH=ON -DCMAKE_MACOSX_RPATH=ON
-DCMAKE_BUILD_TYPE=Release -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS} ${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${ZLIB_INSTALL_DIR} CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${ZLIB_INSTALL_DIR}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
) )
LIST(APPEND external_project_dependencies zlib) LIST(APPEND external_project_dependencies zlib)
......
...@@ -28,14 +28,15 @@ namespace paddle { ...@@ -28,14 +28,15 @@ namespace paddle {
namespace framework { namespace framework {
static inline std::unique_ptr<OperatorBase> CreateGradOp( static inline std::unique_ptr<OperatorBase> CreateGradOp(
const OperatorBase& op) { const OperatorBase& op,
const std::unordered_set<std::string>& no_grad_set) {
OpDescBind op_desc; OpDescBind op_desc;
op_desc.SetInputMap(op.Inputs()); op_desc.SetInputMap(op.Inputs());
op_desc.SetOutputMap(op.Outputs()); op_desc.SetOutputMap(op.Outputs());
op_desc.SetType(op.Type()); op_desc.SetType(op.Type());
op_desc.SetAttrMap(op.Attrs()); op_desc.SetAttrMap(op.Attrs());
auto& info = OpInfoMap::Instance().Get(op.Type()); auto& info = OpInfoMap::Instance().Get(op.Type());
auto grad_descs = info.GradOpMaker()(op_desc); auto grad_descs = info.GradOpMaker()(op_desc, no_grad_set);
std::vector<std::unique_ptr<OperatorBase>> grad_ops; std::vector<std::unique_ptr<OperatorBase>> grad_ops;
grad_ops.reserve(grad_descs.size()); grad_ops.reserve(grad_descs.size());
std::transform(grad_descs.begin(), grad_descs.end(), std::transform(grad_descs.begin(), grad_descs.end(),
...@@ -187,7 +188,8 @@ static std::unique_ptr<OperatorBase> BackwardRecursive( ...@@ -187,7 +188,8 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
net->InsertOp(pos.first + 1, std::move(pos.second)); net->InsertOp(pos.first + 1, std::move(pos.second));
} }
} else { } else {
std::unique_ptr<OperatorBase> grad_op(CreateGradOp(forwardOp)); std::unique_ptr<OperatorBase> grad_op(
CreateGradOp(forwardOp, no_grad_names));
ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op]( ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op](
const std::string& grad_input) { const std::string& grad_input) {
...@@ -272,7 +274,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad( ...@@ -272,7 +274,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
const std::unique_ptr<OpDescBind>& op_desc, const std::unique_ptr<OpDescBind>& op_desc,
std::unordered_set<std::string>& no_grad_vars) { std::unordered_set<std::string>& no_grad_vars) {
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs; std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
// All input gradients of forwarding operator do not need to calculat. // All input gradients of forwarding operator do not need to calculate.
const std::vector<std::string>& inputs = op_desc->InputArgumentNames(); const std::vector<std::string>& inputs = op_desc->InputArgumentNames();
if (AllGradInSet(inputs, no_grad_vars)) { if (AllGradInSet(inputs, no_grad_vars)) {
return grad_op_descs; // empty vector return grad_op_descs; // empty vector
...@@ -286,7 +288,9 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad( ...@@ -286,7 +288,9 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
return grad_op_descs; // empty vector return grad_op_descs; // empty vector
} }
grad_op_descs = OpRegistry::CreateGradOpDescs(op_desc.get()); grad_op_descs = OpInfoMap::Instance()
.Get(op_desc->Type())
.GradOpMaker()(*op_desc, no_grad_vars);
std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops; std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops;
for (auto& desc : grad_op_descs) { for (auto& desc : grad_op_descs) {
...@@ -301,11 +305,6 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad( ...@@ -301,11 +305,6 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
pending_fill_zeros_ops.push_back(std::move(fill_zeros_op)); pending_fill_zeros_ops.push_back(std::move(fill_zeros_op));
} }
} }
for (const std::string& out_name : desc->OutputArgumentNames()) {
if (no_grad_vars.count(out_name)) {
desc->Rename(out_name, kEmptyVarName);
}
}
} }
for (auto& p : pending_fill_zeros_ops) { for (auto& p : pending_fill_zeros_ops) {
......
...@@ -169,6 +169,45 @@ class MultInOutOpMaker : public OpProtoAndCheckerMaker { ...@@ -169,6 +169,45 @@ class MultInOutOpMaker : public OpProtoAndCheckerMaker {
} }
}; };
class MinusGradOpDescMaker : public GradOpDescMakerBase {
public:
using GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<OpDescBind>> operator()() const override {
std::vector<std::unique_ptr<OpDescBind>> retv;
auto x_g = InputGrad("X");
if (!x_g.empty()) {
auto *op_desc = new OpDescBind();
op_desc->SetType("scale");
op_desc->SetInput("X", OutputGrad("Out"));
op_desc->SetOutput("Out", x_g);
op_desc->SetAttr("scale", 1.0f);
retv.emplace_back(op_desc);
}
auto y_g = InputGrad("Y");
if (!y_g.empty()) {
auto *op_desc = new OpDescBind();
op_desc->SetType("scale");
op_desc->SetInput("X", OutputGrad("Out"));
op_desc->SetOutput("Out", y_g);
op_desc->SetAttr("scale", -1.0f);
retv.emplace_back(op_desc);
}
return retv;
}
};
class MinusOpMaker : public OpProtoAndCheckerMaker {
public:
MinusOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "");
AddInput("Y", "");
AddOutput("Out", "");
AddComment("minus for unittest");
}
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -187,6 +226,7 @@ REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker); ...@@ -187,6 +226,7 @@ REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad, REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad,
f::NOP); f::NOP);
REGISTER_OP(mult_in_out, f::NOP, f::MultInOutOpMaker, mult_in_out_grad, f::NOP); REGISTER_OP(mult_in_out, f::NOP, f::MultInOutOpMaker, mult_in_out_grad, f::NOP);
REGISTER_OPERATOR(minus, f::NOP, f::MinusOpMaker, f::MinusGradOpDescMaker);
TEST(Backward, simple_op_not_need_grad) { TEST(Backward, simple_op_not_need_grad) {
auto fwd = f::OpRegistry::CreateOp( auto fwd = f::OpRegistry::CreateOp(
...@@ -395,12 +435,13 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ...@@ -395,12 +435,13 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
2UL /* external input number */ 2UL /* external input number */
+ 1UL /* external output number*/ + 1UL /* external output number*/
+ 1UL /* number of gradient of external output*/ + 1UL /* number of gradient of external output*/
+ 2U /* internal variable number*/); + 2UL /* internal variable number*/
);
EXPECT_EQ(grad_fc.Outputs(all).size(), EXPECT_EQ(grad_fc.Outputs(all).size(),
2UL /* input number of mul*/ 2UL /* input number of mul*/
+ 2UL /* input number of rowwise_add + 2UL /* input number of rowwise_add*/
*/ + 1UL /* input number of sigmod */
+ 1UL /* input number of sigmod */); - 1UL /* out2 is not needed*/);
EXPECT_EQ(bwd_net->ops_[1]->Inputs(all).size(), 0UL); EXPECT_EQ(bwd_net->ops_[1]->Inputs(all).size(), 0UL);
EXPECT_EQ(bwd_net->ops_[1]->Outputs(all).size(), 0UL); EXPECT_EQ(bwd_net->ops_[1]->Outputs(all).size(), 0UL);
EXPECT_EQ(bwd_net->ops_[2]->Inputs(all).size(), 0UL); EXPECT_EQ(bwd_net->ops_[2]->Inputs(all).size(), 0UL);
...@@ -580,8 +621,7 @@ TEST(Backward, intermedia_var_no_grad) { ...@@ -580,8 +621,7 @@ TEST(Backward, intermedia_var_no_grad) {
std::vector<std::string>({f::GradVarName("out4")})); std::vector<std::string>({f::GradVarName("out4")}));
EXPECT_EQ(grad_op4->Output(f::GradVarName("X")), EXPECT_EQ(grad_op4->Output(f::GradVarName("X")),
std::vector<std::string>({f::GradVarName("out1")})); std::vector<std::string>({f::GradVarName("out1")}));
EXPECT_EQ(grad_op4->Output(f::GradVarName("Y")), EXPECT_EQ(grad_op4->Output(f::GradVarName("Y")), std::vector<std::string>());
std::vector<std::string>({f::kEmptyVarName}));
} }
TEST(Backward, var_no_grad) { TEST(Backward, var_no_grad) {
...@@ -619,8 +659,7 @@ TEST(Backward, var_no_grad) { ...@@ -619,8 +659,7 @@ TEST(Backward, var_no_grad) {
std::vector<std::string>({f::GradVarName("z2")})); std::vector<std::string>({f::GradVarName("z2")}));
EXPECT_EQ(grad_op2->Output(f::GradVarName("X")), EXPECT_EQ(grad_op2->Output(f::GradVarName("X")),
std::vector<std::string>({f::GradVarName("y1")})); std::vector<std::string>({f::GradVarName("y1")}));
EXPECT_EQ(grad_op2->Output(f::GradVarName("H")), EXPECT_EQ(grad_op2->Output(f::GradVarName("H")), std::vector<std::string>());
std::vector<std::string>({f::kEmptyVarName}));
f::OpDescBind *fill_zero_op = block->AllOps()[3]; f::OpDescBind *fill_zero_op = block->AllOps()[3];
ASSERT_EQ(fill_zero_op->Type(), "fill_zeros_like"); ASSERT_EQ(fill_zero_op->Type(), "fill_zeros_like");
...@@ -718,4 +757,19 @@ TEST(Backward, shared_var) { ...@@ -718,4 +757,19 @@ TEST(Backward, shared_var) {
std::vector<std::string>({f::GradVarName("x1")})); std::vector<std::string>({f::GradVarName("x1")}));
EXPECT_EQ(grad_op1->Output(f::GradVarName("b")), EXPECT_EQ(grad_op1->Output(f::GradVarName("b")),
std::vector<std::string>({f::GradVarName("b1")})); std::vector<std::string>({f::GradVarName("b1")}));
}
TEST(Backward, half_backward) {
f::ProgramDesc *program_desc = GetNewProgramDesc();
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::BlockDescBind *block = program.Block(0);
auto *op1 = block->AppendOp();
op1->SetType("minus");
op1->SetInput("X", {"a"});
op1->SetInput("Y", {"b"});
op1->SetOutput("Out", {"out"});
AppendBackward(program, {"b"});
auto ops = block->AllOps();
ASSERT_EQ(2UL, ops.size());
} }
\ No newline at end of file
...@@ -97,8 +97,10 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> { ...@@ -97,8 +97,10 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
template <typename T> template <typename T>
struct OpInfoFiller<T, kGradOpDescMaker> { struct OpInfoFiller<T, kGradOpDescMaker> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
info->grad_op_maker_ = [](const OpDescBind& fwd_op) { info->grad_op_maker_ = [](
T maker(fwd_op); const OpDescBind& fwd_op,
const std::unordered_set<std::string>& no_grad_set) {
T maker(fwd_op, no_grad_set);
return maker(); return maker();
}; };
} }
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "gflags/gflags.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
#include "paddle/framework/backward.h" #include "paddle/framework/backward.h"
...@@ -317,3 +318,12 @@ TEST_F(ExecutorTesterFeedAndFetch, GPU) { ...@@ -317,3 +318,12 @@ TEST_F(ExecutorTesterFeedAndFetch, GPU) {
} }
} }
#endif #endif
DECLARE_double(fraction_of_gpu_memory_to_use);
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
// Use less GPU memory for unittest.
FLAGS_fraction_of_gpu_memory_to_use = 0.25;
return RUN_ALL_TESTS();
}
\ No newline at end of file
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <string>
#include <unordered_set>
#include "paddle/framework/op_desc.h" #include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
...@@ -21,27 +23,44 @@ namespace framework { ...@@ -21,27 +23,44 @@ namespace framework {
class GradOpDescMakerBase { class GradOpDescMakerBase {
public: public:
explicit GradOpDescMakerBase(const OpDescBind& fwd_op) : fwd_op_(fwd_op) {} explicit GradOpDescMakerBase(
const OpDescBind& fwd_op,
const std::unordered_set<std::string>& no_grad_set)
: fwd_op_(fwd_op), no_grad_set_(no_grad_set) {}
virtual ~GradOpDescMakerBase() = default; virtual ~GradOpDescMakerBase() = default;
virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0; virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0;
protected: protected:
static std::vector<std::string> ToGradNames( std::vector<std::string> InputGrad(const std::string& name,
const std::vector<std::string>& var_names) { bool drop_empty_grad = true) const {
std::vector<std::string> ret_val; std::vector<std::string> ret_val;
auto var_names = this->Input(name);
ret_val.reserve(var_names.size()); ret_val.reserve(var_names.size());
std::transform(var_names.begin(), var_names.end(), std::transform(
std::back_inserter(ret_val), GradVarName); var_names.begin(), var_names.end(), std::back_inserter(ret_val),
return ret_val; [this](const std::string& fwd_var_name) -> std::string {
} auto g_name = GradVarName(fwd_var_name);
return no_grad_set_.count(g_name) == 0 ? g_name : kEmptyVarName;
std::vector<std::string> InputGrad(const std::string& name) const { });
return ToGradNames(fwd_op_.Input(name)); if (!drop_empty_grad) {
return ret_val;
}
std::vector<std::string> dropped_ret_val;
dropped_ret_val.reserve(ret_val.size());
std::copy_if(ret_val.begin(), ret_val.end(),
std::back_inserter(dropped_ret_val),
[](const std::string& str) { return str != kEmptyVarName; });
return dropped_ret_val;
} }
std::vector<std::string> OutputGrad(const std::string& name) const { std::vector<std::string> OutputGrad(const std::string& name) const {
return ToGradNames(fwd_op_.Output(name)); std::vector<std::string> ret_val;
auto onames = this->Output(name);
ret_val.reserve(onames.size());
std::transform(onames.begin(), onames.end(), std::back_inserter(ret_val),
GradVarName);
return ret_val;
} }
std::vector<std::string> InputNames() const { std::vector<std::string> InputNames() const {
...@@ -75,6 +94,7 @@ class GradOpDescMakerBase { ...@@ -75,6 +94,7 @@ class GradOpDescMakerBase {
private: private:
const OpDescBind& fwd_op_; const OpDescBind& fwd_op_;
const std::unordered_set<std::string>& no_grad_set_;
}; };
class SingleGradOpDescMaker : public GradOpDescMakerBase { class SingleGradOpDescMaker : public GradOpDescMakerBase {
...@@ -91,6 +111,7 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase { ...@@ -91,6 +111,7 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase {
virtual std::unique_ptr<OpDescBind> Apply() const = 0; virtual std::unique_ptr<OpDescBind> Apply() const = 0;
}; };
template <bool DropEmptyIG = true>
class DefaultGradOpDescMaker : public SingleGradOpDescMaker { class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
public: public:
using SingleGradOpDescMaker::SingleGradOpDescMaker; using SingleGradOpDescMaker::SingleGradOpDescMaker;
...@@ -102,7 +123,8 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker { ...@@ -102,7 +123,8 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
for (auto& input_param : this->InputNames()) { for (auto& input_param : this->InputNames()) {
grad->SetInput(input_param, this->Input(input_param)); grad->SetInput(input_param, this->Input(input_param));
grad->SetOutput(GradVarName(input_param), this->InputGrad(input_param)); grad->SetOutput(GradVarName(input_param),
this->InputGrad(input_param, DropEmptyIG));
} }
for (auto& output_param : this->OutputNames()) { for (auto& output_param : this->OutputNames()) {
......
...@@ -59,11 +59,5 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDescBind& op_desc) { ...@@ -59,11 +59,5 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDescBind& op_desc) {
op_desc.GetAttrMap()); op_desc.GetAttrMap());
} }
std::vector<std::unique_ptr<OpDescBind>> OpRegistry::CreateGradOpDescs(
OpDescBind* op_desc) {
auto& info = OpInfoMap::Instance().Get(op_desc->Type());
return info.grad_op_maker_(*op_desc);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -79,9 +79,6 @@ class OpRegistry { ...@@ -79,9 +79,6 @@ class OpRegistry {
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
static std::vector<std::unique_ptr<OpDescBind>> CreateGradOpDescs(
OpDescBind* op_desc);
static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc);
}; };
...@@ -160,17 +157,18 @@ class OpKernelRegistrar : public Registrar { ...@@ -160,17 +157,18 @@ class OpKernelRegistrar : public Registrar {
/** /**
* Macro to register Operator. * Macro to register Operator.
*/ */
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \ #define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class) \ grad_op_class) \
REGISTER_OPERATOR(grad_op_type, grad_op_class); \ REGISTER_OPERATOR(grad_op_type, grad_op_class); \
class _GradOpDescMaker_##grad_op_type##_ \ class _GradOpDescMaker_##grad_op_type##_ \
: public ::paddle::framework::DefaultGradOpDescMaker { \ : public ::paddle::framework::DefaultGradOpDescMaker<true> { \
using ::paddle::framework::DefaultGradOpDescMaker::DefaultGradOpDescMaker; \ using ::paddle::framework::DefaultGradOpDescMaker< \
\ true>::DefaultGradOpDescMaker; \
protected: \ \
virtual std::string GradOpType() const { return #grad_op_type; } \ protected: \
}; \ virtual std::string GradOpType() const { return #grad_op_type; } \
REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \ }; \
REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \
op_maker_class); op_maker_class);
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \ #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
......
...@@ -36,8 +36,8 @@ using OpCreator = std::function<OperatorBase*( ...@@ -36,8 +36,8 @@ using OpCreator = std::function<OperatorBase*(
const std::string& /*type*/, const VariableNameMap& /*inputs*/, const std::string& /*type*/, const VariableNameMap& /*inputs*/,
const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>; const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
using GradOpMakerFN = using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDescBind>>(
std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)>; const OpDescBind&, const std::unordered_set<std::string>& /*no_grad_set*/)>;
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -115,8 +115,9 @@ class MultiplexGradOp : public framework::OperatorWithKernel { ...@@ -115,8 +115,9 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker, multiplex_grad, REGISTER_OPERATOR(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker,
ops::MultiplexGradOp); paddle::framework::DefaultGradOpDescMaker<false>);
REGISTER_OPERATOR(multiplex_grad, ops::MultiplexGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
multiplex, ops::MultiplexCPUKernel<paddle::platform::CPUPlace, float>); multiplex, ops::MultiplexCPUKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册