From 816b4c8ab08306b79d3994deebdc51fdd0186bd5 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Tue, 18 Jul 2017 20:18:49 +0800 Subject: [PATCH] "add backward Op" --- paddle/framework/CMakeLists.txt | 3 + paddle/framework/fully_connected_op.cc | 39 ++++++++++ paddle/framework/fully_connected_op.h | 52 +++++++++++++ paddle/framework/net.cc | 14 ++++ paddle/framework/net.h | 2 + paddle/framework/net_op_test.cc | 104 ++++++++++++++++--------- paddle/framework/net_test.cc | 5 +- paddle/framework/op_registry.h | 47 ++++++++++- 8 files changed, 226 insertions(+), 40 deletions(-) create mode 100644 paddle/framework/fully_connected_op.cc create mode 100644 paddle/framework/fully_connected_op.h diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index cc5b05ff0d5..429a9a19a91 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -15,6 +15,8 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) cc_library(operator SRCS operator.cc DEPS op_desc device_context) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) +# cc_library(fc_op SRCS fully_connected_op.cc DEPS operator) + cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) @@ -23,5 +25,6 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch add_dependencies(framework_py_proto framework_py_proto_init) proto_library(net_proto SRCS net_proto.proto DEPS op_proto) +# cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op) cc_library(net SRCS net.cc DEPS operator net_proto op_registry) cc_test(net_op_test SRCS net_op_test.cc DEPS net) diff --git a/paddle/framework/fully_connected_op.cc b/paddle/framework/fully_connected_op.cc new file mode 100644 index 00000000000..28be46366ff --- /dev/null +++ b/paddle/framework/fully_connected_op.cc @@ -0,0 +1,39 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/fully_connected_op.h" +#include +namespace paddle { +namespace framework { + +void FCOp::Run(const ScopePtr& scope, + const platform::DeviceContext& dev_ctx) const override { + std::cout << "FC" << std::endl; +} + +void FCOp::InferShape(const ScopePtr& scope) const override {} + +void FCGradientOp::Run(const ScopePtr& scope, + const platform::DeviceContext& dev_ctx) const override { + std::cout << "FCGrad" << std::endl; +} + +void FCGradientOp::InferShape(const ScopePtr& scope) const override {} + +REGISTER_OP(my_fc, paddle::framework::FCOp, + paddle::framework::FCOpProtoAndCheckerMaker); +REGISTER_OP(my_fc_grad, paddle::framework::FCGradientOp, + paddle::framework::FCGradientOpProtoAndCheckerMaker); +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/fully_connected_op.h b/paddle/framework/fully_connected_op.h new file mode 100644 index 00000000000..948116f653f --- /dev/null +++ b/paddle/framework/fully_connected_op.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace framework { +class FCOp : public OperatorBase { + public: + void Run(const ScopePtr& scope, + const platform::DeviceContext& dev_ctx) const override { + std::cout << "FC" << std::endl; + }; + void InferShape(const ScopePtr& scope) const override{}; +}; + +class FCOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { + public: + FCOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("x", "input data"); + AddInput("w", "weights"); + AddInput("b", "bias"); + AddOutput("y", "output data"); + AddComment("Fully connnect op"); + } +}; + +class FCGradientOp : public OperatorBase { + void Run(const ScopePtr& scope, + const platform::DeviceContext& dev_ctx) const override { + std::cout << "FCGrad" << std::endl; + }; + void InferShape(const ScopePtr& scope) const override{}; +}; + +// class FCGradientOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index 7311cda9a9a..14329159271 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -15,10 +15,24 @@ */ #include "paddle/framework/net.h" +#include "paddle/framework/op_registry.h" namespace paddle { namespace framework { +std::shared_ptr AddBackwardOp(std::shared_ptr ForwardOps) { + // NetPtr->reset(new PlainNet); + // NetPtr grad_ops = new PlainNet; + std::shared_ptr grad_ops; + grad_ops.reset(new PlainNet); + for (auto& op : ForwardOps->ops_) { + auto op_grad = OpRegistry::CreateGradOp(op); + grad_ops->AddOp(op_grad); + } + grad_ops->CompleteAddOp(); + return grad_ops; +} + void PlainNet::CompleteAddOp() { std::unordered_set input_set; std::unordered_set output_set; diff --git a/paddle/framework/net.h b/paddle/framework/net.h index 19a1620e29b..354319001fb 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -99,5 +99,7 @@ class PlainNet : public Net { } }; +std::shared_ptr AddBackwardOp(std::shared_ptr ForwardOps); + } // namespace framework } // namespace paddle diff --git a/paddle/framework/net_op_test.cc b/paddle/framework/net_op_test.cc index f5e1c22400a..d61233a8b40 100644 --- a/paddle/framework/net_op_test.cc +++ b/paddle/framework/net_op_test.cc @@ -3,18 +3,17 @@ #include #include -namespace pd = paddle::framework; +namespace paddle { +namespace framework { static int infer_shape_cnt = 0; static int run_cnt = 0; -class TestOp : public pd::OperatorBase { +class TestOp : public OperatorBase { public: - void InferShape(const paddle::framework::ScopePtr& scope) const override { - ++infer_shape_cnt; - } - void Run(const paddle::framework::ScopePtr& scope, - const paddle::platform::DeviceContext& dev_ctx) const override { + void InferShape(const ScopePtr& scope) const override { ++infer_shape_cnt; } + void Run(const ScopePtr& scope, + const platform::DeviceContext& dev_ctx) const override { ++run_cnt; } }; @@ -32,36 +31,65 @@ void AssertSameVectorWithoutOrder(const std::vector& expected, } } +class PlainNetTest : public testing::Test { + virtual void SetUp() { + net_ = std::make_shared(); + ASSERT_NE(net_, nullptr); + + auto op1 = std::make_shared(); + op1->inputs_ = {"x", "w1", "b1"}; + op1->outputs_ = {"y"}; + net_->AddOp(op1); + + auto op2 = std::make_shared(); + op2->inputs_ = {"y", "w2", "b2"}; + op2->outputs_ = {"z"}; + net_->AddOp(op2); + net_->CompleteAddOp(); + } + + virtual void TearDown() {} + + void TestOpKernel() { + AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"}, net_->inputs_); + AssertSameVectorWithoutOrder({"y", "z"}, net_->outputs_); + auto tmp_idx_iter = net_->attrs_.find("temporary_index"); + ASSERT_NE(net_->attrs_.end(), tmp_idx_iter); + auto& tmp_idx = boost::get>(tmp_idx_iter->second); + ASSERT_EQ(1UL, tmp_idx.size()); + ASSERT_EQ("y", net_->outputs_[tmp_idx[0]]); + + auto scope = std::make_shared(); + platform::CPUDeviceContext dev_ctx; + + net_->InferShape(scope); + net_->Run(scope, dev_ctx); + ASSERT_EQ(2, infer_shape_cnt); + ASSERT_EQ(2, run_cnt); + + ASSERT_THROW(net_->AddOp(op2), EnforceNotMet); + } + + void TestAddBackwardOp() { + auto grad_ops = AddBackwardOp(net_); + for (auto& op : grad_ops->ops_) { + op->DebugString(); + } + } + + private: + std::shared_ptr net_; +}; + TEST(OpKernel, all) { - auto net = std::make_shared(); - ASSERT_NE(net, nullptr); - - auto op1 = std::make_shared(); - op1->inputs_ = {"x", "w1", "b1"}; - op1->outputs_ = {"y"}; - net->AddOp(op1); - - auto op2 = std::make_shared(); - op2->inputs_ = {"y", "w2", "b2"}; - op2->outputs_ = {"z"}; - net->AddOp(op2); - - net->CompleteAddOp(); - AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"}, net->inputs_); - AssertSameVectorWithoutOrder({"y", "z"}, net->outputs_); - auto tmp_idx_iter = net->attrs_.find("temporary_index"); - ASSERT_NE(net->attrs_.end(), tmp_idx_iter); - auto& tmp_idx = boost::get>(tmp_idx_iter->second); - ASSERT_EQ(1UL, tmp_idx.size()); - ASSERT_EQ("y", net->outputs_[tmp_idx[0]]); - - auto scope = std::make_shared(); - paddle::platform::CPUDeviceContext dev_ctx; - - net->InferShape(scope); - net->Run(scope, dev_ctx); - ASSERT_EQ(2, infer_shape_cnt); - ASSERT_EQ(2, run_cnt); - - ASSERT_THROW(net->AddOp(op2), paddle::framework::EnforceNotMet); + PlainNetTest net; + net->TestOpKernel(); +} + +TEST(AddBackwardOp, TestAddBackwardOp) { + PlainNetTest net; + net->TestAddBackwardOp(); } + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/net_test.cc b/paddle/framework/net_test.cc index a8e31c14975..5afc0d9204b 100644 --- a/paddle/framework/net_test.cc +++ b/paddle/framework/net_test.cc @@ -13,12 +13,15 @@ limitations under the License. */ #include "paddle/framework/net.h" +#include "paddle/framework/fully_connected_op.h" #include "paddle/framework/op_registry.h" #include namespace paddle { namespace framework { -class FakeFC : public Operator {} + +TEST(AddBackwardOp, ALL) + } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 24f56b28128..9183a8b1dff 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -8,6 +8,7 @@ #include "paddle/framework/op_desc.pb.h" #include "paddle/framework/op_proto.pb.h" #include "paddle/framework/operator.h" +#include "paddle/framework/scope.h" namespace paddle { namespace framework { @@ -188,8 +189,8 @@ class OpRegistry { template static void RegisterOp(const std::string& op_type) { creators()[op_type] = [] { return new OpType; }; - OpProto& op_proto = protos()[op_type]; OpAttrChecker& op_checker = op_checkers()[op_type]; + OpProto& op_proto = protos()[op_type]; ProtoMakerType(&op_proto, &op_checker); *op_proto.mutable_type() = op_type; PADDLE_ENFORCE( @@ -198,6 +199,11 @@ class OpRegistry { op_type, op_proto.InitializationErrorString()); } + template + static void RegisterGradOp(const std::string& op_type) { + grad_creators()[op_type] = [] { return new OpType; }; + } + static OperatorPtr CreateOp(const OpDesc& op_desc) { std::string op_type = op_desc.type(); OperatorPtr op(creators().at(op_type)()); @@ -216,6 +222,21 @@ class OpRegistry { return op; } + static OperatorPtr CreateGradOp(std::shared_ptr op) { + OperatorPtr op_grad(grad_creators().at(op->type_)()); + op_grad->type_ = op->type_; + op_grad->inputs_.reserve(op->inputs_.size()); + for (auto& input : op->inputs_) { + op_grad->inputs_.emplace_back(input); + op_grad->outputs_.emplace_back(input + "@grad"); + } + for (auto& output : op->outputs_) { + op_grad->inputs_.emplace_back(output); + op_grad->inputs_.emplace_back(output + "@grad"); + } + return op_grad; + } + static std::unordered_map& protos() { static std::unordered_map protos_; return protos_; @@ -231,6 +252,11 @@ class OpRegistry { static std::unordered_map op_checkers_; return op_checkers_; }; + + static std::unordered_map& grad_creators() { + static std::unordered_map grad_creators_; + return grad_creators_; + } }; template @@ -241,6 +267,14 @@ class OpRegisterHelper { } }; +template +class GradOpRegisterHelper { + public: + GradOpRegisterHelper(const char* op_type) { + OpRegistry::RegisterGradOp(op_type); + } +}; + /** * check if MACRO is used in GLOBAL NAMESPACE. */ @@ -260,6 +294,17 @@ class OpRegisterHelper { __op_register_##__op_type##__(#__op_type); \ int __op_register_##__op_type##_handle__() { return 0; } +/** + * Macro to Register Operator. + */ +#define REGISTER_GRADIENT_OP(__op_type, __op_class) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_op__##__op_type, \ + "REGISTER_GRADIENT_OP must be in global namespace"); \ + static ::paddle::framework::GradOpRegisterHelper<__op_class> \ + __op_register_##__op_type##__(#__op_type); \ + int __op_register_##__op_type##_handle__() { return 0; } + /** * Macro to Register OperatorKernel. */ -- GitLab