提交 d0b25ac9 编写于 作者: Y Yu Yang

Fix some unittest error

上级 8bf0ca0f
...@@ -72,7 +72,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl( ...@@ -72,7 +72,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
return EmptyOp(); return EmptyOp();
} }
auto* net = new NetOp(); auto net = std::make_shared<NetOp>();
if (forwardOp.IsNetOp()) { if (forwardOp.IsNetOp()) {
//! TODO(dzh) //! TODO(dzh)
...@@ -84,7 +84,8 @@ static std::shared_ptr<OperatorBase> BackwardImpl( ...@@ -84,7 +84,8 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
auto& forwardNet = static_cast<const NetOp&>(forwardOp); auto& forwardNet = static_cast<const NetOp&>(forwardOp);
// travesal subnet/op // travesal subnet/op
for (auto it = forwardNet.ops_.end(); it != forwardNet.ops_.begin(); --it) { for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
++it) {
auto fwd = *it; auto fwd = *it;
// for (auto& fwd : forwardNet.ops_) { // for (auto& fwd : forwardNet.ops_) {
// auto bwd = Backward(*fwd, no_grad_names); // auto bwd = Backward(*fwd, no_grad_names);
...@@ -115,7 +116,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl( ...@@ -115,7 +116,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
insert_postion.push_back( insert_postion.push_back(
{dup_op.back(), {dup_op.back(),
OpRegistry::CreateOp( OpRegistry::CreateOp(
"Add", {dup_outputs}, {name}, "add", {dup_outputs}, {name},
{{"input_format", {{"input_format",
std::vector<int>{0, (int)dup_outputs.size()}}})}); std::vector<int>{0, (int)dup_outputs.size()}}})});
} }
...@@ -142,11 +143,15 @@ static std::shared_ptr<OperatorBase> BackwardImpl( ...@@ -142,11 +143,15 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
grad_output = OperatorBase::EMPTY_VAR_NAME(); grad_output = OperatorBase::EMPTY_VAR_NAME();
} }
} }
if (net->ops_.empty()) { // Current no aux op is added to network
return grad_op;
}
net->AddOp(grad_op); net->AddOp(grad_op);
} }
net->CompleteAddOp(); net->CompleteAddOp();
return std::shared_ptr<OperatorBase>(net); return net;
} }
extern std::shared_ptr<OperatorBase> Backward( extern std::shared_ptr<OperatorBase> Backward(
......
...@@ -63,14 +63,22 @@ class FcOp : public NetOp { ...@@ -63,14 +63,22 @@ class FcOp : public NetOp {
public: public:
void Init() override { void Init() override {
AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")}, AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")},
{Output("before_act")}, {})); {Output("mul_result")}, {}));
auto b_name = Input("b"); auto b_name = Input("b");
std::string before_act = "mul_result";
if (b_name != EMPTY_VAR_NAME()) { if (b_name != EMPTY_VAR_NAME()) {
AddOp(OpRegistry::CreateOp("rowwise_add", {Output("before_act"), b_name}, AddOp(OpRegistry::CreateOp("rowwise_add", {Output("mul_result"), b_name},
{Output("before_act")}, {})); {Output("add_result")}, {}));
before_act = "add_result";
} else {
auto out_varname = Output("add_result");
if (out_varname != EMPTY_VAR_NAME()) {
this->Rename(out_varname, EMPTY_VAR_NAME());
}
} }
AddOp(OpRegistry::CreateOp("sigmoid", {Output("before_act")},
{Output("Out")}, {})); AddOp(OpRegistry::CreateOp("sigmoid", {Output(before_act)}, {Output("Out")},
{}));
CompleteAddOp(false); CompleteAddOp(false);
} }
}; };
...@@ -82,7 +90,8 @@ class FcOpMaker : public OpProtoAndCheckerMaker { ...@@ -82,7 +90,8 @@ class FcOpMaker : public OpProtoAndCheckerMaker {
AddInput("X", "x"); AddInput("X", "x");
AddInput("W", "w"); AddInput("W", "w");
AddInput("b", "b"); AddInput("b", "b");
AddOutput("before_act", "before act").SetTemporary(); AddOutput("mul_result", "").SetTemporary();
AddOutput("add_result", "").SetTemporary();
AddOutput("Out", ""); AddOutput("Out", "");
AddComment(""); AddComment("");
} }
...@@ -153,7 +162,7 @@ TEST(Backward, simple_op_grad) { ...@@ -153,7 +162,7 @@ TEST(Backward, simple_op_grad) {
TEST(Backward, net_fc_backward_normal) { TEST(Backward, net_fc_backward_normal) {
std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp( std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
"fc", {"X", "w", "b"}, {"out", "tmp_forward"}, {}); "fc", {"X", "w", "b"}, {"mul_result", "add_result", "out"}, {});
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {}); std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
ASSERT_TRUE(gop->IsNetOp()); ASSERT_TRUE(gop->IsNetOp());
...@@ -176,7 +185,7 @@ TEST(Backward, net_fc_backward_normal) { ...@@ -176,7 +185,7 @@ TEST(Backward, net_fc_backward_normal) {
TEST(Backward, net_fc_backward_not_have_b) { TEST(Backward, net_fc_backward_not_have_b) {
std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp( std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
"fc", {"X", "w", f::OperatorBase::EMPTY_VAR_NAME()}, "fc", {"X", "w", f::OperatorBase::EMPTY_VAR_NAME()},
{"out", "tmp_forward"}, {}); {"mul_result", "add_result", "tmp"}, {});
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {}); std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
ASSERT_TRUE(gop->IsNetOp()); ASSERT_TRUE(gop->IsNetOp());
...@@ -196,9 +205,9 @@ TEST(Backward, net_fc_backward_not_have_b) { ...@@ -196,9 +205,9 @@ TEST(Backward, net_fc_backward_not_have_b) {
TEST(Backward, net_input_of_network_not_need_grad) { TEST(Backward, net_input_of_network_not_need_grad) {
f::NetOp net; f::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"}, net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"},
{"hidden0", "tmp0"}, {})); {"mul_tmp_0", "add_tmp_0", "hidden0"}, {}));
net.AddOp(f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"}, net.AddOp(f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"},
{"hidden1", "tmp1"}, {})); {"mul_tmp_1", "add_tmp_1", "hidden1"}, {}));
net.CompleteAddOp(); net.CompleteAddOp();
auto bwd = Backward(net, {"X"}); // X@GRAD is not need. auto bwd = Backward(net, {"X"}); // X@GRAD is not need.
ASSERT_TRUE(bwd->IsNetOp()); ASSERT_TRUE(bwd->IsNetOp());
...@@ -235,6 +244,7 @@ TEST(Backward, net_shared_weight) { ...@@ -235,6 +244,7 @@ TEST(Backward, net_shared_weight) {
ASSERT_TRUE(bwd->IsNetOp()); ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(bwd.get()); auto bwd_net = static_cast<f::NetOp *>(bwd.get());
ASSERT_EQ(3UL, bwd_net->ops_.size()); ASSERT_EQ(3UL, bwd_net->ops_.size());
LOG(INFO) << bwd_net->DebugString();
ASSERT_EQ("add_grad", bwd_net->ops_[2]->type_); ASSERT_EQ("add_grad", bwd_net->ops_[2]->type_);
} }
......
...@@ -52,7 +52,7 @@ std::vector<std::string> OperatorBase::Inputs(const std::string& name) const { ...@@ -52,7 +52,7 @@ std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "IO Idx could not be nullptr"); PADDLE_ENFORCE(in_out_idxs_ != nullptr, "IO Idx could not be nullptr");
auto input_format = GetAttr<std::vector<int>>("input_format"); auto input_format = GetAttr<std::vector<int>>("input_format");
auto offset = in_out_idxs_->at(name); auto offset = in_out_idxs_->at(name);
PADDLE_ENFORCE(input_format.at((size_t)offset + 1) <= inputs_.size(), PADDLE_ENFORCE(input_format.at((size_t)offset + 1) <= (int)inputs_.size(),
"Input Out Of Range"); "Input Out Of Range");
return std::vector<std::string>{ return std::vector<std::string>{
...@@ -78,7 +78,7 @@ std::vector<std::string> OperatorBase::Outputs(const std::string& name) const { ...@@ -78,7 +78,7 @@ std::vector<std::string> OperatorBase::Outputs(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr"); PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr");
auto output_format = GetAttr<std::vector<int>>("output_format"); auto output_format = GetAttr<std::vector<int>>("output_format");
auto offset = in_out_idxs_->at(name); auto offset = in_out_idxs_->at(name);
PADDLE_ENFORCE(output_format.at((size_t)offset + 1) <= outputs_.size(), PADDLE_ENFORCE(output_format.at((size_t)offset + 1) <= (int)outputs_.size(),
"Output Out of Range"); "Output Out of Range");
return std::vector<std::string>{ return std::vector<std::string>{
outputs_.begin() + output_format.at(offset), outputs_.begin() + output_format.at(offset),
......
...@@ -101,6 +101,7 @@ class OperatorBase { ...@@ -101,6 +101,7 @@ class OperatorBase {
//! Get a input with argument's name described in `op_proto` //! Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const; const std::string& Input(const std::string& name) const;
//! Get a input which has multiple variables. //! Get a input which has multiple variables.
//! TODO add a vector_view to prevent memory copy. //! TODO add a vector_view to prevent memory copy.
std::vector<std::string> Inputs(const std::string& name) const; std::vector<std::string> Inputs(const std::string& name) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册