提交 5713266f 编写于 作者: D dongzhihong

Merge remote-tracking branch 'reyoung/feature/backward' into feature/backward

...@@ -108,6 +108,16 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker { ...@@ -108,6 +108,16 @@ class FillZeroOpMaker : public OpProtoAndCheckerMaker {
AddComment(""); AddComment("");
} }
}; };
class AddOpMaker : public OpProtoAndCheckerMaker {
public:
AddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "x").SetMultiple();
AddOutput("Y", "y");
AddComment("");
}
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -123,11 +133,14 @@ REGISTER_OP(fc, f::FcOp, f::FcOpMaker); ...@@ -123,11 +133,14 @@ REGISTER_OP(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker); REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker);
REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp); REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp);
REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker); REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker);
REGISTER_OP(add, f::EmptyOp, f::AddOpMaker);
REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp);
TEST(Backward, simple_grad) { TEST(Backward, simple_op_grad) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
auto gop = f::OpRegistry::CreateGradOp(*fwd); auto gop = f::OpRegistry::CreateGradOp(*fwd);
ASSERT_EQ(1UL, gop->inputs_.size());
ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]); ASSERT_EQ("Out" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->inputs_[0]);
ASSERT_EQ("rowwise_add_grad", gop->type_); ASSERT_EQ("rowwise_add_grad", gop->type_);
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[0]); ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(), gop->outputs_[0]);
...@@ -138,14 +151,99 @@ TEST(Backward, simple_grad) { ...@@ -138,14 +151,99 @@ TEST(Backward, simple_grad) {
// LOG(INFO) << gop->Output("X" + "@GRAD"); // LOG(INFO) << gop->Output("X" + "@GRAD");
} }
TEST(Backward, not_for_network) { TEST(Backward, net_fc_backward_normal) {
std::shared_ptr<f::OperatorBase> fwd =
f::OpRegistry::CreateOp("fc", {"X", "w", "b"}, {"out"}, {});
ASSERT_NE(fwd, nullptr);
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
ASSERT_TRUE(gop->IsNetOp());
auto net = static_cast<f::NetOp *>(gop.get());
ASSERT_NO_THROW(net->DebugString());
ASSERT_EQ(3UL, net->ops_.size());
f::OperatorBase &d_sigmoid = *net->ops_[0];
ASSERT_EQ("sigmoid_grad", d_sigmoid.type_);
f::OperatorBase &d_add = *net->ops_[1];
ASSERT_EQ("rowwise_add_grad", d_add.type_);
f::OperatorBase &d_mul = *net->ops_[2];
ASSERT_EQ("mul_grad", d_mul.type_);
}
TEST(Backward, net_fc_backward_not_have_b) {
std::shared_ptr<f::OperatorBase> fwd = f::OpRegistry::CreateOp(
"fc", {"X", "w", f::OperatorBase::EMPTY_VAR_NAME()}, {"out"}, {});
ASSERT_NE(fwd, nullptr);
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
ASSERT_TRUE(gop->IsNetOp());
auto net = static_cast<f::NetOp *>(gop.get());
ASSERT_NO_THROW(net->DebugString());
ASSERT_EQ(2UL, net->ops_.size());
f::OperatorBase &d_sigmoid = *net->ops_[0];
ASSERT_EQ("sigmoid_grad", d_sigmoid.type_);
f::OperatorBase &d_mul = *net->ops_[1];
ASSERT_EQ("mul_grad", d_mul.type_);
}
TEST(Backward, net_input_of_network_not_need_grad) {
f::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"}, {"hidden0"}, {}));
net.AddOp(
f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"}, {"hidden1"}, {}));
net.CompleteAddOp();
auto bwd = Backward(net, {"X"}); // X@GRAD is not need.
ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(bwd.get());
std::unordered_set<std::string> all_output = std::unordered_set<std::string>(
bwd_net->outputs_.begin(), bwd_net->outputs_.end());
all_output.erase(f::OperatorBase::EMPTY_VAR_NAME());
for (auto &out : {"W1", "b1", "hidden0", "W2", "b2"}) {
ASSERT_NE(all_output.find(out + f::OperatorBase::GRAD_VAR_SUFFIX()),
all_output.end());
}
// Not Generated X
ASSERT_EQ(all_output.find("X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
all_output.end());
ASSERT_EQ(2, bwd_net->ops_.size());
ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
auto first_fc_grad = static_cast<f::NetOp *>(bwd_net->ops_[1].get());
ASSERT_EQ(3, first_fc_grad->ops_.size());
ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(),
first_fc_grad[2].Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()));
}
TEST(Backward, net_shared_weight) {
f::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("mul", {"X", "W"}, {"Out"}, {}));
net.AddOp(f::OpRegistry::CreateOp("mul", {"Out", "W"}, {"FinalOut"}, {}));
net.CompleteAddOp();
auto bwd = f::Backward(net, {});
ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(bwd.get());
ASSERT_EQ(3UL, bwd_net->ops_.size());
ASSERT_EQ("add_grad", bwd_net->ops_[2]->type_);
}
TEST(Backward, op_register_grad_not_for_network) {
auto fwd = auto fwd =
f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"}, f::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Out", "tmp_out"},
{{"temporary_index", std::vector<int>{1}}}); {{"temporary_index", std::vector<int>{1}}});
ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet); ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet);
} }
TEST(Backward, all_input_are_not_need) { TEST(Backward, op_all_input_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
auto backward = f::Backward(*fwd, {"X", "b"}); auto backward = f::Backward(*fwd, {"X", "b"});
ASSERT_TRUE(backward->IsNetOp()); ASSERT_TRUE(backward->IsNetOp());
...@@ -153,7 +251,7 @@ TEST(Backward, all_input_are_not_need) { ...@@ -153,7 +251,7 @@ TEST(Backward, all_input_are_not_need) {
ASSERT_TRUE(net->ops_.empty()); ASSERT_TRUE(net->ops_.empty());
} }
TEST(Backward, all_output_are_not_need) { TEST(Backward, op_all_output_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
auto backward = f::Backward(*fwd, {"Out"}); auto backward = f::Backward(*fwd, {"Out"});
ASSERT_TRUE(backward->IsNetOp()); ASSERT_TRUE(backward->IsNetOp());
...@@ -161,22 +259,78 @@ TEST(Backward, all_output_are_not_need) { ...@@ -161,22 +259,78 @@ TEST(Backward, all_output_are_not_need) {
ASSERT_TRUE(net->ops_.empty()); ASSERT_TRUE(net->ops_.empty());
} }
TEST(Backward, part_of_output_are_not_need) { TEST(Backward, op_part_of_output_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {}); auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {});
auto backward = f::Backward(*fwd, {"Z"}); auto backward = f::Backward(*fwd, {"Z"});
ASSERT_TRUE(backward->IsNetOp()); ASSERT_TRUE(backward->IsNetOp());
auto net = static_cast<f::NetOp *>(backward.get()); auto net = static_cast<f::NetOp *>(backward.get());
ASSERT_EQ(net->ops_.size(), 2); ASSERT_EQ(net->ops_.size(), 2UL);
auto &fill_zero = *net->ops_[0]; auto &fill_zero = *net->ops_[0];
ASSERT_EQ("fill_zeros_like", fill_zero.type_); ASSERT_EQ("fill_zeros_like", fill_zero.type_);
ASSERT_EQ(1, fill_zero.inputs_.size()); ASSERT_EQ(1UL, fill_zero.inputs_.size());
ASSERT_EQ("Z", fill_zero.inputs_[0]); ASSERT_EQ("Z", fill_zero.inputs_[0]);
ASSERT_EQ(1, fill_zero.outputs_.size()); ASSERT_EQ(1UL, fill_zero.outputs_.size());
ASSERT_EQ("Z@ZERO", fill_zero.outputs_[0]); ASSERT_EQ("Z" + f::OperatorBase::ZERO_VAR_SUFFIX(), fill_zero.outputs_[0]);
auto &d_many_out = *net->ops_[1]; auto &d_many_out = *net->ops_[1];
ASSERT_EQ("many_output_op_grad", d_many_out.type_); ASSERT_EQ("many_output_op_grad", d_many_out.type_);
ASSERT_EQ(1 + 2 + 2, d_many_out.inputs_.size()); // I/O/OG ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG
ASSERT_EQ("Z@ZERO", d_many_out.Input("z@GRAD")); ASSERT_EQ("Z" + f::OperatorBase::ZERO_VAR_SUFFIX(),
d_many_out.Input("z" + f::OperatorBase::GRAD_VAR_SUFFIX()));
ASSERT_EQ("Y" + f::OperatorBase::GRAD_VAR_SUFFIX(),
d_many_out.Input("y" + f::OperatorBase::GRAD_VAR_SUFFIX()));
ASSERT_EQ("X" + f::OperatorBase::GRAD_VAR_SUFFIX(),
d_many_out.Output("x" + f::OperatorBase::GRAD_VAR_SUFFIX()));
}
TEST(Backward, op_part_of_input_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {});
auto backward = f::Backward(*fwd, {"a"});
ASSERT_TRUE(backward->IsNetOp());
auto net = static_cast<f::NetOp *>(backward.get());
ASSERT_EQ(net->ops_.size(), 1UL);
auto &grad_mul = *net->ops_[0];
ASSERT_EQ(grad_mul.type_, "mul_grad");
ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL);
ASSERT_EQ(grad_mul.outputs_.size(), 2UL);
ASSERT_EQ(grad_mul.Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX()),
f::OperatorBase::EMPTY_VAR_NAME());
ASSERT_EQ(grad_mul.Output("B" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"b" + f::OperatorBase::GRAD_VAR_SUFFIX());
ASSERT_EQ(grad_mul.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"out" + f::OperatorBase::GRAD_VAR_SUFFIX());
ASSERT_EQ(grad_mul.Input("A"), "a");
ASSERT_EQ(grad_mul.Input("B"), "b");
ASSERT_EQ(grad_mul.Input("Out"), "out");
}
TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
f::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"}, {"out1"}, {}));
net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"}, {"out2"}, {}));
net.AddOp(f::OpRegistry::CreateOp("fc", {"out2", "w3", "b3"}, {"out3"}, {}));
net.CompleteAddOp(false);
auto backward = f::Backward(net, {"out2"});
ASSERT_TRUE(backward->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(backward.get());
ASSERT_EQ(bwd_net->ops_.size(), 1UL);
auto &grad_fc = *bwd_net->ops_[0];
ASSERT_EQ(grad_fc.type_, "fc_grad");
ASSERT_EQ(grad_fc.inputs_.size(), 3UL + 1UL + 1UL);
ASSERT_EQ(grad_fc.outputs_.size(), 3UL);
ASSERT_EQ(grad_fc.Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX()),
f::OperatorBase::EMPTY_VAR_NAME());
ASSERT_EQ(grad_fc.Output("W" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"w3" + f::OperatorBase::GRAD_VAR_SUFFIX());
ASSERT_EQ(grad_fc.Output("b" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"b3" + f::OperatorBase::GRAD_VAR_SUFFIX());
ASSERT_EQ(grad_fc.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"out3" + f::OperatorBase::GRAD_VAR_SUFFIX());
ASSERT_EQ(grad_fc.Input("X"), "out2");
ASSERT_EQ(grad_fc.Input("W"), "w3");
ASSERT_EQ(grad_fc.Input("b"), "b3");
ASSERT_EQ(grad_fc.Input("Out"), "out3");
} }
\ No newline at end of file
...@@ -49,6 +49,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) ...@@ -49,6 +49,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc)
op_library(sigmoid_op SRCS sigmoid_op.cu sigmoid_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cu sigmoid_op.cc)
op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu)
op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu) op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu)
op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu)
op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op
softmax_op net) softmax_op net)
......
...@@ -19,16 +19,16 @@ limitations under the License. */ ...@@ -19,16 +19,16 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class FillZerosLike : public framework::OperatorWithKernel { class FillZerosLikeOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(
const std::vector<const framework::Tensor *> &inputs, const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override { const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 1, PADDLE_ENFORCE(inputs.size() == 1,
"Input size of FillZerosLike must be one."); "Input size of FillZerosLikeOp must be one.");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one."); PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one.");
PADDLE_ENFORCE(inputs[0] != nullptr && outputs[0] != nullptr, PADDLE_ENFORCE(inputs[0] != nullptr && outputs[0] != nullptr,
"Outputs of FillZerosLike must all be set."); "Outputs of FillZerosLikeOp must all be set.");
outputs[0]->Resize(inputs[0]->dims()); outputs[0]->Resize(inputs[0]->dims());
} }
}; };
...@@ -44,7 +44,7 @@ public: ...@@ -44,7 +44,7 @@ public:
Fill up a vriable with zeros. Fill up a vriable with zeros.
The output will have the same size with input. The output will have the same size with input.
)DOC") )DOC");
} }
}; };
} // namespace operators } // namespace operators
...@@ -53,6 +53,6 @@ The output will have the same size with input. ...@@ -53,6 +53,6 @@ The output will have the same size with input.
REGISTER_OP(fill_zeros_like, REGISTER_OP(fill_zeros_like,
paddle::operators::FillZerosLikeOp, paddle::operators::FillZerosLikeOp,
paddle::operators::FillZerosLikeOpMaker); paddle::operators::FillZerosLikeOpMaker);
EGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
fill_zeros_like, fill_zeros_like,
paddle::operators::FillZerosLikeKernal<paddle::platform::CPUPlace, float>); paddle::operators::FillZerosLikeKernel<paddle::platform::CPUPlace, float>);
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册