提交 2b565627 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3794 modify inference return type

Merge pull request !3794 from dinghao/master
......@@ -41,6 +41,7 @@ cmake-build-debug
*.pb.h
*.pb.cc
*.pb
*_grpc.py
# Object files
*.o
......
......@@ -24,20 +24,20 @@
namespace mindspore {
namespace inference {
enum Status { SUCCESS = 0, FAILED, INVALID_INPUTS };
class MS_API InferSession {
public:
InferSession() = default;
virtual ~InferSession() = default;
virtual bool InitEnv(const std::string &device_type, uint32_t device_id) = 0;
virtual bool FinalizeEnv() = 0;
virtual bool LoadModelFromFile(const std::string &file_name, uint32_t &model_id) = 0;
virtual bool UnloadModel(uint32_t model_id) = 0;
virtual Status InitEnv(const std::string &device_type, uint32_t device_id) = 0;
virtual Status FinalizeEnv() = 0;
virtual Status LoadModelFromFile(const std::string &file_name, uint32_t &model_id) = 0;
virtual Status UnloadModel(uint32_t model_id) = 0;
// override this method to avoid request/reply data copy
virtual bool ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) = 0;
virtual Status ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) = 0;
virtual bool ExecuteModel(uint32_t model_id, const std::vector<InferTensor> &inputs,
std::vector<InferTensor> &outputs) {
virtual Status ExecuteModel(uint32_t model_id, const std::vector<InferTensor> &inputs,
std::vector<InferTensor> &outputs) {
VectorInferTensorWrapRequest request(inputs);
VectorInferTensorWrapReply reply(outputs);
return ExecuteModel(model_id, request, reply);
......
......@@ -37,8 +37,8 @@ namespace mindspore::inference {
std::shared_ptr<InferSession> InferSession::CreateSession(const std::string &device, uint32_t device_id) {
try {
auto session = std::make_shared<MSInferSession>();
bool ret = session->InitEnv(device, device_id);
if (!ret) {
Status ret = session->InitEnv(device, device_id);
if (ret != SUCCESS) {
return nullptr;
}
return session;
......@@ -84,21 +84,21 @@ std::shared_ptr<std::vector<char>> MSInferSession::ReadFile(const std::string &f
return buf;
}
bool MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
Status MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
auto graphBuf = ReadFile(file_name);
if (graphBuf == nullptr) {
MS_LOG(ERROR) << "Read model file failed, file name is " << file_name.c_str();
return false;
return FAILED;
}
auto graph = LoadModel(graphBuf->data(), graphBuf->size(), device_type_);
if (graph == nullptr) {
MS_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
return false;
return FAILED;
}
bool ret = CompileGraph(graph, model_id);
if (!ret) {
Status ret = CompileGraph(graph, model_id);
if (ret != SUCCESS) {
MS_LOG(ERROR) << "Compile graph model failed, file name is " << file_name.c_str();
return false;
return FAILED;
}
MS_LOG(INFO) << "Load model from file " << file_name << " success";
......@@ -107,14 +107,14 @@ bool MSInferSession::LoadModelFromFile(const std::string &file_name, uint32_t &m
rtError_t rt_ret = rtCtxGetCurrent(&context_);
if (rt_ret != RT_ERROR_NONE || context_ == nullptr) {
MS_LOG(ERROR) << "the ascend device context is null";
return false;
return FAILED;
}
#endif
return true;
return SUCCESS;
}
bool MSInferSession::UnloadModel(uint32_t model_id) { return true; }
Status MSInferSession::UnloadModel(uint32_t model_id) { return SUCCESS; }
tensor::TensorPtr ServingTensor2MSTensor(const InferTensorBase &out_tensor) {
std::vector<int> shape;
......@@ -170,16 +170,16 @@ void MSTensor2ServingTensor(tensor::TensorPtr ms_tensor, InferTensorBase &out_te
out_tensor.set_data(ms_tensor->data_c(), ms_tensor->Size());
}
bool MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) {
Status MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) {
#ifdef ENABLE_D
if (context_ == nullptr) {
MS_LOG(ERROR) << "rtCtx is nullptr";
return false;
return FAILED;
}
rtError_t rt_ret = rtCtxSetCurrent(context_);
if (rt_ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "set Ascend rtCtx failed";
return false;
return FAILED;
}
#endif
......@@ -187,47 +187,47 @@ bool MSInferSession::ExecuteModel(uint32_t model_id, const RequestBase &request,
for (size_t i = 0; i < request.size(); i++) {
if (request[i] == nullptr) {
MS_LOG(ERROR) << "Execute Model " << model_id << " Failed, input tensor is null, index " << i;
return false;
return FAILED;
}
auto input = ServingTensor2MSTensor(*request[i]);
if (input == nullptr) {
MS_LOG(ERROR) << "Tensor convert failed";
return false;
return FAILED;
}
inputs.push_back(input);
}
if (!CheckModelInputs(model_id, inputs)) {
MS_LOG(ERROR) << "Check Model " << model_id << " Inputs Failed";
return false;
return INVALID_INPUTS;
}
vector<tensor::TensorPtr> outputs = RunGraph(model_id, inputs);
if (outputs.empty()) {
MS_LOG(ERROR) << "Execute Model " << model_id << " Failed";
return false;
return FAILED;
}
reply.clear();
for (const auto &tensor : outputs) {
auto out_tensor = reply.add();
if (out_tensor == nullptr) {
MS_LOG(ERROR) << "Execute Model " << model_id << " Failed, add output tensor failed";
return false;
return FAILED;
}
MSTensor2ServingTensor(tensor, *out_tensor);
}
return true;
return SUCCESS;
}
bool MSInferSession::FinalizeEnv() {
Status MSInferSession::FinalizeEnv() {
auto ms_context = MsContext::GetInstance();
if (ms_context == nullptr) {
MS_LOG(ERROR) << "Get Context failed!";
return false;
return FAILED;
}
if (!ms_context->CloseTsd()) {
MS_LOG(ERROR) << "Inference CloseTsd failed!";
return false;
return FAILED;
}
return true;
return SUCCESS;
}
std::shared_ptr<FuncGraph> MSInferSession::LoadModel(const char *model_buf, size_t size, const std::string &device) {
......@@ -292,16 +292,16 @@ void MSInferSession::RegAllOp() {
return;
}
bool MSInferSession::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id) {
Status MSInferSession::CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id) {
MS_ASSERT(session_impl_ != nullptr);
try {
auto graph_id = session_impl_->CompileGraph(NOT_NULL(funcGraphPtr));
py::gil_scoped_release gil_release;
model_id = graph_id;
return true;
return SUCCESS;
} catch (std::exception &e) {
MS_LOG(ERROR) << "Inference CompileGraph failed";
return false;
return FAILED;
}
}
......@@ -327,31 +327,31 @@ string MSInferSession::AjustTargetName(const std::string &device) {
}
}
bool MSInferSession::InitEnv(const std::string &device, uint32_t device_id) {
Status MSInferSession::InitEnv(const std::string &device, uint32_t device_id) {
RegAllOp();
auto ms_context = MsContext::GetInstance();
ms_context->set_execution_mode(kGraphMode);
ms_context->set_device_id(device_id);
auto ajust_device = AjustTargetName(device);
if (ajust_device == "") {
return false;
return FAILED;
}
ms_context->set_device_target(device);
session_impl_ = session::SessionFactory::Get().Create(ajust_device);
if (session_impl_ == nullptr) {
MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available.";
return false;
return FAILED;
}
session_impl_->Init(device_id);
if (ms_context == nullptr) {
MS_LOG(ERROR) << "Get Context failed!";
return false;
return FAILED;
}
if (!ms_context->OpenTsd()) {
MS_LOG(ERROR) << "Session init OpenTsd failed!";
return false;
return FAILED;
}
return true;
return SUCCESS;
}
bool MSInferSession::CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const {
......
......@@ -38,11 +38,11 @@ class MSInferSession : public InferSession {
MSInferSession();
~MSInferSession();
bool InitEnv(const std::string &device_type, uint32_t device_id) override;
bool FinalizeEnv() override;
bool LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override;
bool UnloadModel(uint32_t model_id) override;
bool ExecuteModel(uint32_t model_id, const RequestBase &inputs, ReplyBase &outputs) override;
Status InitEnv(const std::string &device_type, uint32_t device_id) override;
Status FinalizeEnv() override;
Status LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override;
Status UnloadModel(uint32_t model_id) override;
Status ExecuteModel(uint32_t model_id, const RequestBase &inputs, ReplyBase &outputs) override;
private:
std::shared_ptr<session::SessionBasic> session_impl_ = nullptr;
......@@ -57,7 +57,7 @@ class MSInferSession : public InferSession {
std::shared_ptr<std::vector<char>> ReadFile(const std::string &file);
static void RegAllOp();
string AjustTargetName(const std::string &device);
bool CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id);
Status CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr, uint32_t &model_id);
bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const;
std::vector<tensor::TensorPtr> RunGraph(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs);
};
......
......@@ -35,53 +35,53 @@ std::shared_ptr<InferSession> InferSession::CreateSession(const std::string &dev
}
}
bool AclSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
return model_process_.LoadModelFromFile(file_name, model_id);
Status AclSession::LoadModelFromFile(const std::string &file_name, uint32_t &model_id) {
return model_process_.LoadModelFromFile(file_name, model_id) ? SUCCESS : FAILED;
}
bool AclSession::UnloadModel(uint32_t model_id) {
Status AclSession::UnloadModel(uint32_t model_id) {
model_process_.UnLoad();
return true;
return SUCCESS;
}
bool AclSession::ExecuteModel(uint32_t model_id, const RequestBase &request,
ReplyBase &reply) { // set d context
Status AclSession::ExecuteModel(uint32_t model_id, const RequestBase &request,
ReplyBase &reply) { // set d context
aclError rt_ret = aclrtSetCurrentContext(context_);
if (rt_ret != ACL_ERROR_NONE) {
MSI_LOG_ERROR << "set the ascend device context failed";
return false;
return FAILED;
}
return model_process_.Execute(request, reply);
return model_process_.Execute(request, reply) ? SUCCESS : FAILED;
}
bool AclSession::InitEnv(const std::string &device_type, uint32_t device_id) {
Status AclSession::InitEnv(const std::string &device_type, uint32_t device_id) {
device_type_ = device_type;
device_id_ = device_id;
auto ret = aclInit(nullptr);
if (ret != ACL_ERROR_NONE) {
MSI_LOG_ERROR << "Execute aclInit Failed";
return false;
return FAILED;
}
MSI_LOG_INFO << "acl init success";
ret = aclrtSetDevice(device_id_);
if (ret != ACL_ERROR_NONE) {
MSI_LOG_ERROR << "acl open device " << device_id_ << " failed";
return false;
return FAILED;
}
MSI_LOG_INFO << "open device " << device_id_ << " success";
ret = aclrtCreateContext(&context_, device_id_);
if (ret != ACL_ERROR_NONE) {
MSI_LOG_ERROR << "acl create context failed";
return false;
return FAILED;
}
MSI_LOG_INFO << "create context success";
ret = aclrtCreateStream(&stream_);
if (ret != ACL_ERROR_NONE) {
MSI_LOG_ERROR << "acl create stream failed";
return false;
return FAILED;
}
MSI_LOG_INFO << "create stream success";
......@@ -89,17 +89,17 @@ bool AclSession::InitEnv(const std::string &device_type, uint32_t device_id) {
ret = aclrtGetRunMode(&run_mode);
if (ret != ACL_ERROR_NONE) {
MSI_LOG_ERROR << "acl get run mode failed";
return false;
return FAILED;
}
bool is_device = (run_mode == ACL_DEVICE);
model_process_.SetIsDevice(is_device);
MSI_LOG_INFO << "get run mode success is device input/output " << is_device;
MSI_LOG_INFO << "Init acl success, device id " << device_id_;
return true;
return SUCCESS;
}
bool AclSession::FinalizeEnv() {
Status AclSession::FinalizeEnv() {
aclError ret;
if (stream_ != nullptr) {
ret = aclrtDestroyStream(stream_);
......@@ -129,7 +129,7 @@ bool AclSession::FinalizeEnv() {
MSI_LOG_ERROR << "finalize acl failed";
}
MSI_LOG_INFO << "end to finalize acl";
return true;
return SUCCESS;
}
AclSession::AclSession() = default;
......
......@@ -32,11 +32,11 @@ class AclSession : public InferSession {
public:
AclSession();
bool InitEnv(const std::string &device_type, uint32_t device_id) override;
bool FinalizeEnv() override;
bool LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override;
bool UnloadModel(uint32_t model_id) override;
bool ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) override;
Status InitEnv(const std::string &device_type, uint32_t device_id) override;
Status FinalizeEnv() override;
Status LoadModelFromFile(const std::string &file_name, uint32_t &model_id) override;
Status UnloadModel(uint32_t model_id) override;
Status ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) override;
private:
std::string device_type_;
......
......@@ -31,6 +31,7 @@
#include "core/version_control/version_controller.h"
#include "core/util/file_system_operation.h"
#include "core/serving_tensor.h"
#include "util/status.h"
using ms_serving::MSService;
using ms_serving::PredictReply;
......@@ -79,9 +80,9 @@ Status Session::Predict(const PredictRequest &request, PredictReply &reply) {
auto ret = session_->ExecuteModel(graph_id_, serving_request, serving_reply);
MSI_LOG(INFO) << "run Predict finished";
if (!ret) {
if (Status(ret) != SUCCESS) {
MSI_LOG(ERROR) << "execute model return failed";
return FAILED;
return Status(ret);
}
return SUCCESS;
}
......@@ -97,9 +98,9 @@ Status Session::Warmup(const MindSporeModelPtr model) {
MSI_TIME_STAMP_START(LoadModelFromFile)
auto ret = session_->LoadModelFromFile(file_name, graph_id_);
MSI_TIME_STAMP_END(LoadModelFromFile)
if (!ret) {
if (Status(ret) != SUCCESS) {
MSI_LOG(ERROR) << "Load graph model failed, file name is " << file_name.c_str();
return FAILED;
return Status(ret);
}
model_loaded_ = true;
MSI_LOG(INFO) << "Session Warmup finished";
......@@ -119,12 +120,22 @@ namespace {
static const uint32_t uint32max = 0x7FFFFFFF;
std::promise<void> exit_requested;
void ClearEnv() {
Session::Instance().Clear();
// inference::ExitInference();
}
void ClearEnv() { Session::Instance().Clear(); }
void HandleSignal(int sig) { exit_requested.set_value(); }
grpc::Status CreatGRPCStatus(Status status) {
switch (status) {
case SUCCESS:
return grpc::Status::OK;
case FAILED:
return grpc::Status::CANCELLED;
case INVALID_INPUTS:
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "The Predict Inputs do not match the Model Request!");
default:
return grpc::Status::CANCELLED;
}
}
} // namespace
// Service Implement
......@@ -134,8 +145,8 @@ class MSServiceImpl final : public MSService::Service {
MSI_TIME_STAMP_START(Predict)
auto res = Session::Instance().Predict(*request, *reply);
MSI_TIME_STAMP_END(Predict)
if (res != SUCCESS) {
return grpc::Status::CANCELLED;
if (res != inference::SUCCESS) {
return CreatGRPCStatus(res);
}
MSI_LOG(INFO) << "Finish call service Eval";
return grpc::Status::OK;
......
......@@ -18,7 +18,7 @@
namespace mindspore {
namespace serving {
using Status = uint32_t;
enum ServingStatus { SUCCESS = 0, FAILED };
enum ServingStatus { SUCCESS = 0, FAILED, INVALID_INPUTS };
} // namespace serving
} // namespace mindspore
......
......@@ -31,51 +31,51 @@ using ms_serving::TensorShape;
class MSClient {
public:
explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
~MSClient() = default;
~MSClient() = default;
std::string Predict() {
// Data we are sending to the server.
PredictRequest request;
std::string Predict() {
// Data we are sending to the server.
PredictRequest request;
Tensor data;
TensorShape shape;
shape.add_dims(4);
*data.mutable_tensor_shape() = shape;
data.set_tensor_type(ms_serving::MS_FLOAT32);
std::vector<float> input_data{1, 2, 3, 4};
data.set_data(input_data.data(), input_data.size() * sizeof(float));
*request.add_data() = data;
*request.add_data() = data;
std::cout << "intput tensor size is " << request.data_size() << std::endl;
// Container for the data we expect from the server.
PredictReply reply;
Tensor data;
TensorShape shape;
shape.add_dims(4);
*data.mutable_tensor_shape() = shape;
data.set_tensor_type(ms_serving::MS_FLOAT32);
std::vector<float> input_data{1, 2, 3, 4};
data.set_data(input_data.data(), input_data.size() * sizeof(float));
*request.add_data() = data;
*request.add_data() = data;
std::cout << "intput tensor size is " << request.data_size() << std::endl;
// Container for the data we expect from the server.
PredictReply reply;
// Context for the client. It could be used to convey extra information to
// the server and/or tweak certain RPC behaviors.
ClientContext context;
// Context for the client. It could be used to convey extra information to
// the server and/or tweak certain RPC behaviors.
ClientContext context;
// The actual RPC.
Status status = stub_->Predict(&context, request, &reply);
std::cout << "Compute [1, 2, 3, 4] + [1, 2, 3, 4]" << std::endl;
// The actual RPC.
Status status = stub_->Predict(&context, request, &reply);
std::cout << "Compute [1, 2, 3, 4] + [1, 2, 3, 4]" << std::endl;
// Act upon its status.
if (status.ok()) {
std::cout << "Add result is";
for (size_t i = 0; i < reply.result(0).data().size() / sizeof(float); i++) {
std::cout << " " << (reinterpret_cast<const float *>(reply.mutable_result(0)->mutable_data()->data()))[i];
}
std::cout << std::endl;
// Act upon its status.
if (status.ok()) {
return "RPC OK";
} else {
std::cout << status.error_code() << ": " << status.error_message() << std::endl;
return "RPC failed";
}
return "RPC OK";
} else {
std::cout << status.error_code() << ": " << status.error_message() << std::endl;
return "RPC failed";
}
}
private:
std::unique_ptr<MSService::Stub> stub_;
std::unique_ptr<MSService::Stub> stub_;
};
int main(int argc, char **argv) {
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import sys
import grpc
import numpy as np
import ms_service_pb2
......@@ -19,7 +20,19 @@ import ms_service_pb2_grpc
def run():
channel = grpc.insecure_channel('localhost:5500')
if len(sys.argv) > 2:
sys.exit("input error")
channel_str = ""
if len(sys.argv) == 2:
split_args = sys.argv[1].split('=')
if len(split_args) > 1:
channel_str = split_args[1]
else:
channel_str = 'localhost:5500'
else:
channel_str = 'localhost:5500'
channel = grpc.insecure_channel(channel_str)
stub = ms_service_pb2_grpc.MSServiceStub(channel)
request = ms_service_pb2.PredictRequest()
......@@ -33,11 +46,17 @@ def run():
y.tensor_type = ms_service_pb2.MS_FLOAT32
y.data = (np.ones([4]).astype(np.float32)).tobytes()
result = stub.Predict(request)
print(result)
result_np = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims)
print("ms client received: ")
print(result_np)
try:
result = stub.Predict(request)
print(result)
result_np = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims)
print("ms client received: ")
print(result_np)
except grpc.RpcError as e:
print(e.details())
status_code = e.code()
print(status_code.name)
print(status_code.value)
if __name__ == '__main__':
run()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册