diff --git a/paddle/framework/fully_connected_op.cc b/paddle/framework/fully_connected_op.cc deleted file mode 100644 index 28be46366fffa54ad74c2c1b6f05ff08ae069813..0000000000000000000000000000000000000000 --- a/paddle/framework/fully_connected_op.cc +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/framework/fully_connected_op.h" -#include -namespace paddle { -namespace framework { - -void FCOp::Run(const ScopePtr& scope, - const platform::DeviceContext& dev_ctx) const override { - std::cout << "FC" << std::endl; -} - -void FCOp::InferShape(const ScopePtr& scope) const override {} - -void FCGradientOp::Run(const ScopePtr& scope, - const platform::DeviceContext& dev_ctx) const override { - std::cout << "FCGrad" << std::endl; -} - -void FCGradientOp::InferShape(const ScopePtr& scope) const override {} - -REGISTER_OP(my_fc, paddle::framework::FCOp, - paddle::framework::FCOpProtoAndCheckerMaker); -REGISTER_OP(my_fc_grad, paddle::framework::FCGradientOp, - paddle::framework::FCGradientOpProtoAndCheckerMaker); -} // namespace framework -} // namespace paddle diff --git a/paddle/framework/fully_connected_op.h b/paddle/framework/fully_connected_op.h index 948116f653f44ae9bd6bdcc4c84c59adc7d86350..f049eda9bbd1f7c6322cc09e95737cd351594a42 100644 --- a/paddle/framework/fully_connected_op.h +++ b/paddle/framework/fully_connected_op.h @@ -47,6 +47,8 @@ class FCGradientOp : public OperatorBase { }; // class FCGradientOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {}; +REGISTER_OP(my_fc, FCOp, FCOpProtoAndCheckerMaker); +REGISTER_GRADIENT_OP(my_fc_grad, FCGradientOp); } // namespace framework } // namespace paddle diff --git a/paddle/framework/net_op_test.cc b/paddle/framework/net_op_test.cc index 814f397c7da0b44256ba6e41476c229552ff2710..18151c56d9acb3b10d5949f92b3e093d38c796e0 100644 --- a/paddle/framework/net_op_test.cc +++ b/paddle/framework/net_op_test.cc @@ -2,6 +2,7 @@ #include #include #include +#include "paddle/framework/fully_connected_op.h" namespace paddle { namespace framework { @@ -31,68 +32,51 @@ void AssertSameVectorWithoutOrder(const std::vector& expected, } } -class PlainNetTest : public testing::Test { - public: - virtual void SetUp() { - net_ = std::make_shared(); - ASSERT_NE(net_, nullptr); - - auto op1 = std::make_shared(); - op1->inputs_ = {"x", "w1", "b1"}; - op1->outputs_ = {"y"}; - net_->AddOp(op1); - - auto op2 = std::make_shared(); - op2->inputs_ = {"y", "w2", "b2"}; - op2->outputs_ = {"z"}; - net_->AddOp(op2); - net_->CompleteAddOp(); - } - - virtual void TearDown() {} - - virtual void TestBody() {} - - void TestOpKernel() { - AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"}, net_->inputs_); - AssertSameVectorWithoutOrder({"y", "z"}, net_->outputs_); - auto tmp_idx_iter = net_->attrs_.find("temporary_index"); - ASSERT_NE(net_->attrs_.end(), tmp_idx_iter); - auto& tmp_idx = boost::get>(tmp_idx_iter->second); - ASSERT_EQ(1UL, tmp_idx.size()); - ASSERT_EQ("y", net_->outputs_[tmp_idx[0]]); - - auto scope = std::make_shared(); - platform::CPUDeviceContext dev_ctx; - - net_->InferShape(scope); - net_->Run(scope, dev_ctx); - ASSERT_EQ(2, infer_shape_cnt); - ASSERT_EQ(2, run_cnt); - - auto op2 = std::make_shared(); - ASSERT_THROW(net_->AddOp(op2), EnforceNotMet); - } - - void TestAddBackwardOp() { - auto grad_ops = AddBackwardOp(net_); - for (auto& op : grad_ops->ops_) { - op->DebugString(); - } - } - - private: - std::shared_ptr net_; -}; - TEST(OpKernel, all) { - PlainNetTest net; - net.TestOpKernel(); + auto net = std::make_shared(); + ASSERT_NE(net, nullptr); + + auto op1 = std::make_shared(); + op1->inputs_ = {"x", "w1", "b1"}; + op1->outputs_ = {"y"}; + net->AddOp(op1); + + auto op2 = std::make_shared(); + op2->inputs_ = {"y", "w2", "b2"}; + op2->outputs_ = {"z"}; + net->AddOp(op2); + + net->CompleteAddOp(); + AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"}, net->inputs_); + AssertSameVectorWithoutOrder({"y", "z"}, net->outputs_); + auto tmp_idx_iter = net->attrs_.find("temporary_index"); + ASSERT_NE(net->attrs_.end(), tmp_idx_iter); + auto& tmp_idx = boost::get>(tmp_idx_iter->second); + ASSERT_EQ(1UL, tmp_idx.size()); + ASSERT_EQ("y", net->outputs_[tmp_idx[0]]); + + auto scope = std::make_shared(); + platform::CPUDeviceContext dev_ctx; + + net->InferShape(scope); + net->Run(scope, dev_ctx); + ASSERT_EQ(2, infer_shape_cnt); + ASSERT_EQ(2, run_cnt); + + ASSERT_THROW(net->AddOp(op2), EnforceNotMet); } -TEST(AddBackwardOp, TestAddBackwardOp) { - PlainNetTest net; - net.TestAddBackwardOp(); +TEST(AddBackwardOp, TestGradOp) { + auto net = std::make_shared(); + ASSERT_NE(net, nullptr); + auto op1 = std::make_shared(); + op1->inputs_ = {"x", "w1", "b1"}; + op1->outputs_ = {"y"}; + net->AddOp(op1); + auto grad_ops = AddBackwardOp(net); + for (auto& op : grad_ops->ops_) { + op->DebugString(); + } } } // namespace framework diff --git a/paddle/framework/net_test.cc b/paddle/framework/net_test.cc deleted file mode 100644 index 5afc0d9204b8a62bf5fcdca748ef275da8200139..0000000000000000000000000000000000000000 --- a/paddle/framework/net_test.cc +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/framework/net.h" -#include "paddle/framework/fully_connected_op.h" -#include "paddle/framework/op_registry.h" - -#include - -namespace paddle { -namespace framework { - -TEST(AddBackwardOp, ALL) - -} // namespace framework -} // namespace paddle