diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index f7e5753ac2c238627b94050059620f87966ab4ed..8c34a77c20787e59376556c9d8f8ccbfa34d074b 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -7,4 +7,8 @@ cc_test(scope_test SRCS scope_test.cc) cc_test(enforce_test SRCS enforce_test.cc) proto_library(attr_type SRCS attr_type.proto) proto_library(op_proto SRCS op_proto.proto) + cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto attr_type protobuf) + +proto_library(net_proto SRCS net_proto.proto) +cc_library(net SRCS net.cc DEPS net_proto attr_type op_proto) diff --git a/paddle/framework/net.h b/paddle/framework/net.h index 88bdf0bb68bffaa758cc8db723ed0377838516e5..b3064e4f90b773b7c0bcaaec66d1d47103e87793 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -14,6 +14,8 @@ #pragma once +#include "paddle/framework/net_proto.pb.h" +#include "paddle/framework/op_proto.pb.h" #include "paddle/framework/scope.h" namespace paddle { @@ -27,31 +29,11 @@ typedef int OpIndex; * keep updating if the concepts related are implemented. */ -// Operator's runtime context. -struct OpContext { - int dev_id; - DevType dev_type{kCPU}; - enum DevType { kCPU, kGPU }; -}; - -// Proto definitions, use `struct`s for simpility. -struct VarDesc { - std::string type; - std::vector dims; -}; -struct OpDesc { - std::string type; - std::vector inputs; - std::vector outputs; -}; -struct struct NetDesc { - std::vector ops; -}; class Operator { public: Operator(const OpDesc &def) {} - Error InferShape() {} - Error Run() {} + bool InferShape() {} + bool Run() {} }; /** @@ -73,7 +55,7 @@ class Net { /** * @brief Infer shapes of all inputs and outputs of operators. */ - virtual Error InferShape(Scope *scope) override; + virtual bool InferShape(Scope *scope) override; /** * @brief Run the network. * @@ -82,8 +64,8 @@ class Net { * environment for ops. `begin` and `end` specify the scope of `ops_` to run, * If no positive indexes are provided, all operators in `ops_` will run. */ - virtual Error Run(Scope *scope, OpContext *context, OpIndex begin = -1, - OpIndex end = -1) const = 0; + virtual bool Run(Scope *scope, OpContext *context, OpIndex begin = -1, + OpIndex end = -1) const = 0; /** * @brief Add an Operator according to `def`. @@ -93,12 +75,12 @@ class Net { /** * @brief Add optimizer operators acctording to `attrs`. */ - virtual Error AddOptimizerOps(const OptAttrs &attrs) = 0; + virtual bool AddOptimizerOps(const OptAttrs &attrs) = 0; /** * @brief Add backward operators. */ - virtual Error AddBackwardOps() = 0; + virtual bool AddBackwardOps() = 0; /** * @brief Create a network. @@ -126,7 +108,7 @@ class PlainNet : public Net { * Infer all the operators' input and output varialbes' shapes, will be called * before every mini-batch */ - virtual Error InferShape(Scope *scope) override; + virtual bool InferShape(Scope *scope) override; /** * @brief Run the network. @@ -135,8 +117,8 @@ class PlainNet : public Net { * scope will be used instead. If no OpContext is provicded, default context * will be used. */ - virtual Error Run(Scope *scope = nullptr, OpContext *context = nullptr, - OpIndex begin = -1, OpIndex end = -1) const override; + virtual bool Run(Scope *scope = nullptr, OpContext *context = nullptr, + OpIndex begin = -1, OpIndex end = -1) const override; /** * @brief Add an operator to this network. @@ -146,12 +128,12 @@ class PlainNet : public Net { /** * @brief Add all optimizer operators related into the network. */ - virtual Error AddOptimizerOps(const OptAttrs &attrs) override; + virtual bool AddOptimizerOps(const OptAttrs &attrs) override; /** * @brief Add all backward operators related into the network. */ - virtual Error AddBackwardOps() override; + virtual bool AddBackwardOps() override; protected: /** @@ -159,7 +141,7 @@ class PlainNet : public Net { * * Create operators accordding to `def`, will be called by the constructor. */ - Error BuildNet(const NetDesc &def); + bool BuildNet(const NetDesc &def); /** * @brief Add an operator into this network. diff --git a/paddle/framework/net_proto.proto b/paddle/framework/net_proto.proto new file mode 100644 index 0000000000000000000000000000000000000000..e9aed8f349b80925fc492707393b03835e73fcdd --- /dev/null +++ b/paddle/framework/net_proto.proto @@ -0,0 +1,16 @@ +syntax="proto2"; +package paddle.framework; + +import "op_proto.proto" + +message NetDesc { + // network identification + optional string name = 1; + // operator contains in network + repeated OpProto operators = 2; + // network type to run with. e.g "plainNet", "DAG" + optional string type = 3; + // num worker always + optional int32 num_workers = 4; +} +