server.cc 6.7 KB
Newer Older
D
dinghao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "core/server.h"
#include <grpcpp/grpcpp.h>
#include <grpcpp/health_check_service_interface.h>
#include <grpcpp/ext/proto_server_reflection_plugin.h>
#include <string>
#include <map>
#include <vector>
#include <utility>
#include <memory>
D
dinghao 已提交
25
#include <future>
26
#include <chrono>
D
dinghao 已提交
27

28
#include "include/infer_log.h"
D
dinghao 已提交
29 30 31 32
#include "serving/ms_service.grpc.pb.h"
#include "core/util/option_parser.h"
#include "core/version_control/version_controller.h"
#include "core/util/file_system_operation.h"
33
#include "core/serving_tensor.h"
D
dinghao 已提交
34 35 36 37 38 39 40

using ms_serving::MSService;
using ms_serving::PredictReply;
using ms_serving::PredictRequest;

namespace mindspore {
namespace serving {
41 42 43 44 45 46 47 48

#define MSI_TIME_STAMP_START(name) auto time_start_##name = std::chrono::steady_clock::now();
#define MSI_TIME_STAMP_END(name)                                                                             \
  {                                                                                                          \
    auto time_end_##name = std::chrono::steady_clock::now();                                                 \
    auto time_cost = std::chrono::duration<double, std::milli>(time_end_##name - time_start_##name).count(); \
    MSI_LOG_INFO << #name " Time Cost " << time_cost << "ms ---------------------";                          \
  }
D
dinghao 已提交
49 50

Status Session::CreatDeviceSession(const std::string &device, uint32_t device_id) {
51
  session_ = inference::InferSession::CreateSession(device, device_id);
D
dinghao 已提交
52
  if (session_ == nullptr) {
53
    MSI_LOG(ERROR) << "Creat Session Failed";
D
dinghao 已提交
54 55 56 57 58 59 60 61 62 63 64
    return FAILED;
  }
  device_type_ = device;
  return SUCCESS;
}

Session &Session::Instance() {
  static Session instance;
  return instance;
}

65 66 67
Status Session::Predict(const PredictRequest &request, PredictReply &reply) {
  if (!model_loaded_) {
    MSI_LOG(ERROR) << "the model has not loaded";
D
dinghao 已提交
68 69 70
    return FAILED;
  }
  if (session_ == nullptr) {
71
    MSI_LOG(ERROR) << "the inference session has not be initialized";
D
dinghao 已提交
72 73 74
    return FAILED;
  }
  std::lock_guard<std::mutex> lock(mutex_);
75
  MSI_LOG(INFO) << "run Predict";
D
dinghao 已提交
76

77 78 79 80 81 82 83
  ServingRequest serving_request(request);
  ServingReply serving_reply(reply);

  auto ret = session_->ExecuteModel(graph_id_, serving_request, serving_reply);
  MSI_LOG(INFO) << "run Predict finished";
  if (!ret) {
    MSI_LOG(ERROR) << "execute model return failed";
H
hexia 已提交
84 85
    return FAILED;
  }
D
dinghao 已提交
86 87 88 89 90
  return SUCCESS;
}

Status Session::Warmup(const MindSporeModelPtr model) {
  if (session_ == nullptr) {
91
    MSI_LOG(ERROR) << "The CreatDeviceSession should be called, before warmup";
D
dinghao 已提交
92 93 94 95
    return FAILED;
  }
  std::lock_guard<std::mutex> lock(mutex_);
  std::string file_name = model->GetModelPath() + '/' + model->GetModelName();
96 97 98 99 100 101
  model_loaded_ = false;
  MSI_TIME_STAMP_START(LoadModelFromFile)
  auto ret = session_->LoadModelFromFile(file_name, graph_id_);
  MSI_TIME_STAMP_END(LoadModelFromFile)
  if (!ret) {
    MSI_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
D
dinghao 已提交
102 103
    return FAILED;
  }
104 105
  model_loaded_ = true;
  MSI_LOG(INFO) << "Session Warmup finished";
D
dinghao 已提交
106 107 108 109
  return SUCCESS;
}

Status Session::Clear() {
110 111 112 113 114
  if (session_ != nullptr) {
    session_->UnloadModel(graph_id_);
    session_->FinalizeEnv();
    session_ = nullptr;
  }
D
dinghao 已提交
115 116 117 118
  return SUCCESS;
}

namespace {
D
dinghao 已提交
119 120 121
static const uint32_t uint32max = 0x7FFFFFFF;
std::promise<void> exit_requested;

D
dinghao 已提交
122 123
void ClearEnv() {
  Session::Instance().Clear();
124
  // inference::ExitInference();
D
dinghao 已提交
125
}
D
dinghao 已提交
126
void HandleSignal(int sig) { exit_requested.set_value(); }
D
dinghao 已提交
127 128 129 130 131 132 133

}  // namespace

// Service Implement
class MSServiceImpl final : public MSService::Service {
  grpc::Status Predict(grpc::ServerContext *context, const PredictRequest *request, PredictReply *reply) override {
    std::lock_guard<std::mutex> lock(mutex_);
134 135 136
    MSI_TIME_STAMP_START(Predict)
    auto res = Session::Instance().Predict(*request, *reply);
    MSI_TIME_STAMP_END(Predict)
D
dinghao 已提交
137 138 139
    if (res != SUCCESS) {
      return grpc::Status::CANCELLED;
    }
140
    MSI_LOG(INFO) << "Finish call service Eval";
D
dinghao 已提交
141 142 143 144
    return grpc::Status::OK;
  }

  grpc::Status Test(grpc::ServerContext *context, const PredictRequest *request, PredictReply *reply) override {
145
    MSI_LOG(INFO) << "TestService call";
D
dinghao 已提交
146 147 148 149 150 151 152 153
    return grpc::Status::OK;
  }
  std::mutex mutex_;
};

Status Server::BuildAndStart() {
  // handle exit signal
  signal(SIGINT, HandleSignal);
D
dinghao 已提交
154
  signal(SIGTERM, HandleSignal);
D
dinghao 已提交
155 156 157 158 159 160 161 162 163
  Status res;
  auto option_args = Options::Instance().GetArgs();
  std::string server_address = "0.0.0.0:" + std::to_string(option_args->grpc_port);
  std::string model_path = option_args->model_path;
  std::string model_name = option_args->model_name;
  std::string device_type = option_args->device_type;
  auto device_id = option_args->device_id;
  res = Session::Instance().CreatDeviceSession(device_type, device_id);
  if (res != SUCCESS) {
164
    MSI_LOG(ERROR) << "creat session failed";
D
dinghao 已提交
165 166 167 168 169 170
    ClearEnv();
    return res;
  }
  VersionController version_controller(option_args->poll_model_wait_seconds, model_path, model_name);
  res = version_controller.Run();
  if (res != SUCCESS) {
171
    MSI_LOG(ERROR) << "load model failed";
D
dinghao 已提交
172 173 174
    ClearEnv();
    return res;
  }
D
dinghao 已提交
175
  MSServiceImpl ms_service;
D
dinghao 已提交
176 177 178 179
  grpc::EnableDefaultHealthCheckService(true);
  grpc::reflection::InitProtoReflectionServerBuilderPlugin();
  // Set the port is not reuseable
  auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0);
D
dinghao 已提交
180 181 182 183
  grpc::ServerBuilder serverBuilder;
  serverBuilder.SetOption(std::move(option));
  serverBuilder.SetMaxMessageSize(uint32max);
  serverBuilder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
D
dinghao 已提交
184
  serverBuilder.RegisterService(&ms_service);
D
dinghao 已提交
185
  std::unique_ptr<grpc::Server> server(serverBuilder.BuildAndStart());
D
dinghao 已提交
186
  if (server == nullptr) {
187
    MSI_LOG(ERROR) << "The serving server create failed";
D
dinghao 已提交
188 189 190
    ClearEnv();
    return FAILED;
  }
D
dinghao 已提交
191 192
  auto grpc_server_run = [&server]() { server->Wait(); };
  std::thread serving_thread(grpc_server_run);
193
  MSI_LOG(INFO) << "MS Serving listening on " << server_address;
D
dinghao 已提交
194 195 196 197 198
  auto exit_future = exit_requested.get_future();
  exit_future.wait();
  ClearEnv();
  server->Shutdown();
  serving_thread.join();
D
dinghao 已提交
199 200 201 202
  return SUCCESS;
}
}  // namespace serving
}  // namespace mindspore