From 40573cd56f723ebde6328ccd5dabe4a363c9f3db Mon Sep 17 00:00:00 2001 From: Superjom Date: Mon, 3 Jul 2017 14:41:43 +0800 Subject: [PATCH] add net headers --- paddle/framework/net.cc | 23 +++++ paddle/framework/net.h | 182 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 205 insertions(+) create mode 100644 paddle/framework/net.cc create mode 100644 paddle/framework/net.h diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc new file mode 100644 index 0000000000..0ce9296820 --- /dev/null +++ b/paddle/framework/net.cc @@ -0,0 +1,23 @@ +#include "paddle/framework/net.h" + +namespace paddle { +namespace framework { + +PlainNet::PlainNet(const NetDesc& def) {} + +virtual Error PlainNet::InferShape() { + for (auto& op : ops_) { + // wrong shape + auto err = op.InferShape(); + if (!err) return err; + } + // ok + return Error(); +} + +virtual Error PlainNet::Run(Scope* scope = nullptr, + OpContext* context = nullptr, OpIndex begin = -1, + OpIndex end = -1) const {} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/net.h b/paddle/framework/net.h new file mode 100644 index 0000000000..88bdf0bb68 --- /dev/null +++ b/paddle/framework/net.h @@ -0,0 +1,182 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/scope.h" + +namespace paddle { +namespace framework { + +// operator's index stored in a network. +typedef int OpIndex; +/** + * NOTE following codes are some definitions of unimplemented concepts. + * We write some basic implementation to make Net compilable. These APIs will + * keep updating if the concepts related are implemented. + */ + +// Operator's runtime context. +struct OpContext { + int dev_id; + DevType dev_type{kCPU}; + enum DevType { kCPU, kGPU }; +}; + +// Proto definitions, use `struct`s for simpility. +struct VarDesc { + std::string type; + std::vector dims; +}; +struct OpDesc { + std::string type; + std::vector inputs; + std::vector outputs; +}; +struct struct NetDesc { + std::vector ops; +}; +class Operator { + public: + Operator(const OpDesc &def) {} + Error InferShape() {} + Error Run() {} +}; + +/** + * @brief Network that manage the operators it has. + * + * Network is the container and controller of a set of operators, user can build + * a real network from a NetDesc which is a protobuf message and use + * Network.Run() * to run all the operators in the network. + + * A network object knows all Operators belonging to this network. Variables, + * which are inputs and outputs of these operators, are created and managed by a + * hierarchy of Scope objects. + * + * This is the base class of network, all the networks should implement the apis + * it defines. + */ +class Net { + public: + /** + * @brief Infer shapes of all inputs and outputs of operators. + */ + virtual Error InferShape(Scope *scope) override; + /** + * @brief Run the network. + * + * Run all the operators and return success(true) or not, with all the + * variables are located in `scope`. `context` describes the detail execution + * environment for ops. `begin` and `end` specify the scope of `ops_` to run, + * If no positive indexes are provided, all operators in `ops_` will run. + */ + virtual Error Run(Scope *scope, OpContext *context, OpIndex begin = -1, + OpIndex end = -1) const = 0; + + /** + * @brief Add an Operator according to `def`. + */ + virtual OpIndex AddOp(const proto::OpDef &def) = 0; + + /** + * @brief Add optimizer operators acctording to `attrs`. + */ + virtual Error AddOptimizerOps(const OptAttrs &attrs) = 0; + + /** + * @brief Add backward operators. + */ + virtual Error AddBackwardOps() = 0; + + /** + * @brief Create a network. + */ + static std::unique_ptr Create(const NetDesc &def = NetDesc()); +}; + +/** + * @brief a basic implementation of Net. + * + * PlainNet is a very simple Net, it create a list of operators, and run them + * sequentially following the order they added. + */ +class PlainNet : public Net { + public: + /** + * @brief Initialize a PlainNet. + * + * Initialize from a network describe by `def`. NetDesc is the definition of + * a network. + */ + PlainNet(const NetDesc &def); + + /** + * Infer all the operators' input and output varialbes' shapes, will be called + * before every mini-batch + */ + virtual Error InferShape(Scope *scope) override; + + /** + * @brief Run the network. + * + * Run all the operators with the `scope`, if no scope is provided, default + * scope will be used instead. If no OpContext is provicded, default context + * will be used. + */ + virtual Error Run(Scope *scope = nullptr, OpContext *context = nullptr, + OpIndex begin = -1, OpIndex end = -1) const override; + + /** + * @brief Add an operator to this network. + */ + virtual OpIndex AddOp(const proto::OpDef &def) override; + + /** + * @brief Add all optimizer operators related into the network. + */ + virtual Error AddOptimizerOps(const OptAttrs &attrs) override; + + /** + * @brief Add all backward operators related into the network. + */ + virtual Error AddBackwardOps() override; + + protected: + /** + * @brief Build the network. + * + * Create operators accordding to `def`, will be called by the constructor. + */ + Error BuildNet(const NetDesc &def); + + /** + * @brief Add an operator into this network. + * + * Add a operator which is identified as `type` and has attributes described + * in `attrs`, the `inputs` are the keys of readonly input variables, + * `outputs` are keys of mutable output variables. An `OpIndex` will be + * returned to indicate the offset of the new operator in `ops_`. + */ + OpIndex AddOp(const std::string &type, const std::vector &inputs, + const std::vector &outputs, + const OprAttr &attrs = OprAttr()); + + private: + // the operators owned by `Network`. + std::vector ops_; +}; + +} // namespace framework +} // namespace paddle -- GitLab