net.h 1.1 KB
Newer Older
李寅 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//

#ifndef MACE_CORE_NET_H_
#define MACE_CORE_NET_H_

#include "mace/core/common.h"
#include "mace/proto/mace.pb.h"
#include "mace/core/operator.h"
#include "mace/core/workspace.h"

namespace mace {

class NetBase {
 public:
  NetBase(const std::shared_ptr<const NetDef> &net_def, Workspace* ws, DeviceType type);
  virtual ~NetBase() noexcept {}

  virtual bool Run() = 0;

  const string &Name() const {
    return name_;
  }

 protected:
  string name_;

 DISABLE_COPY_AND_ASSIGN(NetBase);
};

class SimpleNet : public NetBase {
 public:
  SimpleNet(const std::shared_ptr<const NetDef>& net_def, Workspace* ws, DeviceType type);

  virtual bool Run() override;

 protected:
  vector<unique_ptr<OperatorBase> > operators_;

 DISABLE_COPY_AND_ASSIGN(SimpleNet);
};

unique_ptr<NetBase> CreateNet(const NetDef& net_def, Workspace* ws, DeviceType type);
unique_ptr<NetBase> CreateNet(
    const std::shared_ptr<const NetDef>& net_def,
    Workspace* ws,
    DeviceType type);

} //  namespace mace

#endif // MACE_CORE_NET_H_