提交 02cde244 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #2949 from dzhwinter/backward

Backward
...@@ -49,6 +49,7 @@ message AttrProto { ...@@ -49,6 +49,7 @@ message AttrProto {
message VarProto { message VarProto {
required string name = 1; required string name = 1;
required string comment = 2; required string comment = 2;
required bool is_tensor = 3;
}; };
message OpProto { message OpProto {
......
...@@ -19,8 +19,10 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) ...@@ -19,8 +19,10 @@ cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor) cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc) cc_library(grad_op_creator SRCS grad_op_creator.cc DEPS op_proto operator)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator) cc_library(op_registry SRCS op_registry.cc DEPS op_desc grad_op_creator)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
cc_test(grad_op_creator_test SRCS grad_op_creator_test.cc DEPS grad_op_creator op_registry add_op)
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module. # Generate an empty __init__.py to make framework_py_proto as a valid python module.
...@@ -28,5 +30,6 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch ...@@ -28,5 +30,6 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch
add_dependencies(framework_py_proto framework_py_proto_init) add_dependencies(framework_py_proto framework_py_proto_init)
proto_library(net_proto SRCS net_proto.proto DEPS op_proto) proto_library(net_proto SRCS net_proto.proto DEPS op_proto)
# cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op)
cc_library(net SRCS net.cc DEPS operator net_proto op_registry) cc_library(net SRCS net.cc DEPS operator net_proto op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net) cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/grad_op_creator.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace framework {
OperatorBase* GradOpCreator::Create() {
BuildOpInOutArgList();
OperatorBase* grad_op = OpRegistry::grad_creators().at(op_->type_)();
CompleteGradOp(grad_op);
return grad_op;
}
OpInOutArg* GradOpCreator::BuildArg(const VarProto& var,
const VarIndexMap& var_map,
const std::vector<int>& format,
InOutType type) {
int idx = var_map.at(var.name());
int begin_idx = format.empty() ? idx : format.at(idx);
int end_idx = format.empty() ? idx + 1 : format.at(idx + 1);
return new OpInOutArg(var.name(), type, !var.ignore_gradient(), begin_idx,
end_idx);
}
void GradOpCreator::BuildOpInOutArgList() {
const OpProto& op_proto = OpRegistry::protos().at(op_->type_);
const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_->type_));
const std::vector<int>& in_format =
op_->attrs_.count("input_format")
? op_->GetAttr<std::vector<int>>("input_format")
: std::vector<int>();
const std::vector<int>& out_format =
op_->attrs_.count("output_format")
? op_->GetAttr<std::vector<int>>("output_format")
: std::vector<int>();
for (const auto& var : op_proto.inputs()) {
arg_list_.emplace_back(
std::shared_ptr<OpInOutArg>(BuildArg(var, var_map, in_format, IN)));
}
for (const auto& var : op_proto.outputs()) {
arg_list_.emplace_back(
std::shared_ptr<OpInOutArg>(BuildArg(var, var_map, out_format, OUT)));
}
}
void GradOpCreator::AddArgIntoGradOp(const OpInOutArg* arg,
std::vector<std::string>& in_out,
std::vector<int>& format,
VarIndexMap* varmap, int& idx,
bool is_grad) const {
std::string var_name = arg->proto_name_;
if (is_grad) {
var_name += OperatorBase::GRAD_VAR_SUFFIX();
}
(*varmap)[var_name] = idx++;
size_t pre_sz = in_out.size();
auto base_it =
arg->type_ == IN ? op_->inputs_.begin() : op_->outputs_.begin();
std::copy(base_it + arg->begin_idx_, base_it + arg->end_idx_,
std::back_inserter(in_out));
if (is_grad) {
for (size_t i = pre_sz; i < in_out.size(); ++i) {
in_out[i] += OperatorBase::GRAD_VAR_SUFFIX();
}
}
format.push_back(in_out.size());
}
void GradOpCreator::CompleteGradOp(OperatorBase* grad_op) const {
grad_op->type_ = op_->type_ + "@GRAD"; // not necessary
grad_op->attrs_ = op_->attrs_;
grad_op->attrs_.erase("input_format");
grad_op->attrs_.erase("output_format");
VarIndexMap* grad_varmap = new VarIndexMap();
int in_idx = 0;
int out_idx = 0;
std::vector<int> in_format({0});
std::vector<int> out_format({0});
for (const auto& arg : arg_list_) {
// op_'s inputs_ and outputs_
if (arg->needed_in_grad_) {
AddArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap,
in_idx, false);
}
if (arg->type_ == IN) {
// gradients of op_'s inputs_
AddArgIntoGradOp(arg.get(), grad_op->outputs_, out_format, grad_varmap,
out_idx, true);
} else {
// gradients of op_'s outputs_
AddArgIntoGradOp(arg.get(), grad_op->inputs_, in_format, grad_varmap,
in_idx, true);
}
}
grad_op->attrs_["input_format"] = in_format;
grad_op->attrs_["output_format"] = out_format;
grad_op->in_out_idxs_.reset(grad_varmap);
}
} // namespace framework
} // namespace paddle
#pragma once
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace framework {
class OpRegistry;
enum InOutType { IN, OUT };
struct OpInOutArg {
OpInOutArg(const std::string& proto_name, const InOutType& type,
bool needed_in_grad, size_t begin_idx, size_t end_idx)
: proto_name_(proto_name),
type_(type),
needed_in_grad_(needed_in_grad),
begin_idx_(begin_idx),
end_idx_(end_idx) {}
std::string proto_name_;
InOutType type_;
bool needed_in_grad_;
size_t begin_idx_;
size_t end_idx_;
};
class GradOpCreator {
using VarIndexMap = std::unordered_map<std::string, int>;
public:
GradOpCreator(const OperatorBase* op) : op_(op) {}
OperatorBase* Create();
private:
OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map,
const std::vector<int>& format, InOutType type);
void BuildOpInOutArgList();
void AddArgIntoGradOp(const OpInOutArg* arg, std::vector<std::string>& in_out,
std::vector<int>& format, VarIndexMap* varmap, int& idx,
bool is_grad) const;
void CompleteGradOp(OperatorBase* grad_op) const;
const OperatorBase* op_;
std::vector<std::shared_ptr<OpInOutArg>> arg_list_;
};
} // namespace framework
} // namespace paddle
#include "paddle/framework/grad_op_creator.h"
#include <gtest/gtest.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
USE_OP(add_two);
namespace paddle {
namespace framework {
TEST(GradOpCreator, AddTwo) {
std::shared_ptr<OperatorBase> add_op(
OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {}));
std::shared_ptr<OperatorBase> grad_add_op = OpRegistry::CreateGradOp(add_op);
EXPECT_EQ(static_cast<int>(grad_add_op->inputs_.size()), 4);
EXPECT_EQ(static_cast<int>(grad_add_op->outputs_.size()), 2);
EXPECT_EQ(grad_add_op->Input("X"), "x");
EXPECT_EQ(grad_add_op->Input("Y"), "y");
EXPECT_EQ(grad_add_op->Input("Out"), "out");
EXPECT_EQ(grad_add_op->Input("Out@GRAD"), "out@GRAD");
EXPECT_EQ(grad_add_op->Output("X@GRAD"), "x@GRAD");
EXPECT_EQ(grad_add_op->Output("Y@GRAD"), "y@GRAD");
}
} // namespace framework
} // namespace paddle
\ No newline at end of file
...@@ -15,14 +15,24 @@ ...@@ -15,14 +15,24 @@
*/ */
#include "paddle/framework/net.h" #include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps) {
auto grad_ops = std::make_shared<PlainNet>();
for (auto& op : ForwardOps->ops_) {
auto op_grad = OpRegistry::CreateGradOp(op);
grad_ops->AddOp(op_grad);
}
grad_ops->CompleteAddOp();
return grad_ops;
}
void PlainNet::CompleteAddOp(bool calc) { void PlainNet::CompleteAddOp(bool calc) {
add_op_done_ = true; add_op_done_ = true;
if (!calc) return; if (!calc) return;
std::unordered_set<std::string> input_set; std::unordered_set<std::string> input_set;
std::unordered_set<std::string> output_set; std::unordered_set<std::string> output_set;
std::unordered_set<std::string> temp_output; std::unordered_set<std::string> temp_output;
......
...@@ -100,5 +100,7 @@ class PlainNet : public Net { ...@@ -100,5 +100,7 @@ class PlainNet : public Net {
} }
}; };
std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -3,17 +3,24 @@ ...@@ -3,17 +3,24 @@
#include <paddle/framework/op_registry.h> #include <paddle/framework/op_registry.h>
#include <paddle/framework/operator.h> #include <paddle/framework/operator.h>
namespace pd = paddle::framework; USE_OP(add_two);
USE_OP(mul);
USE_OP(sigmoid);
USE_OP(softmax);
namespace paddle {
namespace framework {
static int infer_shape_cnt = 0; static int infer_shape_cnt = 0;
static int run_cnt = 0; static int run_cnt = 0;
class TestOp : public pd::OperatorBase { class TestOp : public OperatorBase {
public: public:
void InferShape(const std::shared_ptr<pd::Scope>& scope) const override { void InferShape(
const std::shared_ptr<framework::Scope>& scope) const override {
++infer_shape_cnt; ++infer_shape_cnt;
} }
void Run(const std::shared_ptr<pd::Scope>& scope, void Run(const std::shared_ptr<framework::Scope>& scope,
const paddle::platform::DeviceContext& dev_ctx) const override { const paddle::platform::DeviceContext& dev_ctx) const override {
++run_cnt; ++run_cnt;
} }
...@@ -33,7 +40,7 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected, ...@@ -33,7 +40,7 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
} }
TEST(OpKernel, all) { TEST(OpKernel, all) {
auto net = std::make_shared<paddle::framework::PlainNet>(); auto net = std::make_shared<PlainNet>();
ASSERT_NE(net, nullptr); ASSERT_NE(net, nullptr);
auto op1 = std::make_shared<TestOp>(); auto op1 = std::make_shared<TestOp>();
...@@ -55,13 +62,37 @@ TEST(OpKernel, all) { ...@@ -55,13 +62,37 @@ TEST(OpKernel, all) {
ASSERT_EQ(1UL, tmp_idx.size()); ASSERT_EQ(1UL, tmp_idx.size());
ASSERT_EQ("y", net->outputs_[tmp_idx[0]]); ASSERT_EQ("y", net->outputs_[tmp_idx[0]]);
auto scope = std::make_shared<pd::Scope>(); auto scope = std::make_shared<Scope>();
paddle::platform::CPUDeviceContext dev_ctx; platform::CPUDeviceContext dev_ctx;
net->InferShape(scope); net->InferShape(scope);
net->Run(scope, dev_ctx); net->Run(scope, dev_ctx);
ASSERT_EQ(2, infer_shape_cnt); ASSERT_EQ(2, infer_shape_cnt);
ASSERT_EQ(2, run_cnt); ASSERT_EQ(2, run_cnt);
ASSERT_THROW(net->AddOp(op2), std::runtime_error); ASSERT_THROW(net->AddOp(op2), std::runtime_error);
} }
TEST(AddBackwardOp, TestGradOp) {
auto net = std::make_shared<PlainNet>();
ASSERT_NE(net, nullptr);
net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {}));
net->AddOp(
framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {}));
net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""}, {}));
auto grad_ops = AddBackwardOp(net);
for (auto& op : grad_ops->ops_) {
op->DebugString();
}
}
// TODO(zhihong): add fc grad without registering.
// TEST(AddBackwardOp, TestNoGradOp) {
// auto net = std::make_shared<PlainNet>();
// ASSERT_NE(net, nullptr);
// net->AddOp(framework::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Y"},
// {})); auto grad_ops = AddBackwardOp(net); for (auto& op : grad_ops->ops_) {
// op->DebugString();
// }
// }
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
namespace paddle {
namespace framework {
class FakeFC : public Operator {}
} // namespace framework
} // namespace paddle
...@@ -84,6 +84,11 @@ message VarProto { ...@@ -84,6 +84,11 @@ message VarProto {
// "temporary_index": [1] // "temporary_index": [1]
// } // }
optional bool temporary = 4 [default=false]; optional bool temporary = 4 [default=false];
// The gradient of operator can be ignored immediately
// e.g. operator AddOp, y = x1 + x2, the gradient of dy/dx1, dy/dx2
// can be ignored for the future optimized on graph.
optional bool ignore_gradient = 6;
} }
// Op protocol message for 3rd-party language binding. // Op protocol message for 3rd-party language binding.
...@@ -105,4 +110,5 @@ message OpProto { ...@@ -105,4 +110,5 @@ message OpProto {
// The type of that Op. // The type of that Op.
required string type = 5; required string type = 5;
} }
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once #pragma once
#include <algorithm> #include <algorithm>
...@@ -6,9 +20,9 @@ ...@@ -6,9 +20,9 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "paddle/framework/attr_checker.h" #include "paddle/framework/attr_checker.h"
#include "paddle/framework/grad_op_creator.h"
#include "paddle/framework/op_desc.pb.h" #include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h" #include "paddle/framework/scope.h"
#include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -73,25 +87,29 @@ class OpProtoAndCheckerMaker { ...@@ -73,25 +87,29 @@ class OpProtoAndCheckerMaker {
protected: protected:
void AddInput(const std::string& name, const std::string& comment, void AddInput(const std::string& name, const std::string& comment,
bool multiple = false) { bool multiple = false, bool ignore_gradient = false) {
auto input = proto_->mutable_inputs()->Add(); auto input = proto_->mutable_inputs()->Add();
*input->mutable_name() = name; *input->mutable_name() = name;
*input->mutable_comment() = comment; *input->mutable_comment() = comment;
input->set_ignore_gradient(ignore_gradient);
input->set_multiple(multiple); input->set_multiple(multiple);
if (multiple) { if (multiple) {
SetHasMultipleInput(); SetHasMultipleInput();
} }
} }
void AddInputs(const std::string& name, const std::string& comment) { void AddInputs(const std::string& name, const std::string& comment,
AddInput(name, comment, true); bool ignore_gradient = false) {
AddInput(name, comment, true, ignore_gradient);
} }
void AddOutput(const std::string& name, const std::string& comment, void AddOutput(const std::string& name, const std::string& comment,
bool temporary = false, bool multiple = false) { bool temporary = false, bool multiple = false,
bool ignore_gradient = false) {
auto output = proto_->mutable_outputs()->Add(); auto output = proto_->mutable_outputs()->Add();
*output->mutable_name() = name; *output->mutable_name() = name;
*output->mutable_comment() = comment; *output->mutable_comment() = comment;
output->set_ignore_gradient(ignore_gradient);
output->set_multiple(multiple); output->set_multiple(multiple);
if (multiple) { if (multiple) {
SetHasMultipleOutput(); SetHasMultipleOutput();
...@@ -103,8 +121,8 @@ class OpProtoAndCheckerMaker { ...@@ -103,8 +121,8 @@ class OpProtoAndCheckerMaker {
} }
void AddOutputs(const std::string& name, const std::string& comment, void AddOutputs(const std::string& name, const std::string& comment,
bool temporary = false) { bool temporary = false, bool ignore_gradient = false) {
AddOutput(name, comment, temporary, true); AddOutput(name, comment, temporary, true, ignore_gradient);
} }
template <typename T> template <typename T>
...@@ -205,8 +223,8 @@ class OpRegistry { ...@@ -205,8 +223,8 @@ class OpRegistry {
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
static void RegisterOp(const std::string& op_type) { static void RegisterOp(const std::string& op_type) {
creators()[op_type] = [] { return new OpType; }; creators()[op_type] = [] { return new OpType; };
OpProto& op_proto = protos()[op_type];
OpAttrChecker& op_checker = op_checkers()[op_type]; OpAttrChecker& op_checker = op_checkers()[op_type];
OpProto& op_proto = protos()[op_type];
auto maker = ProtoMakerType(&op_proto, &op_checker); auto maker = ProtoMakerType(&op_proto, &op_checker);
maker.Validate(); maker.Validate();
*op_proto.mutable_type() = op_type; *op_proto.mutable_type() = op_type;
...@@ -227,18 +245,24 @@ class OpRegistry { ...@@ -227,18 +245,24 @@ class OpRegistry {
} }
} }
template <typename OpType>
static void RegisterGradOp(const std::string& op_type) {
grad_creators()[op_type] = [] { return new OpType; };
}
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type, static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameList& inputs, const VarNameList& inputs,
const VarNameList& outputs, const VarNameList& outputs,
const AttributeMap& attrs) { const AttributeMap& attrs) {
auto op_create_it = creators().find(type); auto op_create_it = creators().find(type);
PADDLE_ENFORCE(op_create_it != creators().end(), PADDLE_ENFORCE(op_create_it != creators().end(),
"Operator %s cannot be found", type); "Operator %s cannot be found.", type);
auto op = op_create_it->second(); auto op = op_create_it->second();
op->type_ = type; op->type_ = type;
op->inputs_ = inputs; op->inputs_ = inputs;
op->outputs_ = outputs; op->outputs_ = outputs;
op->attrs_ = attrs; op->attrs_ = attrs;
op_checkers().at(type).Check(op->attrs_); op_checkers().at(type).Check(op->attrs_);
...@@ -274,18 +298,41 @@ class OpRegistry { ...@@ -274,18 +298,41 @@ class OpRegistry {
return CreateOp(op_desc.type(), inputs, outputs, attrs); return CreateOp(op_desc.type(), inputs, outputs, attrs);
} }
static std::shared_ptr<OperatorBase> CreateGradOp(
std::shared_ptr<OperatorBase> op) {
GradOpCreator creator(op.get());
std::shared_ptr<OperatorBase> grad_op(creator.Create());
grad_op->Init();
return grad_op;
}
static std::unordered_map<std::string, OpProto>& protos() { static std::unordered_map<std::string, OpProto>& protos() {
static std::unordered_map<std::string, OpProto> protos_; static std::unordered_map<std::string, OpProto> protos_;
return protos_; return protos_;
}; };
private: static std::unordered_map<std::string, OpCreator>& grad_creators() {
static std::unordered_map<std::string, OpCreator> grad_creators_;
return grad_creators_;
}
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>& static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>&
VarIndexMaps() { VarIndexMaps() {
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>> maps_; static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>> maps_;
return maps_; return maps_;
} }
private:
static std::unordered_map<std::string, OpCreator>& creators() {
static std::unordered_map<std::string, OpCreator> creators_;
return creators_;
}
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
};
static void GenerateTempVariableName(OperatorBase* op) { static void GenerateTempVariableName(OperatorBase* op) {
static std::atomic<size_t> gUniqId(0UL); static std::atomic<size_t> gUniqId(0UL);
for (auto& outname : op->outputs_) { for (auto& outname : op->outputs_) {
...@@ -296,16 +343,6 @@ class OpRegistry { ...@@ -296,16 +343,6 @@ class OpRegistry {
} }
} }
} }
static std::unordered_map<std::string, OpCreator>& creators() {
static std::unordered_map<std::string, OpCreator> creators_;
return creators_;
}
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
};
}; };
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
...@@ -316,6 +353,14 @@ class OpRegisterHelper { ...@@ -316,6 +353,14 @@ class OpRegisterHelper {
} }
}; };
template <typename OpType>
class GradOpRegisterHelper {
public:
GradOpRegisterHelper(const char* op_type) {
OpRegistry::RegisterGradOp<OpType>(op_type);
}
};
/** /**
* check if MACRO is used in GLOBAL NAMESPACE. * check if MACRO is used in GLOBAL NAMESPACE.
*/ */
...@@ -335,6 +380,17 @@ class OpRegisterHelper { ...@@ -335,6 +380,17 @@ class OpRegisterHelper {
__op_register_##__op_type##__(#__op_type); \ __op_register_##__op_type##__(#__op_type); \
int __op_register_##__op_type##_handle__() { return 0; } int __op_register_##__op_type##_handle__() { return 0; }
/**
* Macro to Register Gradient Operator.
*/
#define REGISTER_GRADIENT_OP(__op_type, __op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##__op_type, \
"REGISTER_GRADIENT_OP must be in global namespace"); \
static ::paddle::framework::GradOpRegisterHelper<__op_class> \
__op_gradient_register_##__op_type##__(#__op_type); \
int __op_gradient_register_##__op_type##_handle__() { return 0; }
/** /**
* Macro to Register OperatorKernel. * Macro to Register OperatorKernel.
*/ */
......
...@@ -62,6 +62,11 @@ class OperatorBase { ...@@ -62,6 +62,11 @@ class OperatorBase {
/// but it will be convert to a unique name in scope after OpCreator. /// but it will be convert to a unique name in scope after OpCreator.
static std::string TMP_VAR_NAME() { return "@TEMP@"; } static std::string TMP_VAR_NAME() { return "@TEMP@"; }
/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; }
virtual ~OperatorBase() {} virtual ~OperatorBase() {}
template <typename T> template <typename T>
......
...@@ -49,9 +49,22 @@ The equation is: Out = X + Y ...@@ -49,9 +49,22 @@ The equation is: Out = X + Y
)DOC"); )DOC");
} }
}; };
class AddOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {}
std::string DebugString() const override {
LOG(INFO) << "AddOpGrad";
return "";
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker); REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker);
REGISTER_GRADIENT_OP(add_two, paddle::operators::AddOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
add_two, paddle::operators::AddKernel<paddle::platform::CPUPlace, float>); add_two, paddle::operators::AddKernel<paddle::platform::CPUPlace, float>);
...@@ -16,8 +16,13 @@ limitations under the License. */ ...@@ -16,8 +16,13 @@ limitations under the License. */
#define private public #define private public
#include <paddle/framework/op_registry.h> #include <paddle/framework/op_registry.h>
USE_OP(add_two); USE_OP(add_two);
// USE_OP(add_two_grad);
TEST(AddOp, GetOpProto) { TEST(AddOp, GetOpProto) {
auto& protos = paddle::framework::OpRegistry::protos(); auto& protos = paddle::framework::OpRegistry::protos();
auto it = protos.find("add_two"); auto it = protos.find("add_two");
ASSERT_NE(it, protos.end()); ASSERT_NE(it, protos.end());
auto& grad_creators = paddle::framework::OpRegistry::grad_creators();
auto it1 = grad_creators.find("add_two");
ASSERT_NE(it1, grad_creators.end());
} }
...@@ -52,9 +52,22 @@ The equation is: Out = X * Y ...@@ -52,9 +52,22 @@ The equation is: Out = X * Y
} }
}; };
class MulOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {}
std::string DebugString() const override {
LOG(INFO) << "MulGrad";
return "";
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(mul, paddle::operators::MulOp, paddle::operators::MulOpMaker); REGISTER_OP(mul, paddle::operators::MulOp, paddle::operators::MulOpMaker);
REGISTER_GRADIENT_OP(mul, paddle::operators::MulOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mul, paddle::operators::MulKernel<paddle::platform::CPUPlace, float>); mul, paddle::operators::MulKernel<paddle::platform::CPUPlace, float>);
...@@ -39,12 +39,25 @@ public: ...@@ -39,12 +39,25 @@ public:
} }
}; };
class SigmoidOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {}
std::string DebugString() const override {
LOG(INFO) << "SigmoidGrad";
return "";
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(sigmoid, REGISTER_OP(sigmoid,
paddle::operators::SigmoidOp, paddle::operators::SigmoidOp,
paddle::operators::SigmoidOpMaker); paddle::operators::SigmoidOpMaker);
REGISTER_GRADIENT_OP(sigmoid, paddle::operators::SigmoidOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sigmoid, sigmoid,
paddle::operators::SigmoidKernel<paddle::platform::CPUPlace, float>); paddle::operators::SigmoidKernel<paddle::platform::CPUPlace, float>);
...@@ -42,11 +42,23 @@ public: ...@@ -42,11 +42,23 @@ public:
} }
}; };
class SoftmaxOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(
const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override {}
std::string DebugString() const override {
LOG(INFO) << "SoftmaxOpGrad";
return "";
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker); REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker);
REGISTER_GRADIENT_OP(softmax, paddle::operators::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL(softmax, REGISTER_OP_CPU_KERNEL(softmax,
ops::SoftmaxKernel<paddle::platform::CPUPlace, float>); ops::SoftmaxKernel<paddle::platform::CPUPlace, float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册