diff --git a/CMakeLists.txt b/CMakeLists.txt index 4783095194dc9c6409dc31c95588f46c9bee7c61..1252e7539816016dfdf1b90b8941fa42e6bb85e0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -105,6 +105,12 @@ if (WITH_C_API AND WITH_PYTHON) "different Python interpreter from compiling.") endif() +if(MOBILE_INFERENCE) + set(THIRD_PARTY_BUILD_TYPE MinSizeRel) +else() + set(THIRD_PARTY_BUILD_TYPE Release) +endif() + ######################################################################################## include(external/mklml) # download mklml package diff --git a/cmake/external/eigen.cmake b/cmake/external/eigen.cmake index f7483f6be9169eb58f0148cd3a956a8c881e1fe3..bd853d921b4362ac7ac5e17e629552b2a200f08a 100644 --- a/cmake/external/eigen.cmake +++ b/cmake/external/eigen.cmake @@ -8,7 +8,7 @@ ExternalProject_Add( extern_eigen3 ${EXTERNAL_PROJECT_LOG_ARGS} GIT_REPOSITORY "https://github.com/RLovelett/eigen.git" - GIT_TAG "master" + GIT_TAG 4e79cb69b9425f5f8c3a84be4350d4ab75b5fd9d PREFIX ${EIGEN_SOURCE_DIR} UPDATE_COMMAND "" CONFIGURE_COMMAND "" diff --git a/cmake/external/gflags.cmake b/cmake/external/gflags.cmake index 957f8271e4841836956b0c3f2cf3d8c88a31192a..c819eb4d70898e48eab499c666168d78262d4240 100644 --- a/cmake/external/gflags.cmake +++ b/cmake/external/gflags.cmake @@ -36,6 +36,7 @@ ExternalProject_Add( # change this back to the official Github repo once my PR is # merged. GIT_REPOSITORY "https://github.com/wangkuiyi/gflags.git" + GIT_TAG 986964c07427ecb9cdb5bd73f73ebbd40e54dadb PREFIX ${GFLAGS_SOURCES_DIR} UPDATE_COMMAND "" CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} @@ -45,11 +46,11 @@ ExternalProject_Add( -DCMAKE_INSTALL_PREFIX=${GFLAGS_INSTALL_DIR} -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_TESTING=OFF - -DCMAKE_BUILD_TYPE=Release + -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} ${EXTERNAL_OPTIONAL_ARGS} CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GFLAGS_INSTALL_DIR} -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON - -DCMAKE_BUILD_TYPE:STRING=Release + -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} ) ADD_LIBRARY(gflags STATIC IMPORTED GLOBAL) diff --git a/cmake/external/glog.cmake b/cmake/external/glog.cmake index b3fef738ccc0b5886bb0a32501bb7b7adade0ff1..08bdc1e1623b0d917061c7368e9b2a8f7e9517fd 100644 --- a/cmake/external/glog.cmake +++ b/cmake/external/glog.cmake @@ -31,6 +31,7 @@ ExternalProject_Add( ${EXTERNAL_PROJECT_LOG_ARGS} DEPENDS gflags GIT_REPOSITORY "https://github.com/google/glog.git" + GIT_TAG v0.3.5 PREFIX ${GLOG_SOURCES_DIR} UPDATE_COMMAND "" CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} @@ -43,12 +44,12 @@ ExternalProject_Add( -DWITH_GFLAGS=ON -Dgflags_DIR=${GFLAGS_INSTALL_DIR}/lib/cmake/gflags -DBUILD_TESTING=OFF - -DCMAKE_BUILD_TYPE=Release + -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} ${EXTERNAL_OPTIONAL_ARGS} CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GLOG_INSTALL_DIR} -DCMAKE_INSTALL_LIBDIR:PATH=${GLOG_INSTALL_DIR}/lib -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON - -DCMAKE_BUILD_TYPE:STRING=Release + -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} ) ADD_LIBRARY(glog STATIC IMPORTED GLOBAL) diff --git a/cmake/external/gtest.cmake b/cmake/external/gtest.cmake index 6a2a79b7631b32e8a099797de509af64533bbb95..5a4aa7a5b71a4fdfd556a46037e6d1846d668fc4 100644 --- a/cmake/external/gtest.cmake +++ b/cmake/external/gtest.cmake @@ -56,11 +56,11 @@ IF(WITH_TESTING) -DBUILD_GMOCK=ON -Dgtest_disable_pthreads=ON -Dgtest_force_shared_crt=ON - -DCMAKE_BUILD_TYPE=Release + -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} ${EXTERNAL_OPTIONAL_ARGS} CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GTEST_INSTALL_DIR} -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON - -DCMAKE_BUILD_TYPE:STRING=Release + -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} ) ADD_LIBRARY(gtest STATIC IMPORTED GLOBAL) diff --git a/cmake/external/protobuf.cmake b/cmake/external/protobuf.cmake index 7cf7ba85cca4c248dcc74e078124c0b3815ee380..be7f6a9465970711170bd15dcecaadeaa8a55f86 100644 --- a/cmake/external/protobuf.cmake +++ b/cmake/external/protobuf.cmake @@ -191,12 +191,12 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) ${OPTIONAL_ARGS} -Dprotobuf_BUILD_TESTS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON - -DCMAKE_BUILD_TYPE=Release + -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_INSTALL_PREFIX=${PROTOBUF_INSTALL_DIR} -DCMAKE_INSTALL_LIBDIR=lib CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${PROTOBUF_INSTALL_DIR} - -DCMAKE_BUILD_TYPE:STRING=Release + -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON ${OPTIONAL_CACHE_ARGS} diff --git a/cmake/external/warpctc.cmake b/cmake/external/warpctc.cmake index bb258c7b5581fc22b44f4fe15c119f8081f4767e..8bd058222880b4df3b08da09c02f9fe7f1d0ee66 100644 --- a/cmake/external/warpctc.cmake +++ b/cmake/external/warpctc.cmake @@ -35,6 +35,7 @@ ExternalProject_Add( extern_warpctc ${EXTERNAL_PROJECT_LOG_ARGS} GIT_REPOSITORY "https://github.com/gangliao/warp-ctc.git" + GIT_TAG b63a0644654a3e0ed624c85a1767bc8193aead09 PREFIX ${WARPCTC_SOURCES_DIR} UPDATE_COMMAND "" CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} @@ -48,9 +49,9 @@ ExternalProject_Add( -DCMAKE_DISABLE_FIND_PACKAGE_Torch=ON -DBUILD_SHARED=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON - -DCMAKE_BUILD_TYPE=Release + -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} ${EXTERNAL_OPTIONAL_ARGS} - CMAKE_CACHE_ARGS -DCMAKE_BUILD_TYPE:STRING=Release + CMAKE_CACHE_ARGS -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_INSTALL_PREFIX:PATH=${WARPCTC_INSTALL_DIR} ) diff --git a/cmake/external/zlib.cmake b/cmake/external/zlib.cmake index c496a52b780364f3014f8fa3dfbc944a7aa7430e..e2c9fe56f335ae5b627b4d8d4bb17e4a2a466677 100644 --- a/cmake/external/zlib.cmake +++ b/cmake/external/zlib.cmake @@ -42,11 +42,11 @@ ExternalProject_Add( -DBUILD_SHARED_LIBS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_MACOSX_RPATH=ON - -DCMAKE_BUILD_TYPE=Release + -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} ${EXTERNAL_OPTIONAL_ARGS} CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${ZLIB_INSTALL_DIR} -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON - -DCMAKE_BUILD_TYPE:STRING=Release + -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} ) LIST(APPEND external_project_dependencies zlib) diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 063b108500d95c94d5859cf6e1a5a88dcdb2ed31..c966f97c2d5b553f6ab67bb2f7aac27108b80409 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -28,14 +28,15 @@ namespace paddle { namespace framework { static inline std::unique_ptr CreateGradOp( - const OperatorBase& op) { + const OperatorBase& op, + const std::unordered_set& no_grad_set) { OpDescBind op_desc; op_desc.SetInputMap(op.Inputs()); op_desc.SetOutputMap(op.Outputs()); op_desc.SetType(op.Type()); op_desc.SetAttrMap(op.Attrs()); auto& info = OpInfoMap::Instance().Get(op.Type()); - auto grad_descs = info.GradOpMaker()(op_desc); + auto grad_descs = info.GradOpMaker()(op_desc, no_grad_set); std::vector> grad_ops; grad_ops.reserve(grad_descs.size()); std::transform(grad_descs.begin(), grad_descs.end(), @@ -187,7 +188,8 @@ static std::unique_ptr BackwardRecursive( net->InsertOp(pos.first + 1, std::move(pos.second)); } } else { - std::unique_ptr grad_op(CreateGradOp(forwardOp)); + std::unique_ptr grad_op( + CreateGradOp(forwardOp, no_grad_names)); ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op]( const std::string& grad_input) { @@ -272,7 +274,7 @@ std::vector> MakeOpGrad( const std::unique_ptr& op_desc, std::unordered_set& no_grad_vars) { std::vector> grad_op_descs; - // All input gradients of forwarding operator do not need to calculat. + // All input gradients of forwarding operator do not need to calculate. const std::vector& inputs = op_desc->InputArgumentNames(); if (AllGradInSet(inputs, no_grad_vars)) { return grad_op_descs; // empty vector @@ -286,7 +288,9 @@ std::vector> MakeOpGrad( return grad_op_descs; // empty vector } - grad_op_descs = OpRegistry::CreateGradOpDescs(op_desc.get()); + grad_op_descs = OpInfoMap::Instance() + .Get(op_desc->Type()) + .GradOpMaker()(*op_desc, no_grad_vars); std::list> pending_fill_zeros_ops; for (auto& desc : grad_op_descs) { @@ -301,11 +305,6 @@ std::vector> MakeOpGrad( pending_fill_zeros_ops.push_back(std::move(fill_zeros_op)); } } - for (const std::string& out_name : desc->OutputArgumentNames()) { - if (no_grad_vars.count(out_name)) { - desc->Rename(out_name, kEmptyVarName); - } - } } for (auto& p : pending_fill_zeros_ops) { diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 8fd8c826187e3e3f830fdb318c4edacbad8b7333..9b15331dff380a2c437e1cc42da92d7fd239d31c 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -169,6 +169,45 @@ class MultInOutOpMaker : public OpProtoAndCheckerMaker { } }; +class MinusGradOpDescMaker : public GradOpDescMakerBase { + public: + using GradOpDescMakerBase::GradOpDescMakerBase; + + std::vector> operator()() const override { + std::vector> retv; + auto x_g = InputGrad("X"); + if (!x_g.empty()) { + auto *op_desc = new OpDescBind(); + op_desc->SetType("scale"); + op_desc->SetInput("X", OutputGrad("Out")); + op_desc->SetOutput("Out", x_g); + op_desc->SetAttr("scale", 1.0f); + retv.emplace_back(op_desc); + } + + auto y_g = InputGrad("Y"); + if (!y_g.empty()) { + auto *op_desc = new OpDescBind(); + op_desc->SetType("scale"); + op_desc->SetInput("X", OutputGrad("Out")); + op_desc->SetOutput("Out", y_g); + op_desc->SetAttr("scale", -1.0f); + retv.emplace_back(op_desc); + } + return retv; + } +}; + +class MinusOpMaker : public OpProtoAndCheckerMaker { + public: + MinusOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", ""); + AddInput("Y", ""); + AddOutput("Out", ""); + AddComment("minus for unittest"); + } +}; } // namespace framework } // namespace paddle @@ -187,6 +226,7 @@ REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker); REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad, f::NOP); REGISTER_OP(mult_in_out, f::NOP, f::MultInOutOpMaker, mult_in_out_grad, f::NOP); +REGISTER_OPERATOR(minus, f::NOP, f::MinusOpMaker, f::MinusGradOpDescMaker); TEST(Backward, simple_op_not_need_grad) { auto fwd = f::OpRegistry::CreateOp( @@ -395,12 +435,13 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { 2UL /* external input number */ + 1UL /* external output number*/ + 1UL /* number of gradient of external output*/ - + 2U /* internal variable number*/); + + 2UL /* internal variable number*/ + ); EXPECT_EQ(grad_fc.Outputs(all).size(), 2UL /* input number of mul*/ - + 2UL /* input number of rowwise_add - */ - + 1UL /* input number of sigmod */); + + 2UL /* input number of rowwise_add*/ + + 1UL /* input number of sigmod */ + - 1UL /* out2 is not needed*/); EXPECT_EQ(bwd_net->ops_[1]->Inputs(all).size(), 0UL); EXPECT_EQ(bwd_net->ops_[1]->Outputs(all).size(), 0UL); EXPECT_EQ(bwd_net->ops_[2]->Inputs(all).size(), 0UL); @@ -580,8 +621,7 @@ TEST(Backward, intermedia_var_no_grad) { std::vector({f::GradVarName("out4")})); EXPECT_EQ(grad_op4->Output(f::GradVarName("X")), std::vector({f::GradVarName("out1")})); - EXPECT_EQ(grad_op4->Output(f::GradVarName("Y")), - std::vector({f::kEmptyVarName})); + EXPECT_EQ(grad_op4->Output(f::GradVarName("Y")), std::vector()); } TEST(Backward, var_no_grad) { @@ -619,8 +659,7 @@ TEST(Backward, var_no_grad) { std::vector({f::GradVarName("z2")})); EXPECT_EQ(grad_op2->Output(f::GradVarName("X")), std::vector({f::GradVarName("y1")})); - EXPECT_EQ(grad_op2->Output(f::GradVarName("H")), - std::vector({f::kEmptyVarName})); + EXPECT_EQ(grad_op2->Output(f::GradVarName("H")), std::vector()); f::OpDescBind *fill_zero_op = block->AllOps()[3]; ASSERT_EQ(fill_zero_op->Type(), "fill_zeros_like"); @@ -718,4 +757,19 @@ TEST(Backward, shared_var) { std::vector({f::GradVarName("x1")})); EXPECT_EQ(grad_op1->Output(f::GradVarName("b")), std::vector({f::GradVarName("b1")})); +} + +TEST(Backward, half_backward) { + f::ProgramDesc *program_desc = GetNewProgramDesc(); + f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); + f::BlockDescBind *block = program.Block(0); + auto *op1 = block->AppendOp(); + op1->SetType("minus"); + op1->SetInput("X", {"a"}); + op1->SetInput("Y", {"b"}); + op1->SetOutput("Out", {"out"}); + + AppendBackward(program, {"b"}); + auto ops = block->AllOps(); + ASSERT_EQ(2UL, ops.size()); } \ No newline at end of file diff --git a/paddle/framework/details/op_registry.h b/paddle/framework/details/op_registry.h index daa474e8c5a223589018720da29a5c3363b5934d..ca8584b78ab081138e0d73b8a71ae4cc111a1b4c 100644 --- a/paddle/framework/details/op_registry.h +++ b/paddle/framework/details/op_registry.h @@ -97,8 +97,10 @@ struct OpInfoFiller { template struct OpInfoFiller { void operator()(const char* op_type, OpInfo* info) const { - info->grad_op_maker_ = [](const OpDescBind& fwd_op) { - T maker(fwd_op); + info->grad_op_maker_ = []( + const OpDescBind& fwd_op, + const std::unordered_set& no_grad_set) { + T maker(fwd_op, no_grad_set); return maker(); }; } diff --git a/paddle/framework/executor_test.cc b/paddle/framework/executor_test.cc index eaa9c9414b631a986af1ec2de0ebad84ed27f983..8c7db3b290483b32e5b40a74d63b6fac7f7aa7e1 100644 --- a/paddle/framework/executor_test.cc +++ b/paddle/framework/executor_test.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include +#include "gflags/gflags.h" #include "gtest/gtest.h" #include "paddle/framework/attribute.h" #include "paddle/framework/backward.h" @@ -317,3 +318,12 @@ TEST_F(ExecutorTesterFeedAndFetch, GPU) { } } #endif + +DECLARE_double(fraction_of_gpu_memory_to_use); + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + // Use less GPU memory for unittest. + FLAGS_fraction_of_gpu_memory_to_use = 0.25; + return RUN_ALL_TESTS(); +} \ No newline at end of file diff --git a/paddle/framework/grad_op_desc_maker.h b/paddle/framework/grad_op_desc_maker.h index e9ae6e22060850fe229998d3b651d08a5ca2033a..d7366b11ec94403e0d8d5d8a3485896f0dc691c0 100644 --- a/paddle/framework/grad_op_desc_maker.h +++ b/paddle/framework/grad_op_desc_maker.h @@ -13,6 +13,8 @@ limitations under the License. */ #pragma once +#include +#include #include "paddle/framework/op_desc.h" #include "paddle/framework/operator.h" @@ -21,27 +23,44 @@ namespace framework { class GradOpDescMakerBase { public: - explicit GradOpDescMakerBase(const OpDescBind& fwd_op) : fwd_op_(fwd_op) {} + explicit GradOpDescMakerBase( + const OpDescBind& fwd_op, + const std::unordered_set& no_grad_set) + : fwd_op_(fwd_op), no_grad_set_(no_grad_set) {} virtual ~GradOpDescMakerBase() = default; virtual std::vector> operator()() const = 0; protected: - static std::vector ToGradNames( - const std::vector& var_names) { + std::vector InputGrad(const std::string& name, + bool drop_empty_grad = true) const { std::vector ret_val; + auto var_names = this->Input(name); ret_val.reserve(var_names.size()); - std::transform(var_names.begin(), var_names.end(), - std::back_inserter(ret_val), GradVarName); - return ret_val; - } - - std::vector InputGrad(const std::string& name) const { - return ToGradNames(fwd_op_.Input(name)); + std::transform( + var_names.begin(), var_names.end(), std::back_inserter(ret_val), + [this](const std::string& fwd_var_name) -> std::string { + auto g_name = GradVarName(fwd_var_name); + return no_grad_set_.count(g_name) == 0 ? g_name : kEmptyVarName; + }); + if (!drop_empty_grad) { + return ret_val; + } + std::vector dropped_ret_val; + dropped_ret_val.reserve(ret_val.size()); + std::copy_if(ret_val.begin(), ret_val.end(), + std::back_inserter(dropped_ret_val), + [](const std::string& str) { return str != kEmptyVarName; }); + return dropped_ret_val; } std::vector OutputGrad(const std::string& name) const { - return ToGradNames(fwd_op_.Output(name)); + std::vector ret_val; + auto onames = this->Output(name); + ret_val.reserve(onames.size()); + std::transform(onames.begin(), onames.end(), std::back_inserter(ret_val), + GradVarName); + return ret_val; } std::vector InputNames() const { @@ -75,6 +94,7 @@ class GradOpDescMakerBase { private: const OpDescBind& fwd_op_; + const std::unordered_set& no_grad_set_; }; class SingleGradOpDescMaker : public GradOpDescMakerBase { @@ -91,6 +111,7 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase { virtual std::unique_ptr Apply() const = 0; }; +template class DefaultGradOpDescMaker : public SingleGradOpDescMaker { public: using SingleGradOpDescMaker::SingleGradOpDescMaker; @@ -102,7 +123,8 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker { for (auto& input_param : this->InputNames()) { grad->SetInput(input_param, this->Input(input_param)); - grad->SetOutput(GradVarName(input_param), this->InputGrad(input_param)); + grad->SetOutput(GradVarName(input_param), + this->InputGrad(input_param, DropEmptyIG)); } for (auto& output_param : this->OutputNames()) { diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index 94f75b0f309417e52bb5ad3117797d8c25233257..504afbd5dbacf7185f92e0000d19666230e2fb42 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -59,11 +59,5 @@ std::unique_ptr OpRegistry::CreateOp(const OpDescBind& op_desc) { op_desc.GetAttrMap()); } -std::vector> OpRegistry::CreateGradOpDescs( - OpDescBind* op_desc) { - auto& info = OpInfoMap::Instance().Get(op_desc->Type()); - return info.grad_op_maker_(*op_desc); -} - } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 5ca3af52a6909eeee21f647d0e60c7a690f90190..226e8ddcd4b1a2630e0eea00ee6c9f6af6bd5d20 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -79,9 +79,6 @@ class OpRegistry { static std::unique_ptr CreateOp(const OpDesc& op_desc); - static std::vector> CreateGradOpDescs( - OpDescBind* op_desc); - static std::unique_ptr CreateOp(const OpDescBind& op_desc); }; @@ -160,17 +157,18 @@ class OpKernelRegistrar : public Registrar { /** * Macro to register Operator. */ -#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \ - grad_op_class) \ - REGISTER_OPERATOR(grad_op_type, grad_op_class); \ - class _GradOpDescMaker_##grad_op_type##_ \ - : public ::paddle::framework::DefaultGradOpDescMaker { \ - using ::paddle::framework::DefaultGradOpDescMaker::DefaultGradOpDescMaker; \ - \ - protected: \ - virtual std::string GradOpType() const { return #grad_op_type; } \ - }; \ - REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \ +#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \ + grad_op_class) \ + REGISTER_OPERATOR(grad_op_type, grad_op_class); \ + class _GradOpDescMaker_##grad_op_type##_ \ + : public ::paddle::framework::DefaultGradOpDescMaker { \ + using ::paddle::framework::DefaultGradOpDescMaker< \ + true>::DefaultGradOpDescMaker; \ + \ + protected: \ + virtual std::string GradOpType() const { return #grad_op_type; } \ + }; \ + REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \ op_maker_class); #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \ diff --git a/paddle/framework/type_defs.h b/paddle/framework/type_defs.h index 6f65a942ba2a4073e6aa1047875ec5c3283c23a6..7e1b79c97b5b4c3f0292fbfa205a8ca541702fbc 100644 --- a/paddle/framework/type_defs.h +++ b/paddle/framework/type_defs.h @@ -36,8 +36,8 @@ using OpCreator = std::function; -using GradOpMakerFN = - std::function>(const OpDescBind&)>; +using GradOpMakerFN = std::function>( + const OpDescBind&, const std::unordered_set& /*no_grad_set*/)>; } // namespace framework } // namespace paddle diff --git a/paddle/operators/multiplex_op.cc b/paddle/operators/multiplex_op.cc index a86685b6dde4761cf74f9521bd9609b0864b9bdf..051051b051961c6da064bd9319460b3f41cea3e8 100644 --- a/paddle/operators/multiplex_op.cc +++ b/paddle/operators/multiplex_op.cc @@ -115,8 +115,9 @@ class MultiplexGradOp : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker, multiplex_grad, - ops::MultiplexGradOp); +REGISTER_OPERATOR(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(multiplex_grad, ops::MultiplexGradOp); REGISTER_OP_CPU_KERNEL( multiplex, ops::MultiplexCPUKernel); REGISTER_OP_CPU_KERNEL(