提交 0bd49a50 编写于 作者: Y Yan Chunwei 提交者: GitHub

move net_op to operators/ (#3201)

* move net_op to operators
上级 d953611e
...@@ -31,10 +31,7 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc. ...@@ -31,10 +31,7 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init) add_dependencies(framework_py_proto framework_py_proto_init)
cc_library(net SRCS net.cc DEPS op_registry) cc_library(backward SRCS backward.cc DEPS net_op)
cc_test(net_op_test SRCS net_op_test.cc DEPS net)
cc_library(backward SRCS backward.cc DEPS net)
cc_test(backward_test SRCS backward_test.cc DEPS backward) cc_test(backward_test SRCS backward_test.cc DEPS backward)
cc_library(paddle_pybind SHARED cc_library(paddle_pybind SHARED
SRCS pybind.cc SRCS pybind.cc
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include "paddle/framework/backward.h" #include "paddle/framework/backward.h"
#include <list> #include <list>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -32,7 +32,7 @@ static bool AllInSet(const std::vector<std::string>& names, ...@@ -32,7 +32,7 @@ static bool AllInSet(const std::vector<std::string>& names,
} }
static std::shared_ptr<OperatorBase> NOP() { static std::shared_ptr<OperatorBase> NOP() {
auto net_op = std::make_shared<NetOp>(); auto net_op = std::make_shared<operators::NetOp>();
net_op->type_ = "@NOP@"; net_op->type_ = "@NOP@";
net_op->CompleteAddOp(); net_op->CompleteAddOp();
return net_op; return net_op;
...@@ -77,11 +77,11 @@ std::shared_ptr<OperatorBase> BackwardRecursive( ...@@ -77,11 +77,11 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
} }
// Returned gradient network // Returned gradient network
auto net = std::make_shared<NetOp>(); auto net = std::make_shared<operators::NetOp>();
if (forwardOp.IsNetOp()) { if (forwardOp.IsNetOp()) {
// Because forwardOp is a net op, it can static_cast. // Because forwardOp is a net op, it can static_cast.
auto& forwardNet = static_cast<const NetOp&>(forwardOp); auto& forwardNet = static_cast<const operators::NetOp&>(forwardOp);
// Map from output gradient variable name to operator's indices in backward // Map from output gradient variable name to operator's indices in backward
// net. That operator generates that variable. // net. That operator generates that variable.
......
...@@ -15,8 +15,9 @@ ...@@ -15,8 +15,9 @@
#include "paddle/framework/backward.h" #include "paddle/framework/backward.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/type_alias.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -70,7 +71,7 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker { ...@@ -70,7 +71,7 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
} }
}; };
class FcOp : public NetOp { class FcOp : public ops::NetOp {
public: public:
void Init() override { void Init() override {
AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")}, AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")},
...@@ -182,7 +183,8 @@ TEST(Backward, simple_op_not_need_grad) { ...@@ -182,7 +183,8 @@ TEST(Backward, simple_op_not_need_grad) {
auto no_input_gop = f::Backward(*fwd, {"X", "b"}); auto no_input_gop = f::Backward(*fwd, {"X", "b"});
ASSERT_NE(no_input_gop, nullptr); ASSERT_NE(no_input_gop, nullptr);
ASSERT_TRUE(no_input_gop->IsNetOp()); ASSERT_TRUE(no_input_gop->IsNetOp());
ASSERT_EQ(0UL, std::static_pointer_cast<f::NetOp>(no_input_gop)->ops_.size()); ASSERT_EQ(0UL,
std::static_pointer_cast<ops::NetOp>(no_input_gop)->ops_.size());
} }
TEST(Backward, net_fc_backward_normal) { TEST(Backward, net_fc_backward_normal) {
...@@ -191,7 +193,7 @@ TEST(Backward, net_fc_backward_normal) { ...@@ -191,7 +193,7 @@ TEST(Backward, net_fc_backward_normal) {
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {}); std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
ASSERT_TRUE(gop->IsNetOp()); ASSERT_TRUE(gop->IsNetOp());
auto net = static_cast<f::NetOp *>(gop.get()); auto net = static_cast<ops::NetOp *>(gop.get());
ASSERT_NO_THROW(net->DebugString()); ASSERT_NO_THROW(net->DebugString());
...@@ -214,7 +216,7 @@ TEST(Backward, net_fc_backward_not_have_b) { ...@@ -214,7 +216,7 @@ TEST(Backward, net_fc_backward_not_have_b) {
ASSERT_NE(fwd, nullptr); ASSERT_NE(fwd, nullptr);
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {}); std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
ASSERT_TRUE(gop->IsNetOp()); ASSERT_TRUE(gop->IsNetOp());
auto net = static_cast<f::NetOp *>(gop.get()); auto net = static_cast<ops::NetOp *>(gop.get());
ASSERT_NO_THROW(net->DebugString()); ASSERT_NO_THROW(net->DebugString());
...@@ -228,7 +230,7 @@ TEST(Backward, net_fc_backward_not_have_b) { ...@@ -228,7 +230,7 @@ TEST(Backward, net_fc_backward_not_have_b) {
} }
TEST(Backward, net_input_of_network_not_need_grad) { TEST(Backward, net_input_of_network_not_need_grad) {
f::NetOp net; ops::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"}, net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"},
{"mul_tmp_0", "add_tmp_0", "hidden0"}, {})); {"mul_tmp_0", "add_tmp_0", "hidden0"}, {}));
net.AddOp(f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"}, net.AddOp(f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"},
...@@ -236,7 +238,7 @@ TEST(Backward, net_input_of_network_not_need_grad) { ...@@ -236,7 +238,7 @@ TEST(Backward, net_input_of_network_not_need_grad) {
net.CompleteAddOp(); net.CompleteAddOp();
auto bwd = Backward(net, {"X"}); // X@GRAD is not need. auto bwd = Backward(net, {"X"}); // X@GRAD is not need.
ASSERT_TRUE(bwd->IsNetOp()); ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(bwd.get()); auto bwd_net = static_cast<ops::NetOp *>(bwd.get());
std::unordered_set<std::string> all_output = std::unordered_set<std::string>( std::unordered_set<std::string> all_output = std::unordered_set<std::string>(
bwd_net->outputs_.begin(), bwd_net->outputs_.end()); bwd_net->outputs_.begin(), bwd_net->outputs_.end());
...@@ -253,7 +255,7 @@ TEST(Backward, net_input_of_network_not_need_grad) { ...@@ -253,7 +255,7 @@ TEST(Backward, net_input_of_network_not_need_grad) {
ASSERT_EQ(2UL, bwd_net->ops_.size()); ASSERT_EQ(2UL, bwd_net->ops_.size());
ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp()); ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
auto first_fc_grad = static_cast<f::NetOp *>(bwd_net->ops_[1].get()); auto first_fc_grad = static_cast<ops::NetOp *>(bwd_net->ops_[1].get());
ASSERT_EQ(3UL, first_fc_grad->ops_.size()); ASSERT_EQ(3UL, first_fc_grad->ops_.size());
ASSERT_EQ( ASSERT_EQ(
f::OperatorBase::EMPTY_VAR_NAME(), f::OperatorBase::EMPTY_VAR_NAME(),
...@@ -261,14 +263,14 @@ TEST(Backward, net_input_of_network_not_need_grad) { ...@@ -261,14 +263,14 @@ TEST(Backward, net_input_of_network_not_need_grad) {
} }
TEST(Backward, net_shared_weight) { TEST(Backward, net_shared_weight) {
f::NetOp net; ops::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("mul", {"X", "W"}, {"Out"}, {})); net.AddOp(f::OpRegistry::CreateOp("mul", {"X", "W"}, {"Out"}, {}));
net.AddOp(f::OpRegistry::CreateOp("mul", {"Out", "W"}, {"FinalOut"}, {})); net.AddOp(f::OpRegistry::CreateOp("mul", {"Out", "W"}, {"FinalOut"}, {}));
net.CompleteAddOp(); net.CompleteAddOp();
auto bwd = f::Backward(net, {}); auto bwd = f::Backward(net, {});
ASSERT_TRUE(bwd->IsNetOp()); ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(bwd.get()); auto bwd_net = static_cast<ops::NetOp *>(bwd.get());
ASSERT_EQ(3UL, bwd_net->ops_.size()); ASSERT_EQ(3UL, bwd_net->ops_.size());
ASSERT_EQ("add", bwd_net->ops_[2]->type_); ASSERT_EQ("add", bwd_net->ops_[2]->type_);
} }
...@@ -285,7 +287,7 @@ TEST(Backward, op_all_input_are_not_need) { ...@@ -285,7 +287,7 @@ TEST(Backward, op_all_input_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
auto backward = f::Backward(*fwd, {"X", "b"}); auto backward = f::Backward(*fwd, {"X", "b"});
ASSERT_TRUE(backward->IsNetOp()); ASSERT_TRUE(backward->IsNetOp());
auto net = static_cast<f::NetOp *>(backward.get()); auto net = static_cast<ops::NetOp *>(backward.get());
ASSERT_TRUE(net->ops_.empty()); ASSERT_TRUE(net->ops_.empty());
} }
...@@ -293,7 +295,7 @@ TEST(Backward, op_all_output_are_not_need) { ...@@ -293,7 +295,7 @@ TEST(Backward, op_all_output_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {}); auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
auto backward = f::Backward(*fwd, {"Out"}); auto backward = f::Backward(*fwd, {"Out"});
ASSERT_TRUE(backward->IsNetOp()); ASSERT_TRUE(backward->IsNetOp());
auto net = static_cast<f::NetOp *>(backward.get()); auto net = static_cast<ops::NetOp *>(backward.get());
ASSERT_TRUE(net->ops_.empty()); ASSERT_TRUE(net->ops_.empty());
} }
...@@ -301,7 +303,7 @@ TEST(Backward, op_part_of_output_are_not_need) { ...@@ -301,7 +303,7 @@ TEST(Backward, op_part_of_output_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {}); auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {});
auto backward = f::Backward(*fwd, {"Z"}); auto backward = f::Backward(*fwd, {"Z"});
ASSERT_TRUE(backward->IsNetOp()); ASSERT_TRUE(backward->IsNetOp());
auto net = static_cast<f::NetOp *>(backward.get()); auto net = static_cast<ops::NetOp *>(backward.get());
ASSERT_EQ(net->ops_.size(), 2UL); ASSERT_EQ(net->ops_.size(), 2UL);
auto &fill_zero = *net->ops_[0]; auto &fill_zero = *net->ops_[0];
...@@ -341,7 +343,7 @@ TEST(Backward, op_part_of_input_are_not_need) { ...@@ -341,7 +343,7 @@ TEST(Backward, op_part_of_input_are_not_need) {
} }
TEST(Backward, linear_net_intermediate_variable_has_no_grad) { TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
f::NetOp net; ops::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"}, net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"},
{"mul_out1", "add_out1", "out1"}, {})); {"mul_out1", "add_out1", "out1"}, {}));
net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"}, net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"},
...@@ -351,7 +353,7 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ...@@ -351,7 +353,7 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
net.CompleteAddOp(); net.CompleteAddOp();
auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"}); auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"});
ASSERT_TRUE(backward->IsNetOp()); ASSERT_TRUE(backward->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(backward.get()); auto bwd_net = static_cast<ops::NetOp *>(backward.get());
ASSERT_EQ(bwd_net->ops_.size(), 3UL); ASSERT_EQ(bwd_net->ops_.size(), 3UL);
auto &grad_fc = *bwd_net->ops_[0]; auto &grad_fc = *bwd_net->ops_[0];
EXPECT_EQ(grad_fc.inputs_.size(), EXPECT_EQ(grad_fc.inputs_.size(),
......
...@@ -17,11 +17,12 @@ limitations under the License. */ ...@@ -17,11 +17,12 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/framework/backward.h" #include "paddle/framework/backward.h"
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
#include "paddle/framework/tensor_py.h" #include "paddle/framework/tensor_py.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/type_alias.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#include "pybind11/numpy.h" #include "pybind11/numpy.h"
...@@ -118,7 +119,9 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -118,7 +119,9 @@ All parameter, weight, gradient are variables in Paddle.
[](Variable &self) -> Tensor * { return self.GetMutable<Tensor>(); }, [](Variable &self) -> Tensor * { return self.GetMutable<Tensor>(); },
py::return_value_policy::reference) py::return_value_policy::reference)
.def("get_net", .def("get_net",
[](Variable &self) -> NetOp * { return self.GetMutable<NetOp>(); }, [](Variable &self) -> ops::NetOp * {
return self.GetMutable<ops::NetOp>();
},
py::return_value_policy::reference); py::return_value_policy::reference);
py::class_<Scope>(m, "Scope", "") py::class_<Scope>(m, "Scope", "")
...@@ -196,22 +199,24 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -196,22 +199,24 @@ All parameter, weight, gradient are variables in Paddle.
ExposeOperator(operator_base); ExposeOperator(operator_base);
py::class_<NetOp, std::shared_ptr<NetOp>> net(m, "Net"); py::class_<ops::NetOp, std::shared_ptr<ops::NetOp>> net(m, "Net");
net.def_static("create", net.def_static("create",
[]() -> std::shared_ptr<NetOp> { []() -> std::shared_ptr<ops::NetOp> {
auto retv = std::make_shared<NetOp>(); auto retv = std::make_shared<ops::NetOp>();
retv->type_ = "plain_net"; retv->type_ = "plain_net";
return retv; return retv;
}) })
.def("add_op", &NetOp::AddOp) .def("add_op", &ops::NetOp::AddOp)
.def("add_op", .def(
[](NetOp &self, const std::shared_ptr<NetOp> &net) -> void { "add_op",
self.AddOp(std::static_pointer_cast<OperatorBase>(net)); [](ops::NetOp &self, const std::shared_ptr<ops::NetOp> &net) -> void {
}) self.AddOp(std::static_pointer_cast<OperatorBase>(net));
.def("complete_add_op", &NetOp::CompleteAddOp) })
.def("complete_add_op", &ops::NetOp::CompleteAddOp)
.def("complete_add_op", .def("complete_add_op",
[](std::shared_ptr<NetOp> &self) { self->CompleteAddOp(); }); [](std::shared_ptr<ops::NetOp> &self) { self->CompleteAddOp(); });
ExposeOperator(net); ExposeOperator(net);
m.def("unique_integer", UniqueIntegerGenerator); m.def("unique_integer", UniqueIntegerGenerator);
......
...@@ -41,6 +41,9 @@ function(op_library TARGET) ...@@ -41,6 +41,9 @@ function(op_library TARGET)
endif() endif()
endfunction() endfunction()
cc_library(net_op SRCS net_op.cc DEPS op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
op_library(add_op SRCS add_op.cc add_op.cu) op_library(add_op SRCS add_op.cc add_op.cu)
cc_test(add_op_test SRCS add_op_test.cc DEPS add_op) cc_test(add_op_test SRCS add_op_test.cc DEPS add_op)
...@@ -59,6 +62,6 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) ...@@ -59,6 +62,6 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
op_library(fc_op op_library(fc_op
SRCS fc_op.cc SRCS fc_op.cc
DEPS mul_op rowwise_add_op sigmoid_op softmax_op net) DEPS mul_op rowwise_add_op sigmoid_op softmax_op net_op)
op_library(recurrent_op SRCS recurrent_op.cc DEPS op_desc tensor op_registry operator net) op_library(recurrent_op SRCS recurrent_op.cc DEPS op_desc tensor op_registry operator net_op)
cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op) cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op)
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
limitations under the License. limitations under the License.
*/ */
#include "paddle/framework/net.h" #include "paddle/operators/net_op.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace operators {
void NetOp::CompleteAddOp(bool calc) { void NetOp::CompleteAddOp(bool calc) {
add_op_done_ = true; add_op_done_ = true;
...@@ -74,5 +74,5 @@ std::string NetOp::DebugString() const { ...@@ -74,5 +74,5 @@ std::string NetOp::DebugString() const {
bool NetOp::IsNetOp() const { return true; } bool NetOp::IsNetOp() const { return true; }
} // namespace framework } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -14,15 +14,17 @@ limitations under the License. */ ...@@ -14,15 +14,17 @@ limitations under the License. */
#pragma once #pragma once
#include <paddle/framework/op_desc.pb.h> #include "paddle/framework/op_desc.pb.h"
#include <paddle/framework/operator.h>
#include "paddle/framework/op_proto.pb.h" #include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
#include "paddle/operators/type_alias.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
namespace framework { namespace operators {
/** /**
* @brief Network is also a type of Operator * @brief Network is also a type of Operator
* *
...@@ -37,13 +39,13 @@ namespace framework { ...@@ -37,13 +39,13 @@ namespace framework {
* This is the base class of network, all the networks should implement the APIs * This is the base class of network, all the networks should implement the APIs
* it defines. * it defines.
*/ */
class NetOp : public OperatorBase { class NetOp : public framework::OperatorBase {
public: public:
/** /**
* Infer all the operators' input and output variables' shapes, will be called * Infer all the operators' input and output variables' shapes, will be called
* before every mini-batch * before every mini-batch
*/ */
void InferShape(const Scope& scope) const override { void InferShape(const framework::Scope& scope) const override {
for (auto& op : ops_) { for (auto& op : ops_) {
op->InferShape(scope); op->InferShape(scope);
} }
...@@ -56,7 +58,7 @@ class NetOp : public OperatorBase { ...@@ -56,7 +58,7 @@ class NetOp : public OperatorBase {
* scope will be used instead. If no OpContext is provicded, default context * scope will be used instead. If no OpContext is provicded, default context
* will be used. * will be used.
*/ */
void Run(const Scope& scope, void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
for (auto& op : ops_) { for (auto& op : ops_) {
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
...@@ -88,7 +90,7 @@ class NetOp : public OperatorBase { ...@@ -88,7 +90,7 @@ class NetOp : public OperatorBase {
std::vector<std::shared_ptr<OperatorBase>> ops_; std::vector<std::shared_ptr<OperatorBase>> ops_;
private: private:
bool add_op_done_{false}; bool add_op_done_{false};
template <typename T, typename KeyType> template <typename T, typename KeyType>
...@@ -97,5 +99,5 @@ class NetOp : public OperatorBase { ...@@ -97,5 +99,5 @@ class NetOp : public OperatorBase {
} }
}; };
} // namespace framework } // namespace operators
} // namespace paddle } // namespace paddle
#include "paddle/operators/net_op.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <paddle/framework/net.h>
#include <paddle/framework/op_registry.h> #include "paddle/framework/op_registry.h"
#include <paddle/framework/operator.h> #include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
namespace framework { namespace operators {
static int infer_shape_cnt = 0; static int infer_shape_cnt = 0;
static int run_cnt = 0; static int run_cnt = 0;
class TestOp : public OperatorBase { class TestOp : public OperatorBase {
public: public:
void InferShape(const framework::Scope& scope) const override { void InferShape(const framework::Scope& scope) const override {
++infer_shape_cnt; ++infer_shape_cnt;
} }
...@@ -21,7 +23,7 @@ class TestOp : public OperatorBase { ...@@ -21,7 +23,7 @@ class TestOp : public OperatorBase {
}; };
class EmptyOp : public OperatorBase { class EmptyOp : public OperatorBase {
public: public:
void InferShape(const Scope& scope) const override {} void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {} const platform::DeviceContext& dev_ctx) const override {}
...@@ -73,7 +75,7 @@ TEST(OpKernel, all) { ...@@ -73,7 +75,7 @@ TEST(OpKernel, all) {
ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet); ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet);
} }
TEST(Net, insert_op) { TEST(NetOp, insert_op) {
NetOp net; NetOp net;
auto op1 = std::make_shared<EmptyOp>(); auto op1 = std::make_shared<EmptyOp>();
op1->inputs_ = {"x", "w1", "b1"}; op1->inputs_ = {"x", "w1", "b1"};
...@@ -85,5 +87,5 @@ TEST(Net, insert_op) { ...@@ -85,5 +87,5 @@ TEST(Net, insert_op) {
ASSERT_EQ(3UL, net.ops_.size()); ASSERT_EQ(3UL, net.ops_.size());
} }
} // namespace framework } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
#include <cstring> #include <cstring>
#include <sstream> #include <sstream>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
namespace paddle { namespace paddle {
......
...@@ -11,14 +11,15 @@ ...@@ -11,14 +11,15 @@
limitations under the License. limitations under the License.
*/ */
#include "paddle/operators/recurrent_op.h"
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/operators/recurrent_op.h" #include "paddle/operators/net_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
#pragma once #pragma once
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -44,15 +44,16 @@ template <typename T, ...@@ -44,15 +44,16 @@ template <typename T,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>; using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using Scope = framework::Scope;
using OperatorWithKernel = framework::OperatorWithKernel; using OperatorWithKernel = framework::OperatorWithKernel;
using OperatorBase = framework::OperatorBase;
using OpProtoAndCheckerMaker = framework::OpProtoAndCheckerMaker; using OpProtoAndCheckerMaker = framework::OpProtoAndCheckerMaker;
using OpProto = framework::OpProto; using OpProto = framework::OpProto;
using OpAttrChecker = framework::OpAttrChecker; using OpAttrChecker = framework::OpAttrChecker;
using CPUPlace = platform::CPUPlace; using CPUPlace = platform::CPUPlace;
using GPUPlace = platform::GPUPlace; using GPUPlace = platform::GPUPlace;
using NetOp = framework::NetOp;
using OpRegistry = framework::OpRegistry; using OpRegistry = framework::OpRegistry;
using OperatorBase = framework::OperatorBase;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册