diff --git a/doc/design/simple_op_design.md b/doc/design/simple_op_design.md index 49ca5db5da9e400fd2c54eb8903b0dd2eb832d44..5e07c29c56d21728599195d420d3222213d77e7c 100644 --- a/doc/design/simple_op_design.md +++ b/doc/design/simple_op_design.md @@ -49,6 +49,7 @@ message AttrProto { message VarProto { required string name = 1; required string comment = 2; + required bool is_tensor = 3; }; message OpProto { diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 760d84e51e7473d359a415e4790251db3d139ab2..a76a95644dae2755a9599a57259a1f9b2ed604b7 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -19,8 +19,10 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) -cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc) -cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator) +cc_library(grad_op_creator SRCS grad_op_creator.cc DEPS op_proto operator) +cc_library(op_registry SRCS op_registry.cc DEPS op_desc grad_op_creator) +cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) +cc_test(grad_op_creator_test SRCS grad_op_creator_test.cc DEPS grad_op_creator op_registry add_op) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) # Generate an empty __init__.py to make framework_py_proto as a valid python module. @@ -28,5 +30,6 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch add_dependencies(framework_py_proto framework_py_proto_init) proto_library(net_proto SRCS net_proto.proto DEPS op_proto) +# cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op) cc_library(net SRCS net.cc DEPS operator net_proto op_registry) -cc_test(net_op_test SRCS net_op_test.cc DEPS net) +cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op) diff --git a/paddle/framework/grad_op_creator.cc b/paddle/framework/grad_op_creator.cc new file mode 100644 index 0000000000000000000000000000000000000000..106c2eae9dade9ef1829fc2f1b793faf483947d4 --- /dev/null +++ b/paddle/framework/grad_op_creator.cc @@ -0,0 +1,115 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/grad_op_creator.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace framework { + +OperatorBase* GradOpCreator::Create() { + BuildOpInOutArgList(); + OperatorBase* grad_op = OpRegistry::grad_creators().at(op_->type_)(); + CompleteGradOp(grad_op); + return grad_op; +} + +OpInOutArg* GradOpCreator::BuildArg(const VarProto& var, + const VarIndexMap& var_map, + const std::vector& format, + InOutType type) { + int idx = var_map.at(var.name()); + int begin_idx = format.empty() ? idx : format.at(idx); + int end_idx = format.empty() ? idx + 1 : format.at(idx + 1); + return new OpInOutArg(var.name(), type, !var.ignore_gradient(), begin_idx, + end_idx); +} + +void GradOpCreator::BuildOpInOutArgList() { + const OpProto& op_proto = OpRegistry::protos().at(op_->type_); + const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_->type_)); + const std::vector& in_format = + op_->attrs_.count("input_format") + ? op_->GetAttr>("input_format") + : std::vector(); + const std::vector& out_format = + op_->attrs_.count("output_format") + ? op_->GetAttr>("output_format") + : std::vector(); + for (const auto& var : op_proto.inputs()) { + arg_list_.emplace_back( + std::shared_ptr(BuildArg(var, var_map, in_format, IN))); + } + for (const auto& var : op_proto.outputs()) { + arg_list_.emplace_back( + std::shared_ptr(BuildArg(var, var_map, out_format, OUT))); + } +} + +void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg, + std::vector& in_out, + std::vector& format, + VarIndexMap* varmap, int& idx, + bool is_grad) const { + std::string var_name = arg->proto_name_; + if (is_grad) { + var_name += OperatorBase::GRAD_VAR_SUFFIX(); + } + (*varmap)[var_name] = idx++; + size_t pre_sz = in_out.size(); + auto base_it = + arg->type_ == IN ? op_->inputs_.begin() : op_->outputs_.begin(); + std::copy(base_it + arg->begin_idx_, base_it + arg->end_idx_, + std::back_inserter(in_out)); + if (is_grad) { + for (size_t i = pre_sz; i < in_out.size(); ++i) { + in_out[i] += OperatorBase::GRAD_VAR_SUFFIX(); + } + } + format.push_back(in_out.size()); +} + +void GradOpCreator::CompleteGradOp(OperatorBase* grad_op) const { + grad_op->type_ = op_->type_ + "@GRAD"; // not necessary + grad_op->attrs_ = op_->attrs_; + grad_op->attrs_.erase("input_format"); + grad_op->attrs_.erase("output_format"); + VarIndexMap* grad_varmap = new VarIndexMap(); + int in_idx = 0; + int out_idx = 0; + std::vector in_format({0}); + std::vector out_format({0}); + for (const auto& arg : arg_list_) { + // op_'s inputs_ and outputs_ + if (arg->needed_in_grad_) { + AddArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap, + in_idx, false); + } + if (arg->type_ == IN) { + // gradients of op_'s inputs_ + AddArgIntoGradOp(arg.get(), grad_op->outputs_, out_format, grad_varmap, + out_idx, true); + } else { + // gradients of op_'s outputs_ + AddArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap, + in_idx, true); + } + } + grad_op->attrs_["input_format"] = in_format; + grad_op->attrs_["output_format"] = out_format; + grad_op->in_out_idxs_.reset(grad_varmap); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/grad_op_creator.h b/paddle/framework/grad_op_creator.h new file mode 100644 index 0000000000000000000000000000000000000000..21b160a73f3f6402a0571e2f13be06b26b5c30bc --- /dev/null +++ b/paddle/framework/grad_op_creator.h @@ -0,0 +1,48 @@ +#pragma once + +#include "paddle/framework/op_proto.pb.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace framework { +class OpRegistry; + +enum InOutType { IN, OUT }; + +struct OpInOutArg { + OpInOutArg(const std::string& proto_name, const InOutType& type, + bool needed_in_grad, size_t begin_idx, size_t end_idx) + : proto_name_(proto_name), + type_(type), + needed_in_grad_(needed_in_grad), + begin_idx_(begin_idx), + end_idx_(end_idx) {} + + std::string proto_name_; + InOutType type_; + bool needed_in_grad_; + size_t begin_idx_; + size_t end_idx_; +}; + +class GradOpCreator { + using VarIndexMap = std::unordered_map; + + public: + GradOpCreator(const OperatorBase* op) : op_(op) {} + OperatorBase* Create(); + + private: + OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map, + const std::vector& format, InOutType type); + void BuildOpInOutArgList(); + void AddArgIntoGradOp(const OpInOutArg* arg, std::vector& in_out, + std::vector& format, VarIndexMap* varmap, int& idx, + bool is_grad) const; + void CompleteGradOp(OperatorBase* grad_op) const; + const OperatorBase* op_; + std::vector> arg_list_; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/grad_op_creator_test.cc b/paddle/framework/grad_op_creator_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..27ac65813120a2a682535a02bcecb882c4a7640d --- /dev/null +++ b/paddle/framework/grad_op_creator_test.cc @@ -0,0 +1,26 @@ +#include "paddle/framework/grad_op_creator.h" +#include +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" + +USE_OP(add_two); + +namespace paddle { +namespace framework { + +TEST(GradOpCreator, AddTwo) { + std::shared_ptr add_op( + OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {})); + std::shared_ptr grad_add_op = OpRegistry::CreateGradOp(add_op); + EXPECT_EQ(static_cast(grad_add_op->inputs_.size()), 4); + EXPECT_EQ(static_cast(grad_add_op->outputs_.size()), 2); + EXPECT_EQ(grad_add_op->Input("X"), "x"); + EXPECT_EQ(grad_add_op->Input("Y"), "y"); + EXPECT_EQ(grad_add_op->Input("Out"), "out"); + EXPECT_EQ(grad_add_op->Input("Out@GRAD"), "out@GRAD"); + EXPECT_EQ(grad_add_op->Output("X@GRAD"), "x@GRAD"); + EXPECT_EQ(grad_add_op->Output("Y@GRAD"), "y@GRAD"); +} + +} // namespace framework +} // namespace paddle \ No newline at end of file diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index 139425b356989f20f035d27ed4b678126d9417d6..bc23b63b35d37eea01ae6b9b8891e9cd94615898 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -15,14 +15,24 @@ */ #include "paddle/framework/net.h" +#include "paddle/framework/op_registry.h" namespace paddle { namespace framework { +std::shared_ptr AddBackwardOp(std::shared_ptr ForwardOps) { + auto grad_ops = std::make_shared(); + for (auto& op : ForwardOps->ops_) { + auto op_grad = OpRegistry::CreateGradOp(op); + grad_ops->AddOp(op_grad); + } + grad_ops->CompleteAddOp(); + return grad_ops; +} + void PlainNet::CompleteAddOp(bool calc) { add_op_done_ = true; if (!calc) return; - std::unordered_set input_set; std::unordered_set output_set; std::unordered_set temp_output; diff --git a/paddle/framework/net.h b/paddle/framework/net.h index b2c64a8675cbb592dfb5d7233c8f73b22cf25621..3264f1f565e3efc188e7835cb9b44e5741e1eea8 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -100,5 +100,7 @@ class PlainNet : public Net { } }; +std::shared_ptr AddBackwardOp(std::shared_ptr ForwardOps); + } // namespace framework } // namespace paddle diff --git a/paddle/framework/net_op_test.cc b/paddle/framework/net_op_test.cc index c179042c81a04741ba1d30ec00adc369b576b941..20b42cbb4923590804a7806ac42347590c73d62f 100644 --- a/paddle/framework/net_op_test.cc +++ b/paddle/framework/net_op_test.cc @@ -3,17 +3,24 @@ #include #include -namespace pd = paddle::framework; +USE_OP(add_two); +USE_OP(mul); +USE_OP(sigmoid); +USE_OP(softmax); + +namespace paddle { +namespace framework { static int infer_shape_cnt = 0; static int run_cnt = 0; -class TestOp : public pd::OperatorBase { +class TestOp : public OperatorBase { public: - void InferShape(const std::shared_ptr& scope) const override { + void InferShape( + const std::shared_ptr& scope) const override { ++infer_shape_cnt; } - void Run(const std::shared_ptr& scope, + void Run(const std::shared_ptr& scope, const paddle::platform::DeviceContext& dev_ctx) const override { ++run_cnt; } @@ -33,7 +40,7 @@ void AssertSameVectorWithoutOrder(const std::vector& expected, } TEST(OpKernel, all) { - auto net = std::make_shared(); + auto net = std::make_shared(); ASSERT_NE(net, nullptr); auto op1 = std::make_shared(); @@ -55,13 +62,37 @@ TEST(OpKernel, all) { ASSERT_EQ(1UL, tmp_idx.size()); ASSERT_EQ("y", net->outputs_[tmp_idx[0]]); - auto scope = std::make_shared(); - paddle::platform::CPUDeviceContext dev_ctx; + auto scope = std::make_shared(); + platform::CPUDeviceContext dev_ctx; net->InferShape(scope); net->Run(scope, dev_ctx); ASSERT_EQ(2, infer_shape_cnt); ASSERT_EQ(2, run_cnt); - ASSERT_THROW(net->AddOp(op2), std::runtime_error); } +TEST(AddBackwardOp, TestGradOp) { + auto net = std::make_shared(); + ASSERT_NE(net, nullptr); + net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {})); + net->AddOp( + framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {})); + net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""}, {})); + auto grad_ops = AddBackwardOp(net); + for (auto& op : grad_ops->ops_) { + op->DebugString(); + } +} + +// TODO(zhihong): add fc grad without registering. +// TEST(AddBackwardOp, TestNoGradOp) { +// auto net = std::make_shared(); +// ASSERT_NE(net, nullptr); +// net->AddOp(framework::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Y"}, +// {})); auto grad_ops = AddBackwardOp(net); for (auto& op : grad_ops->ops_) { +// op->DebugString(); +// } +// } + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/net_test.cc b/paddle/framework/net_test.cc deleted file mode 100644 index a8e31c1497519ce60da004bc0a3e52403593497c..0000000000000000000000000000000000000000 --- a/paddle/framework/net_test.cc +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/framework/net.h" -#include "paddle/framework/op_registry.h" - -#include - -namespace paddle { -namespace framework { -class FakeFC : public Operator {} -} // namespace framework -} // namespace paddle diff --git a/paddle/framework/op_proto.proto b/paddle/framework/op_proto.proto index 596b8588e783722362815f75db876931f83484ec..366c84e53dc29e41eefbaef0a6452e01c4fe37bd 100644 --- a/paddle/framework/op_proto.proto +++ b/paddle/framework/op_proto.proto @@ -84,6 +84,11 @@ message VarProto { // "temporary_index": [1] // } optional bool temporary = 4 [default=false]; + + // The gradient of operator can be ignored immediately + // e.g. operator AddOp, y = x1 + x2, the gradient of dy/dx1, dy/dx2 + // can be ignored for the future optimized on graph. + optional bool ignore_gradient = 6; } // Op protocol message for 3rd-party language binding. @@ -105,4 +110,5 @@ message OpProto { // The type of that Op. required string type = 5; + } diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 165a68c1cf7abbbfacd136e890eaa1f18ed39e69..41c78309327342ff47982fc105eadf777c7e59c7 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -1,3 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #pragma once #include @@ -6,9 +20,9 @@ #include #include #include "paddle/framework/attr_checker.h" +#include "paddle/framework/grad_op_creator.h" #include "paddle/framework/op_desc.pb.h" -#include "paddle/framework/op_proto.pb.h" -#include "paddle/framework/operator.h" +#include "paddle/framework/scope.h" namespace paddle { namespace framework { @@ -73,25 +87,29 @@ class OpProtoAndCheckerMaker { protected: void AddInput(const std::string& name, const std::string& comment, - bool multiple = false) { + bool multiple = false, bool ignore_gradient = false) { auto input = proto_->mutable_inputs()->Add(); *input->mutable_name() = name; *input->mutable_comment() = comment; + input->set_ignore_gradient(ignore_gradient); input->set_multiple(multiple); if (multiple) { SetHasMultipleInput(); } } - void AddInputs(const std::string& name, const std::string& comment) { - AddInput(name, comment, true); + void AddInputs(const std::string& name, const std::string& comment, + bool ignore_gradient = false) { + AddInput(name, comment, true, ignore_gradient); } void AddOutput(const std::string& name, const std::string& comment, - bool temporary = false, bool multiple = false) { + bool temporary = false, bool multiple = false, + bool ignore_gradient = false) { auto output = proto_->mutable_outputs()->Add(); *output->mutable_name() = name; *output->mutable_comment() = comment; + output->set_ignore_gradient(ignore_gradient); output->set_multiple(multiple); if (multiple) { SetHasMultipleOutput(); @@ -103,8 +121,8 @@ class OpProtoAndCheckerMaker { } void AddOutputs(const std::string& name, const std::string& comment, - bool temporary = false) { - AddOutput(name, comment, temporary, true); + bool temporary = false, bool ignore_gradient = false) { + AddOutput(name, comment, temporary, true, ignore_gradient); } template @@ -205,8 +223,8 @@ class OpRegistry { template static void RegisterOp(const std::string& op_type) { creators()[op_type] = [] { return new OpType; }; - OpProto& op_proto = protos()[op_type]; OpAttrChecker& op_checker = op_checkers()[op_type]; + OpProto& op_proto = protos()[op_type]; auto maker = ProtoMakerType(&op_proto, &op_checker); maker.Validate(); *op_proto.mutable_type() = op_type; @@ -227,18 +245,24 @@ class OpRegistry { } } + template + static void RegisterGradOp(const std::string& op_type) { + grad_creators()[op_type] = [] { return new OpType; }; + } + static std::shared_ptr CreateOp(const std::string& type, const VarNameList& inputs, const VarNameList& outputs, const AttributeMap& attrs) { auto op_create_it = creators().find(type); PADDLE_ENFORCE(op_create_it != creators().end(), - "Operator %s cannot be found", type); + "Operator %s cannot be found.", type); auto op = op_create_it->second(); op->type_ = type; op->inputs_ = inputs; op->outputs_ = outputs; + op->attrs_ = attrs; op_checkers().at(type).Check(op->attrs_); @@ -274,18 +298,41 @@ class OpRegistry { return CreateOp(op_desc.type(), inputs, outputs, attrs); } + static std::shared_ptr CreateGradOp( + std::shared_ptr op) { + GradOpCreator creator(op.get()); + std::shared_ptr grad_op(creator.Create()); + grad_op->Init(); + return grad_op; + } + static std::unordered_map& protos() { static std::unordered_map protos_; return protos_; }; - private: + static std::unordered_map& grad_creators() { + static std::unordered_map grad_creators_; + return grad_creators_; + } + static std::unordered_map>& VarIndexMaps() { static std::unordered_map> maps_; return maps_; } + private: + static std::unordered_map& creators() { + static std::unordered_map creators_; + return creators_; + } + + static std::unordered_map& op_checkers() { + static std::unordered_map op_checkers_; + return op_checkers_; + }; + static void GenerateTempVariableName(OperatorBase* op) { static std::atomic gUniqId(0UL); for (auto& outname : op->outputs_) { @@ -296,16 +343,6 @@ class OpRegistry { } } } - - static std::unordered_map& creators() { - static std::unordered_map creators_; - return creators_; - } - - static std::unordered_map& op_checkers() { - static std::unordered_map op_checkers_; - return op_checkers_; - }; }; template @@ -316,6 +353,14 @@ class OpRegisterHelper { } }; +template +class GradOpRegisterHelper { + public: + GradOpRegisterHelper(const char* op_type) { + OpRegistry::RegisterGradOp(op_type); + } +}; + /** * check if MACRO is used in GLOBAL NAMESPACE. */ @@ -335,6 +380,17 @@ class OpRegisterHelper { __op_register_##__op_type##__(#__op_type); \ int __op_register_##__op_type##_handle__() { return 0; } +/** + * Macro to Register Gradient Operator. + */ +#define REGISTER_GRADIENT_OP(__op_type, __op_class) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_gradient_op__##__op_type, \ + "REGISTER_GRADIENT_OP must be in global namespace"); \ + static ::paddle::framework::GradOpRegisterHelper<__op_class> \ + __op_gradient_register_##__op_type##__(#__op_type); \ + int __op_gradient_register_##__op_type##_handle__() { return 0; } + /** * Macro to Register OperatorKernel. */ diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 6b8dbb39acd997321526bf3d44f1ced6acacdae3..f59314f8288d37f0c645b99811b1355f9a496c00 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -62,6 +62,11 @@ class OperatorBase { /// but it will be convert to a unique name in scope after OpCreator. static std::string TMP_VAR_NAME() { return "@TEMP@"; } + /// If a variable's name has a certain suffix, it means that the + /// variable is the gradient of another varibale. + /// e.g. Variable "x@GRAD" is the gradient of varibale "x". + static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; } + virtual ~OperatorBase() {} template diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index ebe9ceebe488437866fd6097531623eeb547f67a..ff60f9b314c86ad92218caea15ca5d9f6d996b4e 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -49,9 +49,22 @@ The equation is: Out = X + Y )DOC"); } }; + +class AddOpGrad : public framework::OperatorWithKernel { +protected: + void InferShape( + const std::vector &inputs, + const std::vector &outputs) const override {} + std::string DebugString() const override { + LOG(INFO) << "AddOpGrad"; + return ""; + } +}; + } // namespace operators } // namespace paddle REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker); +REGISTER_GRADIENT_OP(add_two, paddle::operators::AddOpGrad); REGISTER_OP_CPU_KERNEL( add_two, paddle::operators::AddKernel); diff --git a/paddle/operators/add_op_test.cc b/paddle/operators/add_op_test.cc index 53b354fedcacf2176aed8b504daf2046bdf96bb6..7fc1049893e171a17af92da7e813b2463874c9de 100644 --- a/paddle/operators/add_op_test.cc +++ b/paddle/operators/add_op_test.cc @@ -16,8 +16,13 @@ limitations under the License. */ #define private public #include USE_OP(add_two); +// USE_OP(add_two_grad); + TEST(AddOp, GetOpProto) { auto& protos = paddle::framework::OpRegistry::protos(); auto it = protos.find("add_two"); ASSERT_NE(it, protos.end()); -} \ No newline at end of file + auto& grad_creators = paddle::framework::OpRegistry::grad_creators(); + auto it1 = grad_creators.find("add_two"); + ASSERT_NE(it1, grad_creators.end()); +} diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 079a5800804345762b0b4bc7b8bc9ca042856ccc..89e0375a7a043730685c4c0883ac672bdd688159 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -52,9 +52,22 @@ The equation is: Out = X * Y } }; +class MulOpGrad : public framework::OperatorWithKernel { +protected: + void InferShape( + const std::vector &inputs, + const std::vector &outputs) const override {} + std::string DebugString() const override { + LOG(INFO) << "MulGrad"; + return ""; + } +}; + } // namespace operators } // namespace paddle REGISTER_OP(mul, paddle::operators::MulOp, paddle::operators::MulOpMaker); +REGISTER_GRADIENT_OP(mul, paddle::operators::MulOpGrad); + REGISTER_OP_CPU_KERNEL( mul, paddle::operators::MulKernel); diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc index 91f7d86aebae2e67b2fc18bf2c558fbe2e03de92..7dc58bbb10007545cd281ae7da359e4c2b32fae0 100644 --- a/paddle/operators/sigmoid_op.cc +++ b/paddle/operators/sigmoid_op.cc @@ -39,12 +39,25 @@ public: } }; +class SigmoidOpGrad : public framework::OperatorWithKernel { +protected: + void InferShape( + const std::vector &inputs, + const std::vector &outputs) const override {} + std::string DebugString() const override { + LOG(INFO) << "SigmoidGrad"; + return ""; + } +}; + } // namespace operators } // namespace paddle REGISTER_OP(sigmoid, paddle::operators::SigmoidOp, paddle::operators::SigmoidOpMaker); +REGISTER_GRADIENT_OP(sigmoid, paddle::operators::SigmoidOpGrad); + REGISTER_OP_CPU_KERNEL( sigmoid, paddle::operators::SigmoidKernel); diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index cf5e273de6be71e727f27d5e87d13d9235e31d0c..1d10a415d0208e1edb881eacad951a07fcbb8b5c 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -42,11 +42,23 @@ public: } }; +class SoftmaxOpGrad : public framework::OperatorWithKernel { +protected: + void InferShape( + const std::vector &inputs, + const std::vector &outputs) const override {} + std::string DebugString() const override { + LOG(INFO) << "SoftmaxOpGrad"; + return ""; + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker); +REGISTER_GRADIENT_OP(softmax, paddle::operators::SoftmaxOpGrad); REGISTER_OP_CPU_KERNEL(softmax, ops::SoftmaxKernel);