net.cc 4.9 KB
Newer Older
L
Liangliang He 已提交
1
// Copyright 2018 Xiaomi, Inc.  All rights reserved.
李寅 已提交
2
//
L
Liangliang He 已提交
3 4 5
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
李寅 已提交
6
//
L
Liangliang He 已提交
7 8 9 10 11 12 13
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
李寅 已提交
14

李寅 已提交
15 16
#include <utility>

李寅 已提交
17
#include "mace/core/net.h"
L
Liangliang He 已提交
18
#include "mace/utils/memory_logging.h"
19 20
#include "mace/utils/timer.h"
#include "mace/utils/utils.h"
李寅 已提交
21 22 23

namespace mace {

24 25
NetBase::NetBase(const std::shared_ptr<const OperatorRegistry> op_registry,
                 const std::shared_ptr<const NetDef> net_def,
L
Liangliang He 已提交
26
                 Workspace *ws,
李寅 已提交
27
                 DeviceType type)
28
    : op_registry_(op_registry), name_(net_def->name()) {}
李寅 已提交
29

30
SerialNet::SerialNet(const std::shared_ptr<const OperatorRegistry> op_registry,
31
                     const std::shared_ptr<const NetDef> net_def,
L
Liangliang He 已提交
32
                     Workspace *ws,
33
                     DeviceType type,
L
liuqi 已提交
34
                     const NetMode mode)
35
    : NetBase(op_registry, net_def, ws, type), device_type_(type) {
36
  MACE_LATENCY_LOGGER(1, "Constructing SerialNet ", net_def->name());
李寅 已提交
37
  for (int idx = 0; idx < net_def->op_size(); ++idx) {
L
Liangliang He 已提交
38
    const auto &operator_def = net_def->op(idx);
39 40
    VLOG(3) << "Creating operator " << operator_def.name() << "("
            << operator_def.type() << ")";
李寅 已提交
41
    OperatorDef temp_def(operator_def);
42 43
    std::unique_ptr<OperatorBase> op(
        op_registry->CreateOperator(temp_def, ws, type, mode));
L
liuqi 已提交
44
    if (op) {
L
liuqi 已提交
45 46
      operators_.emplace_back(std::move(op));
    }
李寅 已提交
47 48
  }
}
49

50
bool SerialNet::Run(RunMetadata *run_metadata) {
L
Liangliang He 已提交
51
  MACE_MEMORY_LOGGING_GUARD();
52
  MACE_LATENCY_LOGGER(1, "Running net");
53
  for (auto iter = operators_.begin(); iter != operators_.end(); ++iter) {
54
    auto &op = *iter;
55 56
    MACE_LATENCY_LOGGER(2, "Running operator ", op->debug_def().name(), "(",
                        op->debug_def().type(), ")");
57 58 59 60 61 62 63 64 65 66 67 68 69
    bool future_wait = (device_type_ == DeviceType::OPENCL &&
                        (run_metadata != nullptr ||
                         std::distance(iter, operators_.end()) == 1));

    bool ret;
    CallStats call_stats;
    if (future_wait) {
      StatsFuture future;
      ret = op->Run(&future);
      if (run_metadata != nullptr) {
        future.wait_fn(&call_stats);
      } else {
        future.wait_fn(nullptr);
L
liuqi 已提交
70
      }
71
    } else if (run_metadata != nullptr) {
72
      call_stats.start_micros = NowMicros();
73
      ret = op->Run(nullptr);
74
      call_stats.end_micros = NowMicros();
75 76
    } else {
      ret = op->Run(nullptr);
77
    }
78 79

    if (run_metadata != nullptr) {
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
      std::vector<int> strides;
      int padding_type = -1;
      std::vector<int> paddings;
      std::vector<int> dilations;
      std::vector<index_t> kernels;
      std::string type = op->debug_def().type();

      if (type.compare("Conv2D") == 0 ||
          type.compare("FusedConv2D") == 0 ||
          type.compare("DepthwiseConv2d") == 0 ||
          type.compare("Pooling") == 0) {
        strides = op->GetRepeatedArgument<int>("strides");
        padding_type = op->GetSingleArgument<int>("padding", -1);
        paddings = op->GetRepeatedArgument<int>("padding_values");
        dilations = op->GetRepeatedArgument<int>("dilations");
        if (type.compare("Pooling") == 0) {
          kernels = op->GetRepeatedArgument<index_t>("kernels");
        } else {
          kernels = op->Input(1)->shape();
        }
      }

102
      OperatorStats op_stats = {op->debug_def().name(), op->debug_def().type(),
103 104 105
                                op->debug_def().output_shape(),
                                {strides, padding_type, paddings, dilations,
                                 kernels}, call_stats};
106 107 108 109
      run_metadata->op_stats.emplace_back(op_stats);
    }

    if (!ret) {
L
liuqi 已提交
110
      LOG(ERROR) << "Operator failed: " << op->debug_def().name();
李寅 已提交
111 112
      return false;
    }
113

114 115
    VLOG(3) << "Operator " << op->debug_def().name()
            << " has shape: " << MakeString(op->Output(0)->shape());
李寅 已提交
116
  }
117

118
  return true;
李寅 已提交
119 120
}

121 122 123 124 125 126
std::unique_ptr<NetBase> CreateNet(
    const std::shared_ptr<const OperatorRegistry> op_registry,
    const NetDef &net_def,
    Workspace *ws,
    DeviceType type,
    const NetMode mode) {
李寅 已提交
127
  std::shared_ptr<NetDef> tmp_net_def(new NetDef(net_def));
128
  return CreateNet(op_registry, tmp_net_def, ws, type, mode);
李寅 已提交
129 130
}

131 132 133 134 135 136
std::unique_ptr<NetBase> CreateNet(
    const std::shared_ptr<const OperatorRegistry> op_registry,
    const std::shared_ptr<const NetDef> net_def,
    Workspace *ws,
    DeviceType type,
    const NetMode mode) {
137 138
  std::unique_ptr<NetBase> net(
      new SerialNet(op_registry, net_def, ws, type, mode));
李寅 已提交
139 140 141
  return net;
}

142
}  // namespace mace