提交 f77c63b8 编写于 作者: F fengjiayi

Merge branch 'feature/backward' of https://github.com/reyoung/Paddle into feature/backward

...@@ -33,4 +33,4 @@ cc_library(net SRCS net.cc DEPS op_registry) ...@@ -33,4 +33,4 @@ cc_library(net SRCS net.cc DEPS op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op) cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op)
cc_library(backward SRCS backward.cc DEPS net) cc_library(backward SRCS backward.cc DEPS net)
cc_test(backward_test SRCS backward_test.cc DEPS net) cc_test(backward_test SRCS backward_test.cc DEPS backward)
...@@ -12,8 +12,11 @@ ...@@ -12,8 +12,11 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/framework/backward.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -24,10 +27,9 @@ class EmptyOp : public OperatorBase { ...@@ -24,10 +27,9 @@ class EmptyOp : public OperatorBase {
const platform::DeviceContext &dev_ctx) const override {} const platform::DeviceContext &dev_ctx) const override {}
}; };
class RowwiseAddOp : public EmptyOp {}; class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
class RowwiseAddOpMaker : public OpProtoAndCheckerMaker {
public: public:
RowwiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input X of Add").IgnoreGradient(); AddInput("X", "Input X of Add").IgnoreGradient();
AddInput("b", "Bias of Add").IgnoreGradient(); AddInput("b", "Bias of Add").IgnoreGradient();
...@@ -36,15 +38,143 @@ class RowwiseAddOpMaker : public OpProtoAndCheckerMaker { ...@@ -36,15 +38,143 @@ class RowwiseAddOpMaker : public OpProtoAndCheckerMaker {
} }
}; };
class RowwiseAddGradOp : public EmptyOp {}; class MulOpMaker : public OpProtoAndCheckerMaker {
public:
MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("A", "A");
AddInput("B", "B");
AddOutput("Out", "Out");
AddComment("Mul");
}
};
class SigmoidOpMaker : public OpProtoAndCheckerMaker {
public:
SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "X");
AddOutput("Y", "Y");
AddComment("Sigmoid");
}
};
class FcOp : public NetOp {
public:
void Init() override {
AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")},
{Output("before_act")}, {}));
auto b_name = Input("b");
if (b_name != EMPTY_VAR_NAME()) {
AddOp(OpRegistry::CreateOp("rowwise_add", {Output("before_act"), b_name},
{Output("before_act")}, {}));
}
AddOp(OpRegistry::CreateOp("sigmoid", {Output("before_act")},
{Output("Out")}, {}));
CompleteAddOp(false);
}
};
class FcOpMaker : public OpProtoAndCheckerMaker {
public:
FcOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "x");
AddInput("W", "w");
AddInput("b", "b");
AddOutput("before_act", "before act").SetTemporary();
AddOutput("Out", "");
AddComment("");
}
};
class ManyOutputOpMaker : public OpProtoAndCheckerMaker {
public:
ManyOutputOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("x", "x");
AddOutput("y", "y");
AddOutput("z", "z");
AddComment("");
}
};
class FillZeroOpMaker : public OpProtoAndCheckerMaker {
public:
FillZeroOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("x", "x");
AddOutput("out", "out");
AddComment("");
}
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
namespace f = paddle::framework; namespace f = paddle::framework;
REGISTER_OP(rowwise_add, f::RowwiseAddOp, f::RowwiseAddOpMaker); using EnforceNotMet = paddle::platform::EnforceNotMet;
REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::RowwiseAddGradOp); REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker);
REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::EmptyOp);
REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker);
REGISTER_GRADIENT_OP(mul, mul_grad, f::EmptyOp);
REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker);
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, f::EmptyOp);
REGISTER_OP(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker);
REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp);
REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker);
TEST(Backward, simple_grad) { TEST(Backward, simple_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
auto gop = f::OpRegistry::CreateGradOp(*fwd);
ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]);
ASSERT_EQ("rowwise_add_grad", gop->type_);
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[0]);
ASSERT_EQ("b" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[1]);
// LOG(INFO) << gop->Output("X" + "@GRAD");
}
TEST(Backward, not_for_network) {
auto fwd =
f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"},
{{"temporary_index", std::vector<int>{1}}});
ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet);
}
TEST(Backward, all_input_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
auto backward = f::Backward(*fwd, {"X", "b"});
ASSERT_TRUE(backward->IsNetOp());
auto net = static_cast<f::NetOp *>(backward.get());
ASSERT_TRUE(net->ops_.empty());
}
TEST(Backward, all_output_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
auto backward = f::Backward(*fwd, {"Out"});
ASSERT_TRUE(backward->IsNetOp());
auto net = static_cast<f::NetOp *>(backward.get());
ASSERT_TRUE(net->ops_.empty());
}
TEST(Backward, part_of_output_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {});
auto backward = f::Backward(*fwd, {"Z"});
ASSERT_TRUE(backward->IsNetOp());
auto net = static_cast<f::NetOp *>(backward.get());
ASSERT_EQ(net->ops_.size(), 2);
auto &fill_zero = *net->ops_[0];
ASSERT_EQ("fill_zeros_like", fill_zero.type_);
ASSERT_EQ(1, fill_zero.inputs_.size());
ASSERT_EQ("Z", fill_zero.inputs_[0]);
ASSERT_EQ(1, fill_zero.outputs_.size());
ASSERT_EQ("Z@ZERO", fill_zero.outputs_[0]);
auto &d_many_out = *net->ops_[1];
ASSERT_EQ("many_output_op_grad", d_many_out.type_);
ASSERT_EQ(1 + 2 + 2, d_many_out.inputs_.size()); // I/O/OG
ASSERT_EQ("Z@ZERO", d_many_out.Input("z@GRAD"));
} }
\ No newline at end of file
...@@ -20,7 +20,7 @@ namespace framework { ...@@ -20,7 +20,7 @@ namespace framework {
OperatorBase* GradOpBuilder::Build() { OperatorBase* GradOpBuilder::Build() {
BuildOpInOutArgList(); BuildOpInOutArgList();
std::string grad_op_type = OpRegistry::grad_ops().at(op_->type_); std::string grad_op_type = OpRegistry::grad_ops().at(op_.type_);
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)(); OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
grad_op->type_ = grad_op_type; grad_op->type_ = grad_op_type;
CompleteGradOp(grad_op); CompleteGradOp(grad_op);
...@@ -39,15 +39,15 @@ OpInOutArg* GradOpBuilder::BuildArg(const VarProto& var, ...@@ -39,15 +39,15 @@ OpInOutArg* GradOpBuilder::BuildArg(const VarProto& var,
} }
void GradOpBuilder::BuildOpInOutArgList() { void GradOpBuilder::BuildOpInOutArgList() {
const OpProto& op_proto = OpRegistry::protos().at(op_->type_); const OpProto& op_proto = OpRegistry::protos().at(op_.type_);
const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_->type_)); const auto& var_map = *(OpRegistry::VarIndexMaps().at(op_.type_));
const std::vector<int>& in_format = const std::vector<int>& in_format =
op_->attrs_.count("input_format") op_.attrs_.count("input_format")
? op_->GetAttr<std::vector<int>>("input_format") ? op_.GetAttr<std::vector<int>>("input_format")
: std::vector<int>(); : std::vector<int>();
const std::vector<int>& out_format = const std::vector<int>& out_format =
op_->attrs_.count("output_format") op_.attrs_.count("output_format")
? op_->GetAttr<std::vector<int>>("output_format") ? op_.GetAttr<std::vector<int>>("output_format")
: std::vector<int>(); : std::vector<int>();
for (const auto& var : op_proto.inputs()) { for (const auto& var : op_proto.inputs()) {
arg_list_.emplace_back( arg_list_.emplace_back(
...@@ -70,8 +70,7 @@ void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg, ...@@ -70,8 +70,7 @@ void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg,
} }
(*varmap)[var_name] = idx++; (*varmap)[var_name] = idx++;
size_t pre_sz = in_out.size(); size_t pre_sz = in_out.size();
auto base_it = auto base_it = arg->type_ == IN ? op_.inputs_.begin() : op_.outputs_.begin();
arg->type_ == IN ? op_->inputs_.begin() : op_->outputs_.begin();
std::copy(base_it + arg->begin_idx_, base_it + arg->end_idx_, std::copy(base_it + arg->begin_idx_, base_it + arg->end_idx_,
std::back_inserter(in_out)); std::back_inserter(in_out));
if (is_grad) { if (is_grad) {
...@@ -83,7 +82,7 @@ void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg, ...@@ -83,7 +82,7 @@ void GradOpBuilder::AddArgIntoGradOp(const OpInOutArg* arg,
} }
void GradOpBuilder::CompleteGradOp(OperatorBase* grad_op) const { void GradOpBuilder::CompleteGradOp(OperatorBase* grad_op) const {
grad_op->attrs_ = op_->attrs_; grad_op->attrs_ = op_.attrs_;
grad_op->attrs_.erase("input_format"); grad_op->attrs_.erase("input_format");
grad_op->attrs_.erase("output_format"); grad_op->attrs_.erase("output_format");
VarIndexMap* grad_varmap = new VarIndexMap(); VarIndexMap* grad_varmap = new VarIndexMap();
......
...@@ -29,7 +29,7 @@ class GradOpBuilder { ...@@ -29,7 +29,7 @@ class GradOpBuilder {
using VarIndexMap = std::unordered_map<std::string, int>; using VarIndexMap = std::unordered_map<std::string, int>;
public: public:
GradOpBuilder(const OperatorBase* op) : op_(op) {} GradOpBuilder(const OperatorBase& op) : op_(op) {}
OperatorBase* Build(); OperatorBase* Build();
private: private:
...@@ -40,7 +40,7 @@ class GradOpBuilder { ...@@ -40,7 +40,7 @@ class GradOpBuilder {
std::vector<int>& format, VarIndexMap* varmap, int& idx, std::vector<int>& format, VarIndexMap* varmap, int& idx,
bool is_grad) const; bool is_grad) const;
void CompleteGradOp(OperatorBase* grad_op) const; void CompleteGradOp(OperatorBase* grad_op) const;
const OperatorBase* op_; const OperatorBase& op_;
std::vector<std::shared_ptr<OpInOutArg>> arg_list_; std::vector<std::shared_ptr<OpInOutArg>> arg_list_;
}; };
......
...@@ -11,7 +11,7 @@ namespace framework { ...@@ -11,7 +11,7 @@ namespace framework {
TEST(GradOpBuilder, AddTwo) { TEST(GradOpBuilder, AddTwo) {
std::shared_ptr<OperatorBase> add_op( std::shared_ptr<OperatorBase> add_op(
OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {})); OpRegistry::CreateOp("add_two", {"x", "y"}, {"out"}, {}));
std::shared_ptr<OperatorBase> grad_add_op = OpRegistry::CreateGradOp(add_op); std::shared_ptr<OperatorBase> grad_add_op = OpRegistry::CreateGradOp(*add_op);
EXPECT_EQ(static_cast<int>(grad_add_op->inputs_.size()), 4); EXPECT_EQ(static_cast<int>(grad_add_op->inputs_.size()), 4);
EXPECT_EQ(static_cast<int>(grad_add_op->outputs_.size()), 2); EXPECT_EQ(static_cast<int>(grad_add_op->outputs_.size()), 2);
EXPECT_EQ(grad_add_op->Input("X"), "x"); EXPECT_EQ(grad_add_op->Input("X"), "x");
......
...@@ -303,11 +303,10 @@ class OpRegistry { ...@@ -303,11 +303,10 @@ class OpRegistry {
return CreateOp(op_desc.type(), inputs, outputs, attrs); return CreateOp(op_desc.type(), inputs, outputs, attrs);
} }
static std::shared_ptr<OperatorBase> CreateGradOp( static std::shared_ptr<OperatorBase> CreateGradOp(const OperatorBase& op) {
std::shared_ptr<OperatorBase> op) { PADDLE_ENFORCE(!op.IsNetOp(),
PADDLE_ENFORCE(!op->IsNetOp(),
"Use framework::Backward to get backward ops"); "Use framework::Backward to get backward ops");
GradOpBuilder builder(op.get()); GradOpBuilder builder(op);
std::shared_ptr<OperatorBase> grad_op(builder.Build()); std::shared_ptr<OperatorBase> grad_op(builder.Build());
grad_op->Init(); grad_op->Init();
return grad_op; return grad_op;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册