提交 0e337bec 编写于 作者: F fengjiayi

Merge branch 'feature/backward' of https://github.com/reyoung/Paddle into feature/backward

...@@ -50,50 +50,72 @@ static std::shared_ptr<OperatorBase> EmptyOp() { ...@@ -50,50 +50,72 @@ static std::shared_ptr<OperatorBase> EmptyOp() {
return net_op; return net_op;
} }
/**
* @brief Backward an operator, implementation
* @param forwardOp the forward operator
* @param no_grad_names variable names not calculate for gradient. Like X@GRAD
* is not needed.
* @param uniq_id a unique index used inside BackwardImpl, it will be shared
* through recursive invoke.
* @return The backward operator. For simple situation, it is a simple operator.
* For complex situation, it is a NetOp.
*
* See Backward.h for details
*/
static std::shared_ptr<OperatorBase> BackwardImpl( static std::shared_ptr<OperatorBase> BackwardImpl(
const OperatorBase& forwardOp, const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) { std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
/**
* If all input gradients of forwarding operator do not need to calculate,
* just return an EmptyOp. Not return null ptr because EmptyOp does not take
* too much time for calculation, but it is useful for simplifying logic.
*/
if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(), if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(),
no_grad_names)) { no_grad_names)) {
return EmptyOp(); return EmptyOp();
} }
/**
* All output gradients of forwarding operator do not need to calculate. Then
* all input gradients cannot be computed at all, and we put them into
* `no_grad_names` set. Return an EmptyOp.
*/
if (AllInSet(forwardOp.outputs_, OperatorBase::GRAD_VAR_SUFFIX(), if (AllInSet(forwardOp.outputs_, OperatorBase::GRAD_VAR_SUFFIX(),
no_grad_names)) { no_grad_names)) {
for (auto& name : forwardOp.inputs_) { for (auto& name : forwardOp.inputs_) {
// Mark all input is not need /// Mark all input is not need
no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX());
} }
return EmptyOp(); return EmptyOp();
} }
//! Returned gradient network
auto net = std::make_shared<NetOp>(); auto net = std::make_shared<NetOp>();
if (forwardOp.IsNetOp()) { if (forwardOp.IsNetOp()) {
//! TODO(dzh) /// Because forwardOp is a net op, it can static_cast.
std::unordered_map<std::string /*var name*/,
std::vector<size_t> /*op offset*/>
dup_output_ops;
size_t local_op_id = 0;
// Because it is a net op, it can static_cast.
auto& forwardNet = static_cast<const NetOp&>(forwardOp); auto& forwardNet = static_cast<const NetOp&>(forwardOp);
// travesal subnet/op //! Map from output gradient variable name to operator's indices in backward
//! net. That operator generates that variable.
std::unordered_map<std::string, std::vector<size_t>> dup_output_ops;
size_t local_op_id = 0;
/// reversely travel forwardNet
for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
++it) { ++it, ++local_op_id) {
auto fwd = *it; auto fwd = *it;
auto bwd = BackwardImpl(*fwd, no_grad_names, uniq_id); auto bwd = BackwardImpl(*fwd, no_grad_names, uniq_id);
net->AddOp(bwd); net->AddOp(bwd);
for (size_t i = 0; i < bwd->outputs_.size(); ++i) { for (auto& out : bwd->outputs_) {
dup_output_ops[bwd->outputs_[i]].emplace_back(local_op_id); dup_output_ops[out].emplace_back(local_op_id);
} }
local_op_id++;
} }
// unique the duplicate name /// Get unique ID for this method.
auto uid = uniq_id++; auto uid = uniq_id++;
// TODO(dzh): more comment // TODO(dzh): more comment
typedef std::pair<size_t, std::shared_ptr<OperatorBase>> Pos; using Pos = std::pair<size_t, std::shared_ptr<OperatorBase>>;
std::list<Pos> insert_postion; std::list<Pos> insert_position;
for (auto& dup_output_op : dup_output_ops) { for (auto& dup_output_op : dup_output_ops) {
const std::string& name = dup_output_op.first; const std::string& name = dup_output_op.first;
auto& dup_op = dup_output_op.second; auto& dup_op = dup_output_op.second;
...@@ -106,16 +128,18 @@ static std::shared_ptr<OperatorBase> BackwardImpl( ...@@ -106,16 +128,18 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
std::to_string(i)); std::to_string(i));
net->ops_[op_offset]->Rename(name, dup_outputs.back()); net->ops_[op_offset]->Rename(name, dup_outputs.back());
} }
insert_postion.push_back( insert_position.push_back(
{dup_op.back(), {dup_op.back(),
OpRegistry::CreateOp( OpRegistry::CreateOp(
"add", {dup_outputs}, {name}, "add", {dup_outputs}, {name},
{{"input_format", {{"input_format",
std::vector<int>{0, (int)dup_outputs.size()}}})}); std::vector<int>{0, (int)dup_outputs.size()}}})});
} }
insert_postion.sort(
insert_position.sort(
[](const Pos& l, const Pos& r) { return l.first > r.first; }); [](const Pos& l, const Pos& r) { return l.first > r.first; });
for (auto& pos : insert_postion) {
for (auto& pos : insert_position) {
net->InsertOp(pos.first, pos.second); net->InsertOp(pos.first, pos.second);
} }
...@@ -148,6 +172,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl( ...@@ -148,6 +172,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
return net; return net;
} }
//! See header for comments
extern std::shared_ptr<OperatorBase> Backward( extern std::shared_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp, const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars) { const std::unordered_set<std::string>& no_grad_vars) {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
limitations under the License. */ limitations under the License. */
#include "paddle/framework/backward.h" #include "paddle/framework/backward.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/framework/net.h" #include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
...@@ -142,6 +143,7 @@ REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker); ...@@ -142,6 +143,7 @@ REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker);
REGISTER_OP(add, f::EmptyOp, f::AddOpMaker); REGISTER_OP(add, f::EmptyOp, f::AddOpMaker);
REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp); REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp);
REGISTER_OP(fc, f::FcOp, f::FcOpMaker); REGISTER_OP(fc, f::FcOp, f::FcOpMaker);
REGISTER_GRADIENT_OP(fc, fc_grad, f::EmptyOp);
REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker); REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker);
REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp); REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp);
...@@ -160,6 +162,18 @@ TEST(Backward, simple_op_grad) { ...@@ -160,6 +162,18 @@ TEST(Backward, simple_op_grad) {
// LOG(INFO) << gop->Output("X" + "@GRAD"); // LOG(INFO) << gop->Output("X" + "@GRAD");
} }
TEST(Backward, simple_op_not_need_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
ASSERT_NE(fwd, nullptr);
auto gop = f::Backward(*fwd, {"X"});
ASSERT_EQ(std::find(gop->outputs_.begin(), gop->outputs_.end(),
"X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
gop->outputs_.end());
auto no_input_gop = f::Backward(*fwd, {"X", "b"});
ASSERT_NE(no_input_gop, nullptr);
}
TEST(Backward, net_fc_backward_normal) { TEST(Backward, net_fc_backward_normal) {
std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp( std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
"fc", {"X", "w", "b"}, {"mul_result", "add_result", "out"}, {}); "fc", {"X", "w", "b"}, {"mul_result", "add_result", "out"}, {});
...@@ -217,6 +231,8 @@ TEST(Backward, net_input_of_network_not_need_grad) { ...@@ -217,6 +231,8 @@ TEST(Backward, net_input_of_network_not_need_grad) {
bwd_net->outputs_.begin(), bwd_net->outputs_.end()); bwd_net->outputs_.begin(), bwd_net->outputs_.end());
all_output.erase(f::OperatorBase::EMPTY_VAR_NAME()); all_output.erase(f::OperatorBase::EMPTY_VAR_NAME());
LOG(INFO) << bwd_net->DebugString();
LOG(INFO) << bwd_net->ops_.size();
for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) { for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) {
ASSERT_NE(all_output.find(out + f::OperatorBase::GRAD_VAR_SUFFIX()), ASSERT_NE(all_output.find(out + f::OperatorBase::GRAD_VAR_SUFFIX()),
all_output.end()); all_output.end());
...@@ -230,9 +246,9 @@ TEST(Backward, net_input_of_network_not_need_grad) { ...@@ -230,9 +246,9 @@ TEST(Backward, net_input_of_network_not_need_grad) {
ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp()); ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
auto first_fc_grad = static_cast<f::NetOp *>(bwd_net->ops_[1].get()); auto first_fc_grad = static_cast<f::NetOp *>(bwd_net->ops_[1].get());
ASSERT_EQ(3UL, first_fc_grad->ops_.size()); ASSERT_EQ(3UL, first_fc_grad->ops_.size());
ASSERT_EQ( LOG(INFO) << first_fc_grad->DebugString();
f::OperatorBase::EMPTY_VAR_NAME(), ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(),
first_fc_grad->ops_[2]->Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX())); first_fc_grad[2].Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()));
} }
TEST(Backward, net_shared_weight) { TEST(Backward, net_shared_weight) {
...@@ -245,13 +261,14 @@ TEST(Backward, net_shared_weight) { ...@@ -245,13 +261,14 @@ TEST(Backward, net_shared_weight) {
ASSERT_TRUE(bwd->IsNetOp()); ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(bwd.get()); auto bwd_net = static_cast<f::NetOp *>(bwd.get());
ASSERT_EQ(3UL, bwd_net->ops_.size()); ASSERT_EQ(3UL, bwd_net->ops_.size());
LOG(INFO) << bwd_net->DebugString();
ASSERT_EQ("add_grad", bwd_net->ops_[2]->type_); ASSERT_EQ("add_grad", bwd_net->ops_[2]->type_);
} }
TEST(Backward, op_register_grad_not_for_network) { TEST(Backward, op_register_grad_not_for_network) {
auto fwd = f::OpRegistry::CreateOp( auto fwd =
"fc", {"X", "W", "b"}, {"mul_result", "add_result", "Out"}, f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"},
{{"temporary_index", std::vector<int>{1}}}); {{"temporary_index", std::vector<int>{1}}});
ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet); ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet);
} }
...@@ -299,9 +316,11 @@ TEST(Backward, op_part_of_output_are_not_need) { ...@@ -299,9 +316,11 @@ TEST(Backward, op_part_of_output_are_not_need) {
TEST(Backward, op_part_of_input_are_not_need) { TEST(Backward, op_part_of_input_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {}); auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {});
auto backward = f::Backward(*fwd, {"a"}); auto backward = f::Backward(*fwd, {"a"});
ASSERT_TRUE(!backward->IsNetOp()); ASSERT_False(backward->IsNetOp());
auto net = static_cast<f::NetOp *>(backward.get());
ASSERT_EQ(net->ops_.size(), 1UL);
auto &grad_mul = *backward; auto &grad_mul = *net->ops_[0];
ASSERT_EQ(grad_mul.type_, "mul_grad"); ASSERT_EQ(grad_mul.type_, "mul_grad");
ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL);
ASSERT_EQ(grad_mul.outputs_.size(), 2UL); ASSERT_EQ(grad_mul.outputs_.size(), 2UL);
...@@ -324,11 +343,11 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ...@@ -324,11 +343,11 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
{"mul_out2", "tmp_out2", "out2"}, {})); {"mul_out2", "tmp_out2", "out2"}, {}));
net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"}, net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"},
{"mul_out3", "tmp_out3", "out3"}, {})); {"mul_out3", "tmp_out3", "out3"}, {}));
net.CompleteAddOp(); net.CompleteAddOp(false);
auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"}); auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"});
ASSERT_TRUE(backward->IsNetOp()); ASSERT_TRUE(backward->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(backward.get()); auto bwd_net = static_cast<f::NetOp *>(backward.get());
ASSERT_EQ(bwd_net->ops_.size(), 1UL); ASSERT_EQ(bwd_net->ops_.size(), 3UL);
auto &grad_fc = *bwd_net->ops_[0]; auto &grad_fc = *bwd_net->ops_[0];
ASSERT_EQ(grad_fc.type_, "fc_grad"); ASSERT_EQ(grad_fc.type_, "fc_grad");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册