提交 0bd49a50 编写于 作者: Y Yan Chunwei 提交者: GitHub

move net_op to operators/ (#3201)

* move net_op to operators
上级 d953611e
......@@ -31,10 +31,7 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init)
cc_library(net SRCS net.cc DEPS op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net)
cc_library(backward SRCS backward.cc DEPS net)
cc_library(backward SRCS backward.cc DEPS net_op)
cc_test(backward_test SRCS backward_test.cc DEPS backward)
cc_library(paddle_pybind SHARED
SRCS pybind.cc
......
......@@ -14,8 +14,8 @@
#include "paddle/framework/backward.h"
#include <list>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
namespace paddle {
namespace framework {
......@@ -32,7 +32,7 @@ static bool AllInSet(const std::vector<std::string>& names,
}
static std::shared_ptr<OperatorBase> NOP() {
auto net_op = std::make_shared<NetOp>();
auto net_op = std::make_shared<operators::NetOp>();
net_op->type_ = "@NOP@";
net_op->CompleteAddOp();
return net_op;
......@@ -77,11 +77,11 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
}
// Returned gradient network
auto net = std::make_shared<NetOp>();
auto net = std::make_shared<operators::NetOp>();
if (forwardOp.IsNetOp()) {
// Because forwardOp is a net op, it can static_cast.
auto& forwardNet = static_cast<const NetOp&>(forwardOp);
auto& forwardNet = static_cast<const operators::NetOp&>(forwardOp);
// Map from output gradient variable name to operator's indices in backward
// net. That operator generates that variable.
......
......@@ -15,8 +15,9 @@
#include "paddle/framework/backward.h"
#include <gtest/gtest.h>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/type_alias.h"
namespace paddle {
namespace framework {
......@@ -70,7 +71,7 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
}
};
class FcOp : public NetOp {
class FcOp : public ops::NetOp {
public:
void Init() override {
AddOp(OpRegistry::CreateOp("mul", {Input("X"), Input("W")},
......@@ -182,7 +183,8 @@ TEST(Backward, simple_op_not_need_grad) {
auto no_input_gop = f::Backward(*fwd, {"X", "b"});
ASSERT_NE(no_input_gop, nullptr);
ASSERT_TRUE(no_input_gop->IsNetOp());
ASSERT_EQ(0UL, std::static_pointer_cast<f::NetOp>(no_input_gop)->ops_.size());
ASSERT_EQ(0UL,
std::static_pointer_cast<ops::NetOp>(no_input_gop)->ops_.size());
}
TEST(Backward, net_fc_backward_normal) {
......@@ -191,7 +193,7 @@ TEST(Backward, net_fc_backward_normal) {
ASSERT_NE(fwd, nullptr);
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
ASSERT_TRUE(gop->IsNetOp());
auto net = static_cast<f::NetOp *>(gop.get());
auto net = static_cast<ops::NetOp *>(gop.get());
ASSERT_NO_THROW(net->DebugString());
......@@ -214,7 +216,7 @@ TEST(Backward, net_fc_backward_not_have_b) {
ASSERT_NE(fwd, nullptr);
std::shared_ptr<f::OperatorBase> gop = f::Backward(*fwd, {});
ASSERT_TRUE(gop->IsNetOp());
auto net = static_cast<f::NetOp *>(gop.get());
auto net = static_cast<ops::NetOp *>(gop.get());
ASSERT_NO_THROW(net->DebugString());
......@@ -228,7 +230,7 @@ TEST(Backward, net_fc_backward_not_have_b) {
}
TEST(Backward, net_input_of_network_not_need_grad) {
f::NetOp net;
ops::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("fc", {"X", "W1", "b1"},
{"mul_tmp_0", "add_tmp_0", "hidden0"}, {}));
net.AddOp(f::OpRegistry::CreateOp("fc", {"hidden0", "W2", "b2"},
......@@ -236,7 +238,7 @@ TEST(Backward, net_input_of_network_not_need_grad) {
net.CompleteAddOp();
auto bwd = Backward(net, {"X"}); // X@GRAD is not need.
ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(bwd.get());
auto bwd_net = static_cast<ops::NetOp *>(bwd.get());
std::unordered_set<std::string> all_output = std::unordered_set<std::string>(
bwd_net->outputs_.begin(), bwd_net->outputs_.end());
......@@ -253,7 +255,7 @@ TEST(Backward, net_input_of_network_not_need_grad) {
ASSERT_EQ(2UL, bwd_net->ops_.size());
ASSERT_TRUE(bwd_net->ops_[1]->IsNetOp());
auto first_fc_grad = static_cast<f::NetOp *>(bwd_net->ops_[1].get());
auto first_fc_grad = static_cast<ops::NetOp *>(bwd_net->ops_[1].get());
ASSERT_EQ(3UL, first_fc_grad->ops_.size());
ASSERT_EQ(
f::OperatorBase::EMPTY_VAR_NAME(),
......@@ -261,14 +263,14 @@ TEST(Backward, net_input_of_network_not_need_grad) {
}
TEST(Backward, net_shared_weight) {
f::NetOp net;
ops::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("mul", {"X", "W"}, {"Out"}, {}));
net.AddOp(f::OpRegistry::CreateOp("mul", {"Out", "W"}, {"FinalOut"}, {}));
net.CompleteAddOp();
auto bwd = f::Backward(net, {});
ASSERT_TRUE(bwd->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(bwd.get());
auto bwd_net = static_cast<ops::NetOp *>(bwd.get());
ASSERT_EQ(3UL, bwd_net->ops_.size());
ASSERT_EQ("add", bwd_net->ops_[2]->type_);
}
......@@ -285,7 +287,7 @@ TEST(Backward, op_all_input_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
auto backward = f::Backward(*fwd, {"X", "b"});
ASSERT_TRUE(backward->IsNetOp());
auto net = static_cast<f::NetOp *>(backward.get());
auto net = static_cast<ops::NetOp *>(backward.get());
ASSERT_TRUE(net->ops_.empty());
}
......@@ -293,7 +295,7 @@ TEST(Backward, op_all_output_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("rowwise_add", {"X", "b"}, {"Out"}, {});
auto backward = f::Backward(*fwd, {"Out"});
ASSERT_TRUE(backward->IsNetOp());
auto net = static_cast<f::NetOp *>(backward.get());
auto net = static_cast<ops::NetOp *>(backward.get());
ASSERT_TRUE(net->ops_.empty());
}
......@@ -301,7 +303,7 @@ TEST(Backward, op_part_of_output_are_not_need) {
auto fwd = f::OpRegistry::CreateOp("many_output_op", {"X"}, {"Y", "Z"}, {});
auto backward = f::Backward(*fwd, {"Z"});
ASSERT_TRUE(backward->IsNetOp());
auto net = static_cast<f::NetOp *>(backward.get());
auto net = static_cast<ops::NetOp *>(backward.get());
ASSERT_EQ(net->ops_.size(), 2UL);
auto &fill_zero = *net->ops_[0];
......@@ -341,7 +343,7 @@ TEST(Backward, op_part_of_input_are_not_need) {
}
TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
f::NetOp net;
ops::NetOp net;
net.AddOp(f::OpRegistry::CreateOp("fc", {"x1", "w1", "b1"},
{"mul_out1", "add_out1", "out1"}, {}));
net.AddOp(f::OpRegistry::CreateOp("fc", {"out1", "w2", "b2"},
......@@ -351,7 +353,7 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
net.CompleteAddOp();
auto backward = f::Backward(net, {"mul_out2", "tmp_out2", "out2"});
ASSERT_TRUE(backward->IsNetOp());
auto bwd_net = static_cast<f::NetOp *>(backward.get());
auto bwd_net = static_cast<ops::NetOp *>(backward.get());
ASSERT_EQ(bwd_net->ops_.size(), 3UL);
auto &grad_fc = *bwd_net->ops_[0];
EXPECT_EQ(grad_fc.inputs_.size(),
......
......@@ -17,11 +17,12 @@ limitations under the License. */
#include <vector>
#include "paddle/framework/backward.h"
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor_py.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/type_alias.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include "pybind11/numpy.h"
......@@ -118,7 +119,9 @@ All parameter, weight, gradient are variables in Paddle.
[](Variable &self) -> Tensor * { return self.GetMutable<Tensor>(); },
py::return_value_policy::reference)
.def("get_net",
[](Variable &self) -> NetOp * { return self.GetMutable<NetOp>(); },
[](Variable &self) -> ops::NetOp * {
return self.GetMutable<ops::NetOp>();
},
py::return_value_policy::reference);
py::class_<Scope>(m, "Scope", "")
......@@ -196,22 +199,24 @@ All parameter, weight, gradient are variables in Paddle.
ExposeOperator(operator_base);
py::class_<NetOp, std::shared_ptr<NetOp>> net(m, "Net");
py::class_<ops::NetOp, std::shared_ptr<ops::NetOp>> net(m, "Net");
net.def_static("create",
[]() -> std::shared_ptr<NetOp> {
auto retv = std::make_shared<NetOp>();
[]() -> std::shared_ptr<ops::NetOp> {
auto retv = std::make_shared<ops::NetOp>();
retv->type_ = "plain_net";
return retv;
})
.def("add_op", &NetOp::AddOp)
.def("add_op",
[](NetOp &self, const std::shared_ptr<NetOp> &net) -> void {
self.AddOp(std::static_pointer_cast<OperatorBase>(net));
})
.def("complete_add_op", &NetOp::CompleteAddOp)
.def("add_op", &ops::NetOp::AddOp)
.def(
"add_op",
[](ops::NetOp &self, const std::shared_ptr<ops::NetOp> &net) -> void {
self.AddOp(std::static_pointer_cast<OperatorBase>(net));
})
.def("complete_add_op", &ops::NetOp::CompleteAddOp)
.def("complete_add_op",
[](std::shared_ptr<NetOp> &self) { self->CompleteAddOp(); });
[](std::shared_ptr<ops::NetOp> &self) { self->CompleteAddOp(); });
ExposeOperator(net);
m.def("unique_integer", UniqueIntegerGenerator);
......
......@@ -41,6 +41,9 @@ function(op_library TARGET)
endif()
endfunction()
cc_library(net_op SRCS net_op.cc DEPS op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
op_library(add_op SRCS add_op.cc add_op.cu)
cc_test(add_op_test SRCS add_op_test.cc DEPS add_op)
......@@ -59,6 +62,6 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
op_library(fc_op
SRCS fc_op.cc
DEPS mul_op rowwise_add_op sigmoid_op softmax_op net)
op_library(recurrent_op SRCS recurrent_op.cc DEPS op_desc tensor op_registry operator net)
DEPS mul_op rowwise_add_op sigmoid_op softmax_op net_op)
op_library(recurrent_op SRCS recurrent_op.cc DEPS op_desc tensor op_registry operator net_op)
cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op)
......@@ -14,11 +14,11 @@
limitations under the License.
*/
#include "paddle/framework/net.h"
#include "paddle/operators/net_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace framework {
namespace operators {
void NetOp::CompleteAddOp(bool calc) {
add_op_done_ = true;
......@@ -74,5 +74,5 @@ std::string NetOp::DebugString() const {
bool NetOp::IsNetOp() const { return true; }
} // namespace framework
} // namespace operators
} // namespace paddle
......@@ -14,15 +14,17 @@ limitations under the License. */
#pragma once
#include <paddle/framework/op_desc.pb.h>
#include <paddle/framework/operator.h>
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/operators/type_alias.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace framework {
namespace operators {
/**
* @brief Network is also a type of Operator
*
......@@ -37,13 +39,13 @@ namespace framework {
* This is the base class of network, all the networks should implement the APIs
* it defines.
*/
class NetOp : public OperatorBase {
public:
class NetOp : public framework::OperatorBase {
public:
/**
* Infer all the operators' input and output variables' shapes, will be called
* before every mini-batch
*/
void InferShape(const Scope& scope) const override {
void InferShape(const framework::Scope& scope) const override {
for (auto& op : ops_) {
op->InferShape(scope);
}
......@@ -56,7 +58,7 @@ class NetOp : public OperatorBase {
* scope will be used instead. If no OpContext is provicded, default context
* will be used.
*/
void Run(const Scope& scope,
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
for (auto& op : ops_) {
op->Run(scope, dev_ctx);
......@@ -88,7 +90,7 @@ class NetOp : public OperatorBase {
std::vector<std::shared_ptr<OperatorBase>> ops_;
private:
private:
bool add_op_done_{false};
template <typename T, typename KeyType>
......@@ -97,5 +99,5 @@ class NetOp : public OperatorBase {
}
};
} // namespace framework
} // namespace operators
} // namespace paddle
#include "paddle/operators/net_op.h"
#include <gtest/gtest.h>
#include <paddle/framework/net.h>
#include <paddle/framework/op_registry.h>
#include <paddle/framework/operator.h>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace framework {
namespace operators {
static int infer_shape_cnt = 0;
static int run_cnt = 0;
class TestOp : public OperatorBase {
public:
public:
void InferShape(const framework::Scope& scope) const override {
++infer_shape_cnt;
}
......@@ -21,7 +23,7 @@ class TestOp : public OperatorBase {
};
class EmptyOp : public OperatorBase {
public:
public:
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
......@@ -73,7 +75,7 @@ TEST(OpKernel, all) {
ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet);
}
TEST(Net, insert_op) {
TEST(NetOp, insert_op) {
NetOp net;
auto op1 = std::make_shared<EmptyOp>();
op1->inputs_ = {"x", "w1", "b1"};
......@@ -85,5 +87,5 @@ TEST(Net, insert_op) {
ASSERT_EQ(3UL, net.ops_.size());
}
} // namespace framework
} // namespace operators
} // namespace paddle
......@@ -18,8 +18,8 @@
#include <cstring>
#include <sstream>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
#include "paddle/platform/enforce.h"
namespace paddle {
......
......@@ -11,14 +11,15 @@
limitations under the License.
*/
#include "paddle/operators/recurrent_op.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/tensor.h"
#include "paddle/operators/recurrent_op.h"
#include "paddle/operators/net_op.h"
namespace paddle {
namespace operators {
......
......@@ -15,8 +15,8 @@
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/net.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
namespace paddle {
namespace operators {
......@@ -44,15 +44,16 @@ template <typename T,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Tensor = framework::Tensor;
using Scope = framework::Scope;
using OperatorWithKernel = framework::OperatorWithKernel;
using OperatorBase = framework::OperatorBase;
using OpProtoAndCheckerMaker = framework::OpProtoAndCheckerMaker;
using OpProto = framework::OpProto;
using OpAttrChecker = framework::OpAttrChecker;
using CPUPlace = platform::CPUPlace;
using GPUPlace = platform::GPUPlace;
using NetOp = framework::NetOp;
using OpRegistry = framework::OpRegistry;
using OperatorBase = framework::OperatorBase;
} // namespace operators
} // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册