提交 d0b25ac9 编写于 作者: Y Yu Yang

Fix some unittest error

上级 8bf0ca0f
......@@ -72,7 +72,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
return EmptyOp();
}
auto* net = new NetOp();
auto net = std::make_shared<NetOp>();
if (forwardOp.IsNetOp()) {
//! TODO(dzh)
......@@ -84,7 +84,8 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
auto& forwardNet = static_cast<const NetOp&>(forwardOp);
// travesal subnet/op
for (auto it = forwardNet.ops_.end(); it != forwardNet.ops_.begin(); --it) {
for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
++it) {
auto fwd = *it;
// for (auto& fwd : forwardNet.ops_) {
// auto bwd = Backward(*fwd, no_grad_names);
......@@ -115,7 +116,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
insert_postion.push_back(
{dup_op.back(),
OpRegistry::CreateOp(
"Add", {dup_outputs}, {name},
"add", {dup_outputs}, {name},
{{"input_format",
std::vector<int>{0, (int)dup_outputs.size()}}})});
}
......@@ -142,11 +143,15 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
grad_output = OperatorBase::EMPTY_VAR_NAME();
}
}
if (net->ops_.empty()) { // Current no aux op is added to network
return grad_op;
}
net->AddOp(grad_op);
}
net->CompleteAddOp();
return std::shared_ptr<OperatorBase>(net);
return net;
}
extern std::shared_ptr<OperatorBase> Backward(
......
......@@ -63,14 +63,22 @@ class FcOp : public NetOp {
public:
void Init() override {
AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")},
{Output("before_act")}, {}));
{Output("mul_result")}, {}));
auto b_name = Input("b");
std::string before_act = "mul_result";
if (b_name != EMPTY_VAR_NAME()) {
AddOp(OpRegistry::CreateOp("rowwise_add", {Output("before_act"), b_name},
{Output("before_act")}, {}));
AddOp(OpRegistry::CreateOp("rowwise_add", {Output("mul_result"), b_name},
{Output("add_result")}, {}));
before_act = "add_result";
} else {
auto out_varname = Output("add_result");
if (out_varname != EMPTY_VAR_NAME()) {
this->Rename(out_varname, EMPTY_VAR_NAME());
}
AddOp(OpRegistry::CreateOp("sigmoid", {Output("before_act")},
{Output("Out")}, {}));
}
AddOp(OpRegistry::CreateOp("sigmoid", {Output(before_act)}, {Output("Out")},
{}));
CompleteAddOp(false);
}
};
......@@ -82,7 +90,8 @@ class FcOpMaker : public OpProtoAndCheckerMaker {
AddInput("X", "x");
AddInput("W", "w");
AddInput("b", "b");
AddOutput("before_act", "before act").SetTemporary();
AddOutput("mul_result", "").SetTemporary();
AddOutput("add_result", "").SetTemporary();
AddOutput("Out", "");
AddComment("");
}
......@@ -153,7 +162,7 @@ TEST(Backward, simple_op_grad) {
TEST(Backward, net_fc_backward_normal) {
std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
"fc", {"X", "w", "b"}, {"out", "tmp_forward"}, {});
"fc", {"X", "w", "b"}, {"mul_result", "add_result", "out"}, {});
ASSERT_NE(fwd, nullptr);
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
ASSERT_TRUE(gop->IsNetOp());
......@@ -176,7 +185,7 @@ TEST(Backward, net_fc_backward_normal) {
TEST(Backward, net_fc_backward_not_have_b) {
std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
"fc", {"X", "w", f::OperatorBase::EMPTY_VAR_NAME()},
{"out", "tmp_forward"}, {});
{"mul_result", "add_result", "tmp"}, {});
ASSERT_NE(fwd, nullptr);
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
ASSERT_TRUE(gop->IsNetOp());
......@@ -196,9 +205,9 @@ TEST(Backward, net_fc_backward_not_have_b) {
TEST(Backward, net_input_of_network_not_need_grad) {
f::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"},
{"hidden0", "tmp0"}, {}));
{"mul_tmp_0", "add_tmp_0", "hidden0"}, {}));
net.AddOp(f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"},
{"hidden1", "tmp1"}, {}));
{"mul_tmp_1", "add_tmp_1", "hidden1"}, {}));
net.CompleteAddOp();
auto bwd = Backward(net, {"X"}); // X@GRAD is not need.
ASSERT_TRUE(bwd->IsNetOp());
......@@ -235,6 +244,7 @@ TEST(Backward, net_shared_weight) {
ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(bwd.get());
ASSERT_EQ(3UL, bwd_net->ops_.size());
LOG(INFO) << bwd_net->DebugString();
ASSERT_EQ("add_grad", bwd_net->ops_[2]->type_);
}
......
......@@ -52,7 +52,7 @@ std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "IO Idx could not be nullptr");
auto input_format = GetAttr<std::vector<int>>("input_format");
auto offset = in_out_idxs_->at(name);
PADDLE_ENFORCE(input_format.at((size_t)offset + 1) <= inputs_.size(),
PADDLE_ENFORCE(input_format.at((size_t)offset + 1) <= (int)inputs_.size(),
"Input Out Of Range");
return std::vector<std::string>{
......@@ -78,7 +78,7 @@ std::vector<std::string> OperatorBase::Outputs(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr");
auto output_format = GetAttr<std::vector<int>>("output_format");
auto offset = in_out_idxs_->at(name);
PADDLE_ENFORCE(output_format.at((size_t)offset + 1) <= outputs_.size(),
PADDLE_ENFORCE(output_format.at((size_t)offset + 1) <= (int)outputs_.size(),
"Output Out of Range");
return std::vector<std::string>{
outputs_.begin() + output_format.at(offset),
......
......@@ -101,6 +101,7 @@ class OperatorBase {
//! Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const;
//! Get a input which has multiple variables.
//! TODO add a vector_view to prevent memory copy.
std::vector<std::string> Inputs(const std::string& name) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册