提交 816b4c8a 编写于 作者: D dongzhihong

"add backward Op"

上级 83f263e6
...@@ -15,6 +15,8 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) ...@@ -15,6 +15,8 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library(operator SRCS operator.cc DEPS op_desc device_context) cc_library(operator SRCS operator.cc DEPS op_desc device_context)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
# cc_library(fc_op SRCS fully_connected_op.cc DEPS operator)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc) cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator)
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
...@@ -23,5 +25,6 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch ...@@ -23,5 +25,6 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch
add_dependencies(framework_py_proto framework_py_proto_init) add_dependencies(framework_py_proto framework_py_proto_init)
proto_library(net_proto SRCS net_proto.proto DEPS op_proto) proto_library(net_proto SRCS net_proto.proto DEPS op_proto)
# cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op)
cc_library(net SRCS net.cc DEPS operator net_proto op_registry) cc_library(net SRCS net.cc DEPS operator net_proto op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net) cc_test(net_op_test SRCS net_op_test.cc DEPS net)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/fully_connected_op.h"
#include <iostream>
namespace paddle {
namespace framework {
void FCOp::Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const override {
std::cout << "FC" << std::endl;
}
void FCOp::InferShape(const ScopePtr& scope) const override {}
void FCGradientOp::Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const override {
std::cout << "FCGrad" << std::endl;
}
void FCGradientOp::InferShape(const ScopePtr& scope) const override {}
REGISTER_OP(my_fc, paddle::framework::FCOp,
paddle::framework::FCOpProtoAndCheckerMaker);
REGISTER_OP(my_fc_grad, paddle::framework::FCGradientOp,
paddle::framework::FCGradientOpProtoAndCheckerMaker);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace framework {
class FCOp : public OperatorBase {
public:
void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const override {
std::cout << "FC" << std::endl;
};
void InferShape(const ScopePtr& scope) const override{};
};
class FCOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
FCOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("x", "input data");
AddInput("w", "weights");
AddInput("b", "bias");
AddOutput("y", "output data");
AddComment("Fully connnect op");
}
};
class FCGradientOp : public OperatorBase {
void Run(const ScopePtr& scope,
const platform::DeviceContext& dev_ctx) const override {
std::cout << "FCGrad" << std::endl;
};
void InferShape(const ScopePtr& scope) const override{};
};
// class FCGradientOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {};
} // namespace framework
} // namespace paddle
...@@ -15,10 +15,24 @@ ...@@ -15,10 +15,24 @@
*/ */
#include "paddle/framework/net.h" #include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps) {
// NetPtr->reset(new PlainNet);
// NetPtr grad_ops = new PlainNet;
std::shared_ptr<PlainNet> grad_ops;
grad_ops.reset(new PlainNet);
for (auto& op : ForwardOps->ops_) {
auto op_grad = OpRegistry::CreateGradOp(op);
grad_ops->AddOp(op_grad);
}
grad_ops->CompleteAddOp();
return grad_ops;
}
void PlainNet::CompleteAddOp() { void PlainNet::CompleteAddOp() {
std::unordered_set<std::string> input_set; std::unordered_set<std::string> input_set;
std::unordered_set<std::string> output_set; std::unordered_set<std::string> output_set;
......
...@@ -99,5 +99,7 @@ class PlainNet : public Net { ...@@ -99,5 +99,7 @@ class PlainNet : public Net {
} }
}; };
std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -3,18 +3,17 @@ ...@@ -3,18 +3,17 @@
#include <paddle/framework/op_registry.h> #include <paddle/framework/op_registry.h>
#include <paddle/framework/operator.h> #include <paddle/framework/operator.h>
namespace pd = paddle::framework; namespace paddle {
namespace framework {
static int infer_shape_cnt = 0; static int infer_shape_cnt = 0;
static int run_cnt = 0; static int run_cnt = 0;
class TestOp : public pd::OperatorBase { class TestOp : public OperatorBase {
public: public:
void InferShape(const paddle::framework::ScopePtr& scope) const override { void InferShape(const ScopePtr& scope) const override { ++infer_shape_cnt; }
++infer_shape_cnt; void Run(const ScopePtr& scope,
} const platform::DeviceContext& dev_ctx) const override {
void Run(const paddle::framework::ScopePtr& scope,
const paddle::platform::DeviceContext& dev_ctx) const override {
++run_cnt; ++run_cnt;
} }
}; };
...@@ -32,36 +31,65 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected, ...@@ -32,36 +31,65 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
} }
} }
TEST(OpKernel, all) { class PlainNetTest : public testing::Test {
auto net = std::make_shared<paddle::framework::PlainNet>(); virtual void SetUp() {
ASSERT_NE(net, nullptr); net_ = std::make_shared<PlainNet>();
ASSERT_NE(net_, nullptr);
auto op1 = std::make_shared<TestOp>(); auto op1 = std::make_shared<TestOp>();
op1->inputs_ = {"x", "w1", "b1"}; op1->inputs_ = {"x", "w1", "b1"};
op1->outputs_ = {"y"}; op1->outputs_ = {"y"};
net->AddOp(op1); net_->AddOp(op1);
auto op2 = std::make_shared<TestOp>(); auto op2 = std::make_shared<TestOp>();
op2->inputs_ = {"y", "w2", "b2"}; op2->inputs_ = {"y", "w2", "b2"};
op2->outputs_ = {"z"}; op2->outputs_ = {"z"};
net->AddOp(op2); net_->AddOp(op2);
net_->CompleteAddOp();
}
virtual void TearDown() {}
net->CompleteAddOp(); void TestOpKernel() {
AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"}, net->inputs_); AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"}, net_->inputs_);
AssertSameVectorWithoutOrder({"y", "z"}, net->outputs_); AssertSameVectorWithoutOrder({"y", "z"}, net_->outputs_);
auto tmp_idx_iter = net->attrs_.find("temporary_index"); auto tmp_idx_iter = net_->attrs_.find("temporary_index");
ASSERT_NE(net->attrs_.end(), tmp_idx_iter); ASSERT_NE(net_->attrs_.end(), tmp_idx_iter);
auto& tmp_idx = boost::get<std::vector<int>>(tmp_idx_iter->second); auto& tmp_idx = boost::get<std::vector<int>>(tmp_idx_iter->second);
ASSERT_EQ(1UL, tmp_idx.size()); ASSERT_EQ(1UL, tmp_idx.size());
ASSERT_EQ("y", net->outputs_[tmp_idx[0]]); ASSERT_EQ("y", net_->outputs_[tmp_idx[0]]);
auto scope = std::make_shared<pd::Scope>(); auto scope = std::make_shared<Scope>();
paddle::platform::CPUDeviceContext dev_ctx; platform::CPUDeviceContext dev_ctx;
net->InferShape(scope); net_->InferShape(scope);
net->Run(scope, dev_ctx); net_->Run(scope, dev_ctx);
ASSERT_EQ(2, infer_shape_cnt); ASSERT_EQ(2, infer_shape_cnt);
ASSERT_EQ(2, run_cnt); ASSERT_EQ(2, run_cnt);
ASSERT_THROW(net->AddOp(op2), paddle::framework::EnforceNotMet); ASSERT_THROW(net_->AddOp(op2), EnforceNotMet);
}
void TestAddBackwardOp() {
auto grad_ops = AddBackwardOp(net_);
for (auto& op : grad_ops->ops_) {
op->DebugString();
}
}
private:
std::shared_ptr<PlainNet> net_;
};
TEST(OpKernel, all) {
PlainNetTest net;
net->TestOpKernel();
} }
TEST(AddBackwardOp, TestAddBackwardOp) {
PlainNetTest net;
net->TestAddBackwardOp();
}
} // namespace framework
} // namespace paddle
...@@ -13,12 +13,15 @@ ...@@ -13,12 +13,15 @@
limitations under the License. */ limitations under the License. */
#include "paddle/framework/net.h" #include "paddle/framework/net.h"
#include "paddle/framework/fully_connected_op.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class FakeFC : public Operator {}
TEST(AddBackwardOp, ALL)
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "paddle/framework/op_desc.pb.h" #include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h" #include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -188,8 +189,8 @@ class OpRegistry { ...@@ -188,8 +189,8 @@ class OpRegistry {
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
static void RegisterOp(const std::string& op_type) { static void RegisterOp(const std::string& op_type) {
creators()[op_type] = [] { return new OpType; }; creators()[op_type] = [] { return new OpType; };
OpProto& op_proto = protos()[op_type];
OpAttrChecker& op_checker = op_checkers()[op_type]; OpAttrChecker& op_checker = op_checkers()[op_type];
OpProto& op_proto = protos()[op_type];
ProtoMakerType(&op_proto, &op_checker); ProtoMakerType(&op_proto, &op_checker);
*op_proto.mutable_type() = op_type; *op_proto.mutable_type() = op_type;
PADDLE_ENFORCE( PADDLE_ENFORCE(
...@@ -198,6 +199,11 @@ class OpRegistry { ...@@ -198,6 +199,11 @@ class OpRegistry {
op_type, op_proto.InitializationErrorString()); op_type, op_proto.InitializationErrorString());
} }
template <typename OpType>
static void RegisterGradOp(const std::string& op_type) {
grad_creators()[op_type] = [] { return new OpType; };
}
static OperatorPtr CreateOp(const OpDesc& op_desc) { static OperatorPtr CreateOp(const OpDesc& op_desc) {
std::string op_type = op_desc.type(); std::string op_type = op_desc.type();
OperatorPtr op(creators().at(op_type)()); OperatorPtr op(creators().at(op_type)());
...@@ -216,6 +222,21 @@ class OpRegistry { ...@@ -216,6 +222,21 @@ class OpRegistry {
return op; return op;
} }
static OperatorPtr CreateGradOp(std::shared_ptr<OperatorBase> op) {
OperatorPtr op_grad(grad_creators().at(op->type_)());
op_grad->type_ = op->type_;
op_grad->inputs_.reserve(op->inputs_.size());
for (auto& input : op->inputs_) {
op_grad->inputs_.emplace_back(input);
op_grad->outputs_.emplace_back(input + "@grad");
}
for (auto& output : op->outputs_) {
op_grad->inputs_.emplace_back(output);
op_grad->inputs_.emplace_back(output + "@grad");
}
return op_grad;
}
static std::unordered_map<std::string, OpProto>& protos() { static std::unordered_map<std::string, OpProto>& protos() {
static std::unordered_map<std::string, OpProto> protos_; static std::unordered_map<std::string, OpProto> protos_;
return protos_; return protos_;
...@@ -231,6 +252,11 @@ class OpRegistry { ...@@ -231,6 +252,11 @@ class OpRegistry {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_; static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_; return op_checkers_;
}; };
static std::unordered_map<std::string, OpCreator>& grad_creators() {
static std::unordered_map<std::string, OpCreator> grad_creators_;
return grad_creators_;
}
}; };
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
...@@ -241,6 +267,14 @@ class OpRegisterHelper { ...@@ -241,6 +267,14 @@ class OpRegisterHelper {
} }
}; };
template <typename OpType>
class GradOpRegisterHelper {
public:
GradOpRegisterHelper(const char* op_type) {
OpRegistry::RegisterGradOp<OpType>(op_type);
}
};
/** /**
* check if MACRO is used in GLOBAL NAMESPACE. * check if MACRO is used in GLOBAL NAMESPACE.
*/ */
...@@ -260,6 +294,17 @@ class OpRegisterHelper { ...@@ -260,6 +294,17 @@ class OpRegisterHelper {
__op_register_##__op_type##__(#__op_type); \ __op_register_##__op_type##__(#__op_type); \
int __op_register_##__op_type##_handle__() { return 0; } int __op_register_##__op_type##_handle__() { return 0; }
/**
* Macro to Register Operator.
*/
#define REGISTER_GRADIENT_OP(__op_type, __op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##__op_type, \
"REGISTER_GRADIENT_OP must be in global namespace"); \
static ::paddle::framework::GradOpRegisterHelper<__op_class> \
__op_register_##__op_type##__(#__op_type); \
int __op_register_##__op_type##_handle__() { return 0; }
/** /**
* Macro to Register OperatorKernel. * Macro to Register OperatorKernel.
*/ */
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册