tiny_runtime.cc 4.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <dlfcn.h>
#include <omp.h>

#include <algorithm>
#include <map>
#include <memory>
#include <thread>
#include <vector>

#include "cinn_runtime.h"

extern "C" {
int max_num_workers = std::thread::hardware_concurrency();
// move to standlone file
struct param_context_t {
  int major_v;
  int minor_v;
  std::vector<uint8_t> buf;
  std::vector<std::vector<uint8_t>> temporary;
  std::map<std::string, cinn_pod_value_t> name2podvalue;
  std::vector<std::string> instructions;
  std::vector<int> inst_argc;
  std::vector<cinn_pod_value_t *> inst_argv;
};

void *load_program(const char *paramfile) {
  FILE *f = fopen(paramfile, "r");
  fseek(f, 0, SEEK_END);
  int fsize = ftell(f);
  rewind(f);
  if (fsize < 32) {
    fclose(f);
    return nullptr;
  }

  std::unique_ptr<param_context_t> ctx(new param_context_t{});
  int alignment = std::max(alignof(cinn_pod_value_t), alignof(cinn_buffer_t));
  ctx->buf.resize(fsize + alignment);
  uint8_t *buf = ctx->buf.data();
  if ((uintptr_t)buf % alignment) {
    buf = buf + alignment - ((uintptr_t)buf % alignment);
  }
  fread(buf, 1, fsize, f);
  fclose(f);

  if (std::string(buf, buf + 4) != "CINN") {
    // TODO LOG fatal
    return nullptr;
  }
  // TODO check param file version
  ctx->major_v = *(int *)(buf + 4);
  ctx->minor_v = *(int *)(buf + 8);

68 69
  int *namelist_pos = (int *)(buf + 16);
  int *podvalue_pos = (int *)(buf + *namelist_pos);
70
  int *persistent_pos = (int *)(buf + *podvalue_pos);
71
  int *inst_pos = (int *)(buf + *persistent_pos);
72 73 74 75 76 77 78 79
  if (fsize < *inst_pos) {
    return nullptr;
  }

  int namelen = namelist_pos[1];
  std::vector<const char *> namev(namelen);
  std::map<std::string, int> name2index;
  for (int i = 0; i < namelen; i++) {
80 81
    int offset = (namelist_pos + 2)[i];
    namev[i] = (char *)(buf + offset);
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
    name2index[namev[i]] = i;
  }

  cinn_buffer_t *cb = (cinn_buffer_t *)(buf + podvalue_pos[1]);
  for (int i = 0; i < namelen; i++) {
    // currently only CPU device is supported, so just use malloc
    if (cb[i].memory) {
      cb[i].memory = buf + (uintptr_t)cb[i].memory;
    } else {
      int alignment = cb[i].align;
      if (alignment == 0) {
        alignment = 4;
      }
      ctx->temporary.emplace_back(alignment + cb[i].memory_size);
      uint8_t *tbuf = ctx->temporary.back().data();
      if ((uintptr_t)tbuf % alignment) {
        tbuf = tbuf + alignment - ((uintptr_t)tbuf % alignment);
      }
      cb[i].memory = tbuf;
    }
    ctx->name2podvalue[namev[i]] = cinn_pod_value_t(cb + i);
  }
  for (int i = 0; i < inst_pos[1]; i++) {
    const char *inst = (const char *)(buf + inst_pos[2 + i * 3 + 0]);
    ctx->instructions.push_back(inst);
    int instargc = inst_pos[2 + i * 3 + 1];
    ctx->inst_argc.push_back(instargc);
109 110
    cinn_pod_value_t *argv =
        (cinn_pod_value_t *)(buf + inst_pos[2 + i * 3 + 2]);
111 112 113 114 115 116 117 118 119 120 121 122
    for (int i = 0; i < instargc; i++) {
      int idx = (uintptr_t)((cinn_buffer_t *)argv[i]);
      cinn_value_t tmp_v;
      tmp_v.v_handle = &cb[idx];
      argv[i].set_value(tmp_v);
    }
    ctx->inst_argv.push_back(argv);
  }
  return ctx.release();
}

int set_maxconcurrency(int c) {
123
  int old_c = max_num_workers;
124 125 126 127 128 129 130 131 132
  max_num_workers = c;
  return old_c;
}

typedef void (*func_t)(cinn_pod_value_t *, int);
void run_program(void *ctx) {
  param_context_t *pc = (param_context_t *)ctx;
  for (int i = 0; i < pc->instructions.size(); i++) {
    const char *sym = pc->instructions[i].c_str();
133 134
    void *p = dlsym(RTLD_DEFAULT, sym);
    func_t f = (func_t)p;
135 136 137 138 139 140 141 142 143 144 145 146 147
    f(pc->inst_argv[i], pc->inst_argc[i]);
  }
}

cinn_pod_value_t *get_pod_value(void *ctx, const char *tname) {
  param_context_t *pc = (param_context_t *)ctx;
  if (pc->name2podvalue.find(tname) != pc->name2podvalue.end()) {
    return &pc->name2podvalue[tname];
  }
  return nullptr;
}

typedef int (*FCINNParallelLambda)(int task_id, int num_task, void *datas);
148 149 150
int cinn_backend_parallel_launch(FCINNParallelLambda flambda,
                                 void *datas,
                                 int num_task) {
151 152 153 154 155 156 157 158 159 160 161
  int num_workers = max_num_workers;
  if (num_task == 0) num_task = num_workers;
  omp_set_num_threads(num_task);
#pragma omp parallel num_threads(num_task)
  {
    int thread_num = omp_get_thread_num();
    (*flambda)(thread_num, num_task, datas);
  }
  return 0;
}
}