net.cc 2.7 KB
Newer Older
李寅 已提交
1 2 3 4 5
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//

#include "mace/core/net.h"
6
#include "mace/utils/utils.h"
7
#include "mace/core/runtime/opencl/opencl_runtime.h"
李寅 已提交
8 9 10

namespace mace {

L
Liangliang He 已提交
11 12
NetBase::NetBase(const std::shared_ptr<const NetDef> &net_def,
                 Workspace *ws,
李寅 已提交
13
                 DeviceType type)
L
Liangliang He 已提交
14
    : name_(net_def->name()) {}
李寅 已提交
15

L
Liangliang He 已提交
16 17
SimpleNet::SimpleNet(const std::shared_ptr<const NetDef> &net_def,
                     Workspace *ws,
18
                     DeviceType type)
19
    : NetBase(net_def, ws, type), device_type_(type){
李寅 已提交
20 21
  VLOG(1) << "Constructing SimpleNet " << net_def->name();
  for (int idx = 0; idx < net_def->op_size(); ++idx) {
L
Liangliang He 已提交
22
    const auto &operator_def = net_def->op(idx);
李寅 已提交
23 24
    VLOG(1) << "Creating operator " << operator_def.name() << ":"
            << operator_def.type();
L
Liangliang He 已提交
25
    std::unique_ptr<OperatorBase> op{nullptr};
李寅 已提交
26 27
    OperatorDef temp_def(operator_def);
    op = CreateOperator(temp_def, ws, type);
L
liuqi 已提交
28
    if (op) {
L
liuqi 已提交
29 30
      operators_.emplace_back(std::move(op));
    }
李寅 已提交
31 32
  }
}
L
Liangliang He 已提交
33
bool SimpleNet::Run(RunMetadata *run_metadata) {
李寅 已提交
34
  VLOG(1) << "Running net " << name_;
L
Liangliang He 已提交
35
  for (auto &op : operators_) {
李寅 已提交
36 37
    VLOG(1) << "Running operator " << op->debug_def().name() << "("
            << op->debug_def().type() << ").";
L
Liangliang He 已提交
38
    OperatorStats *op_stats = nullptr;
39 40 41 42 43
    if (run_metadata) {
      op_stats = run_metadata->add_op_stats();
      op_stats->set_operator_name(op->debug_def().name());
      op_stats->set_type(op->debug_def().type());
      op_stats->set_all_start_micros(NowInMicroSec());
L
Liangliang He 已提交
44 45
      op_stats->set_op_start_rel_micros(NowInMicroSec() -
                                        op_stats->all_start_micros());
46
    }
李寅 已提交
47 48 49 50
    if (!op->Run()) {
      LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def());
      return false;
    }
51
    if (op_stats) {
L
Liangliang He 已提交
52 53 54 55
      op_stats->set_op_end_rel_micros(NowInMicroSec() -
                                      op_stats->all_start_micros());
      op_stats->set_all_end_rel_micros(NowInMicroSec() -
                                       op_stats->all_start_micros());
56
    }
李寅 已提交
57 58
    VLOG(1) << "Op " << op->debug_def().name()
            << " has shape: " << internal::MakeString(op->Output(0)->shape());
李寅 已提交
59
  }
60
  return true;
李寅 已提交
61 62
}

L
Liangliang He 已提交
63 64
unique_ptr<NetBase> CreateNet(const NetDef &net_def,
                              Workspace *ws,
李寅 已提交
65 66 67 68 69
                              DeviceType type) {
  std::shared_ptr<NetDef> tmp_net_def(new NetDef(net_def));
  return CreateNet(tmp_net_def, ws, type);
}

L
Liangliang He 已提交
70 71
unique_ptr<NetBase> CreateNet(const std::shared_ptr<const NetDef> &net_def,
                              Workspace *ws,
72
                              DeviceType type) {
李寅 已提交
73 74 75 76
  unique_ptr<NetBase> net(new SimpleNet(net_def, ws, type));
  return net;
}

L
Liangliang He 已提交
77
}  //  namespace mace