diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 6ab81b5589307efe83f725537195671794cd6dcb..e920af3d1ac511f360ba32630c7812a939c27428 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -249,14 +249,20 @@ TEST(Backward, part_of_output_are_not_need) { } TEST(Backward, part_of_input_are_not_need) { - auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); - auto backward = f::Backward(*fwd, {"X"}); + auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {}); + auto backward = f::Backward(*fwd, {"a"}); ASSERT_TRUE(backward->IsNetOp()); auto net = static_cast(backward.get()); - ASSERT_EQ(1UL, net->ops_.size()); - - auto &d_add = *net->ops_[0]; - ASSERT_EQ("rowwise_add_grad", d_add.type_); - ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(), - d_add.Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX())); + ASSERT_EQ(net->ops_.size(), 1UL); + + auto &grad_mul = *net->ops_[0]; + ASSERT_EQ(grad_mul.type_, "mul_grad"); + ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); + ASSERT_EQ(grad_mul.outputs_.size(), 2UL); + ASSERT_EQ(grad_mul.Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX()), + f::OperatorBase::EMPTY_VAR_NAME()); + ASSERT_EQ(grad_mul.Output("B" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "b" + f::OperatorBase::GRAD_VAR_SUFFIX()); + ASSERT_EQ(grad_mul.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()), + "out" + f::OperatorBase::GRAD_VAR_SUFFIX()); } \ No newline at end of file diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 0a14dc21144153f9a45d5227e54102983c6c2659..644460ee4735a8acc7e5ca32d7983d945fa13826 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -49,6 +49,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cu sigmoid_op.cc) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu) +op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu) op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op softmax_op net) diff --git a/paddle/operators/fill_zeros_like_op.cc b/paddle/operators/fill_zeros_like_op.cc index 3df3a2cfab6d0c69d660eb78c103738eb0ccc627..d641bc4adaf8c7a84f5dab37632108d929a64730 100644 --- a/paddle/operators/fill_zeros_like_op.cc +++ b/paddle/operators/fill_zeros_like_op.cc @@ -19,16 +19,16 @@ limitations under the License. */ namespace paddle { namespace operators { -class FillZerosLike : public framework::OperatorWithKernel { +class FillZerosLikeOp : public framework::OperatorWithKernel { protected: void InferShape( const std::vector &inputs, const std::vector &outputs) const override { PADDLE_ENFORCE(inputs.size() == 1, - "Input size of FillZerosLike must be one."); + "Input size of FillZerosLikeOp must be one."); PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one."); PADDLE_ENFORCE(inputs[0] != nullptr && outputs[0] != nullptr, - "Outputs of FillZerosLike must all be set."); + "Outputs of FillZerosLikeOp must all be set."); outputs[0]->Resize(inputs[0]->dims()); } }; @@ -44,7 +44,7 @@ public: Fill up a vriable with zeros. The output will have the same size with input. -)DOC") +)DOC"); } }; } // namespace operators @@ -53,6 +53,6 @@ The output will have the same size with input. REGISTER_OP(fill_zeros_like, paddle::operators::FillZerosLikeOp, paddle::operators::FillZerosLikeOpMaker); -EGISTER_OP_CPU_KERNEL( +REGISTER_OP_CPU_KERNEL( fill_zeros_like, - paddle::operators::FillZerosLikeKernal); \ No newline at end of file + paddle::operators::FillZerosLikeKernel);