/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include "paddle/framework/op_proto.pb.h" #include "paddle/framework/op_registry.h" #include "paddle/framework/scope.h" #include "paddle/platform/device_context.h" namespace paddle { namespace framework { /** * @brief Network is also a type of Operator * * It will manage the operators it has. * * Network is the container and controller of a set of operators. * A network object knows all Operators belonging to this network. Variables, * which are inputs and outputs of these operators, are created and managed by a * hierarchy of Scope objects. * * This is the base class of network, all the networks should implement the APIs * it defines. */ class Net : public OperatorBase { public: virtual void AddOp(const OperatorPtr& op) = 0; virtual void CompleteAddOp(bool calc) = 0; }; using NetPtr = std::shared_ptr; /** * @brief a basic implementation of Net. * * PlainNet is a very simple Net, it create a list of operators, and run them * sequentially following the order they added. */ class PlainNet : public Net { public: /** * Infer all the operators' input and output variables' shapes, will be called * before every mini-batch */ void InferShape(const ScopePtr& scope) const override { for (auto& op : ops_) { op->InferShape(scope); } } /** * @brief Run the network. * * Run all the operators with the `scope`, if no scope is provided, default * scope will be used instead. If no OpContext is provicded, default context * will be used. */ void Run(const ScopePtr& scope, const platform::DeviceContext& dev_ctx) const override { for (auto& op : ops_) { op->Run(scope, dev_ctx); } } /** * @brief Add an operator by ptr */ void AddOp(const OperatorPtr& op) override { PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed"); ops_.push_back(op); } void CompleteAddOp(bool calculate = true) override; std::string DebugString() const override; std::vector ops_; private: bool add_op_done_{false}; template static bool Contains(T container, KeyType key) { return container.find(key) != container.end(); } }; } // namespace framework } // namespace paddle