提交 099bb53b 编写于 作者: Y Yu Yang

Merge branch 'feature/backward' of github.com:reyoung/Paddle into feature/backward

...@@ -249,14 +249,20 @@ TEST(Backward, part_of_output_are_not_need) { ...@@ -249,14 +249,20 @@ TEST(Backward, part_of_output_are_not_need) {
} }
TEST(Backward, part_of_input_are_not_need) { TEST(Backward, part_of_input_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto fwd = f::OpRegistry::CreateOp("mul", {"a", "b"}, {"out"}, {});
auto backward = f::Backward(*fwd, {"X"}); auto backward = f::Backward(*fwd, {"a"});
ASSERT_TRUE(backward->IsNetOp()); ASSERT_TRUE(backward->IsNetOp());
auto net = static_cast<f::NetOp *>(backward.get()); auto net = static_cast<f::NetOp *>(backward.get());
ASSERT_EQ(1UL, net->ops_.size()); ASSERT_EQ(net->ops_.size(), 1UL);
auto &d_add = *net->ops_[0]; auto &grad_mul = *net->ops_[0];
ASSERT_EQ("rowwise_add_grad", d_add.type_); ASSERT_EQ(grad_mul.type_, "mul_grad");
ASSERT_EQ(f::OperatorBase::EMPTY_VAR_NAME(), ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL);
d_add.Output("X" + f::OperatorBase::GRAD_VAR_SUFFIX())); ASSERT_EQ(grad_mul.outputs_.size(), 2UL);
ASSERT_EQ(grad_mul.Output("A" + f::OperatorBase::GRAD_VAR_SUFFIX()),
f::OperatorBase::EMPTY_VAR_NAME());
ASSERT_EQ(grad_mul.Output("B" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"b" + f::OperatorBase::GRAD_VAR_SUFFIX());
ASSERT_EQ(grad_mul.Input("Out" + f::OperatorBase::GRAD_VAR_SUFFIX()),
"out" + f::OperatorBase::GRAD_VAR_SUFFIX());
} }
\ No newline at end of file
...@@ -49,6 +49,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) ...@@ -49,6 +49,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc)
op_library(sigmoid_op SRCS sigmoid_op.cu sigmoid_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cu sigmoid_op.cc)
op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu)
op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu) op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu)
op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu)
op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op
softmax_op net) softmax_op net)
......
...@@ -19,16 +19,16 @@ limitations under the License. */ ...@@ -19,16 +19,16 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class FillZerosLike : public framework::OperatorWithKernel { class FillZerosLikeOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape( void InferShape(
const std::vector<const framework::Tensor *> &inputs, const std::vector<const framework::Tensor *> &inputs,
const std::vector<framework::Tensor *> &outputs) const override { const std::vector<framework::Tensor *> &outputs) const override {
PADDLE_ENFORCE(inputs.size() == 1, PADDLE_ENFORCE(inputs.size() == 1,
"Input size of FillZerosLike must be one."); "Input size of FillZerosLikeOp must be one.");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one."); PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one.");
PADDLE_ENFORCE(inputs[0] != nullptr && outputs[0] != nullptr, PADDLE_ENFORCE(inputs[0] != nullptr && outputs[0] != nullptr,
"Outputs of FillZerosLike must all be set."); "Outputs of FillZerosLikeOp must all be set.");
outputs[0]->Resize(inputs[0]->dims()); outputs[0]->Resize(inputs[0]->dims());
} }
}; };
...@@ -44,7 +44,7 @@ public: ...@@ -44,7 +44,7 @@ public:
Fill up a vriable with zeros. Fill up a vriable with zeros.
The output will have the same size with input. The output will have the same size with input.
)DOC") )DOC");
} }
}; };
} // namespace operators } // namespace operators
...@@ -53,6 +53,6 @@ The output will have the same size with input. ...@@ -53,6 +53,6 @@ The output will have the same size with input.
REGISTER_OP(fill_zeros_like, REGISTER_OP(fill_zeros_like,
paddle::operators::FillZerosLikeOp, paddle::operators::FillZerosLikeOp,
paddle::operators::FillZerosLikeOpMaker); paddle::operators::FillZerosLikeOpMaker);
EGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
fill_zeros_like, fill_zeros_like,
paddle::operators::FillZerosLikeKernal<paddle::platform::CPUPlace, float>); paddle::operators::FillZerosLikeKernel<paddle::platform::CPUPlace, float>);
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册