diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index 0ce92968200dc459c088a02bbbe90da0d880de05..2d9e099dc0c2b7534aef84903fda8d1984b453d9 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -5,7 +5,7 @@ namespace framework { PlainNet::PlainNet(const NetDesc& def) {} -virtual Error PlainNet::InferShape() { +Error PlainNet::InferShape(Scope* scope) { for (auto& op : ops_) { // wrong shape auto err = op.InferShape(); @@ -15,9 +15,11 @@ virtual Error PlainNet::InferShape() { return Error(); } -virtual Error PlainNet::Run(Scope* scope = nullptr, - OpContext* context = nullptr, OpIndex begin = -1, - OpIndex end = -1) const {} +Error PlainNet::Run(Scope* scope, OpContext* context, OpIndex begin, + OpIndex end) const { + // TODO Add implementation here. + return Error(); +} } // namespace framework } // namespace paddle diff --git a/paddle/framework/net.h b/paddle/framework/net.h index b3064e4f90b773b7c0bcaaec66d1d47103e87793..76e0ed9330716949e31f207b86bd5a188c3f98b8 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -17,6 +17,7 @@ #include "paddle/framework/net_proto.pb.h" #include "paddle/framework/op_proto.pb.h" #include "paddle/framework/scope.h" +#include "paddle/utils/Error.h" namespace paddle { namespace framework { @@ -29,11 +30,16 @@ typedef int OpIndex; * keep updating if the concepts related are implemented. */ +struct OpDesc; +struct OpDef; +struct OpContext; +struct OpAttrs {}; + class Operator { public: Operator(const OpDesc &def) {} - bool InferShape() {} - bool Run() {} + Error InferShape() { return Error(); } + Error Run() { return Error(); } }; /** @@ -55,7 +61,7 @@ class Net { /** * @brief Infer shapes of all inputs and outputs of operators. */ - virtual bool InferShape(Scope *scope) override; + virtual Error InferShape(Scope *scope) = 0; /** * @brief Run the network. * @@ -64,28 +70,30 @@ class Net { * environment for ops. `begin` and `end` specify the scope of `ops_` to run, * If no positive indexes are provided, all operators in `ops_` will run. */ - virtual bool Run(Scope *scope, OpContext *context, OpIndex begin = -1, - OpIndex end = -1) const = 0; + virtual Error Run(Scope *scope, OpContext *context, OpIndex begin = -1, + OpIndex end = -1) const = 0; /** * @brief Add an Operator according to `def`. */ - virtual OpIndex AddOp(const proto::OpDef &def) = 0; + virtual OpIndex AddOp(const OpDef &def) = 0; /** * @brief Add optimizer operators acctording to `attrs`. */ - virtual bool AddOptimizerOps(const OptAttrs &attrs) = 0; + virtual Error AddOptimizerOps(const OpAttrs &attrs) = 0; /** * @brief Add backward operators. */ - virtual bool AddBackwardOps() = 0; + virtual Error AddBackwardOps() = 0; /** * @brief Create a network. */ static std::unique_ptr Create(const NetDesc &def = NetDesc()); + + virtual ~Net() = 0; }; /** @@ -108,7 +116,7 @@ class PlainNet : public Net { * Infer all the operators' input and output varialbes' shapes, will be called * before every mini-batch */ - virtual bool InferShape(Scope *scope) override; + virtual Error InferShape(Scope *scope) override; /** * @brief Run the network. @@ -117,23 +125,23 @@ class PlainNet : public Net { * scope will be used instead. If no OpContext is provicded, default context * will be used. */ - virtual bool Run(Scope *scope = nullptr, OpContext *context = nullptr, - OpIndex begin = -1, OpIndex end = -1) const override; + virtual Error Run(Scope *scope = nullptr, OpContext *context = nullptr, + OpIndex begin = -1, OpIndex end = -1) const override; /** * @brief Add an operator to this network. */ - virtual OpIndex AddOp(const proto::OpDef &def) override; + virtual OpIndex AddOp(const OpDef &def) override; /** * @brief Add all optimizer operators related into the network. */ - virtual bool AddOptimizerOps(const OptAttrs &attrs) override; + virtual Error AddOptimizerOps(const OpAttrs &attrs) override; /** * @brief Add all backward operators related into the network. */ - virtual bool AddBackwardOps() override; + virtual Error AddBackwardOps() override; protected: /** @@ -141,7 +149,7 @@ class PlainNet : public Net { * * Create operators accordding to `def`, will be called by the constructor. */ - bool BuildNet(const NetDesc &def); + Error BuildNet(const NetDesc &def); /** * @brief Add an operator into this network. @@ -151,9 +159,9 @@ class PlainNet : public Net { * `outputs` are keys of mutable output variables. An `OpIndex` will be * returned to indicate the offset of the new operator in `ops_`. */ - OpIndex AddOp(const std::string &type, const std::vector &inputs, - const std::vector &outputs, - const OprAttr &attrs = OprAttr()); + OpIndex AddOp(const std::string &type, const std::vector &inputs, + const std::vector &outputs, + const OpAttrs &attrs = OpAttrs()); private: // the operators owned by `Network`. diff --git a/paddle/framework/net_proto.proto b/paddle/framework/net_proto.proto index e9aed8f349b80925fc492707393b03835e73fcdd..2d042457e33065514f987f1157cb96de5f6cd5de 100644 --- a/paddle/framework/net_proto.proto +++ b/paddle/framework/net_proto.proto @@ -1,7 +1,7 @@ syntax="proto2"; package paddle.framework; -import "op_proto.proto" +import "op_proto.proto"; message NetDesc { // network identification @@ -13,4 +13,3 @@ message NetDesc { // num worker always optional int32 num_workers = 4; } -