提交 ba5fa502 编写于 作者: H hexia

restructure

上级 e09d50e4
......@@ -106,6 +106,7 @@ endif() # NOT ENABLE_ACL
if (ENABLE_SERVING)
add_subdirectory(serving)
add_subdirectory(serving/example/cpp_client)
endif()
if (NOT ENABLE_ACL)
......
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: ms_service.proto
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='ms_service.proto',
package='ms_serving',
syntax='proto3',
serialized_options=None,
serialized_pb=b'\n\x10ms_service.proto\x12\nms_serving\"2\n\x0ePredictRequest\x12 \n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x12.ms_serving.Tensor\"2\n\x0cPredictReply\x12\"\n\x06result\x18\x01 \x03(\x0b\x32\x12.ms_serving.Tensor\"\x1b\n\x0bTensorShape\x12\x0c\n\x04\x64ims\x18\x01 \x03(\x03\"p\n\x06Tensor\x12-\n\x0ctensor_shape\x18\x01 \x01(\x0b\x32\x17.ms_serving.TensorShape\x12)\n\x0btensor_type\x18\x02 \x01(\x0e\x32\x14.ms_serving.DataType\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c*\xc9\x01\n\x08\x44\x61taType\x12\x0e\n\nMS_UNKNOWN\x10\x00\x12\x0b\n\x07MS_BOOL\x10\x01\x12\x0b\n\x07MS_INT8\x10\x02\x12\x0c\n\x08MS_UINT8\x10\x03\x12\x0c\n\x08MS_INT16\x10\x04\x12\r\n\tMS_UINT16\x10\x05\x12\x0c\n\x08MS_INT32\x10\x06\x12\r\n\tMS_UINT32\x10\x07\x12\x0c\n\x08MS_INT64\x10\x08\x12\r\n\tMS_UINT64\x10\t\x12\x0e\n\nMS_FLOAT16\x10\n\x12\x0e\n\nMS_FLOAT32\x10\x0b\x12\x0e\n\nMS_FLOAT64\x10\x0c\x32\x8e\x01\n\tMSService\x12\x41\n\x07Predict\x12\x1a.ms_serving.PredictRequest\x1a\x18.ms_serving.PredictReply\"\x00\x12>\n\x04Test\x12\x1a.ms_serving.PredictRequest\x1a\x18.ms_serving.PredictReply\"\x00\x62\x06proto3'
)
_DATATYPE = _descriptor.EnumDescriptor(
name='DataType',
full_name='ms_serving.DataType',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='MS_UNKNOWN', index=0, number=0,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MS_BOOL', index=1, number=1,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MS_INT8', index=2, number=2,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MS_UINT8', index=3, number=3,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MS_INT16', index=4, number=4,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MS_UINT16', index=5, number=5,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MS_INT32', index=6, number=6,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MS_UINT32', index=7, number=7,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MS_INT64', index=8, number=8,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MS_UINT64', index=9, number=9,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MS_FLOAT16', index=10, number=10,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MS_FLOAT32', index=11, number=11,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MS_FLOAT64', index=12, number=12,
serialized_options=None,
type=None),
],
containing_type=None,
serialized_options=None,
serialized_start=280,
serialized_end=481,
)
_sym_db.RegisterEnumDescriptor(_DATATYPE)
DataType = enum_type_wrapper.EnumTypeWrapper(_DATATYPE)
MS_UNKNOWN = 0
MS_BOOL = 1
MS_INT8 = 2
MS_UINT8 = 3
MS_INT16 = 4
MS_UINT16 = 5
MS_INT32 = 6
MS_UINT32 = 7
MS_INT64 = 8
MS_UINT64 = 9
MS_FLOAT16 = 10
MS_FLOAT32 = 11
MS_FLOAT64 = 12
_PREDICTREQUEST = _descriptor.Descriptor(
name='PredictRequest',
full_name='ms_serving.PredictRequest',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='data', full_name='ms_serving.PredictRequest.data', index=0,
number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=32,
serialized_end=82,
)
_PREDICTREPLY = _descriptor.Descriptor(
name='PredictReply',
full_name='ms_serving.PredictReply',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='result', full_name='ms_serving.PredictReply.result', index=0,
number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=84,
serialized_end=134,
)
_TENSORSHAPE = _descriptor.Descriptor(
name='TensorShape',
full_name='ms_serving.TensorShape',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='dims', full_name='ms_serving.TensorShape.dims', index=0,
number=1, type=3, cpp_type=2, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=136,
serialized_end=163,
)
_TENSOR = _descriptor.Descriptor(
name='Tensor',
full_name='ms_serving.Tensor',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='tensor_shape', full_name='ms_serving.Tensor.tensor_shape', index=0,
number=1, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='tensor_type', full_name='ms_serving.Tensor.tensor_type', index=1,
number=2, type=14, cpp_type=8, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='data', full_name='ms_serving.Tensor.data', index=2,
number=3, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=165,
serialized_end=277,
)
_PREDICTREQUEST.fields_by_name['data'].message_type = _TENSOR
_PREDICTREPLY.fields_by_name['result'].message_type = _TENSOR
_TENSOR.fields_by_name['tensor_shape'].message_type = _TENSORSHAPE
_TENSOR.fields_by_name['tensor_type'].enum_type = _DATATYPE
DESCRIPTOR.message_types_by_name['PredictRequest'] = _PREDICTREQUEST
DESCRIPTOR.message_types_by_name['PredictReply'] = _PREDICTREPLY
DESCRIPTOR.message_types_by_name['TensorShape'] = _TENSORSHAPE
DESCRIPTOR.message_types_by_name['Tensor'] = _TENSOR
DESCRIPTOR.enum_types_by_name['DataType'] = _DATATYPE
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
PredictRequest = _reflection.GeneratedProtocolMessageType('PredictRequest', (_message.Message,), {
'DESCRIPTOR' : _PREDICTREQUEST,
'__module__' : 'ms_service_pb2'
# @@protoc_insertion_point(class_scope:ms_serving.PredictRequest)
})
_sym_db.RegisterMessage(PredictRequest)
PredictReply = _reflection.GeneratedProtocolMessageType('PredictReply', (_message.Message,), {
'DESCRIPTOR' : _PREDICTREPLY,
'__module__' : 'ms_service_pb2'
# @@protoc_insertion_point(class_scope:ms_serving.PredictReply)
})
_sym_db.RegisterMessage(PredictReply)
TensorShape = _reflection.GeneratedProtocolMessageType('TensorShape', (_message.Message,), {
'DESCRIPTOR' : _TENSORSHAPE,
'__module__' : 'ms_service_pb2'
# @@protoc_insertion_point(class_scope:ms_serving.TensorShape)
})
_sym_db.RegisterMessage(TensorShape)
Tensor = _reflection.GeneratedProtocolMessageType('Tensor', (_message.Message,), {
'DESCRIPTOR' : _TENSOR,
'__module__' : 'ms_service_pb2'
# @@protoc_insertion_point(class_scope:ms_serving.Tensor)
})
_sym_db.RegisterMessage(Tensor)
_MSSERVICE = _descriptor.ServiceDescriptor(
name='MSService',
full_name='ms_serving.MSService',
file=DESCRIPTOR,
index=0,
serialized_options=None,
serialized_start=484,
serialized_end=626,
methods=[
_descriptor.MethodDescriptor(
name='Predict',
full_name='ms_serving.MSService.Predict',
index=0,
containing_service=None,
input_type=_PREDICTREQUEST,
output_type=_PREDICTREPLY,
serialized_options=None,
),
_descriptor.MethodDescriptor(
name='Test',
full_name='ms_serving.MSService.Test',
index=1,
containing_service=None,
input_type=_PREDICTREQUEST,
output_type=_PREDICTREPLY,
serialized_options=None,
),
])
_sym_db.RegisterServiceDescriptor(_MSSERVICE)
DESCRIPTOR.services_by_name['MSService'] = _MSSERVICE
# @@protoc_insertion_point(module_scope)
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
import grpc
import ms_service_pb2 as ms__service__pb2
class MSServiceStub(object):
"""Missing associated documentation comment in .proto file"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.Predict = channel.unary_unary(
'/ms_serving.MSService/Predict',
request_serializer=ms__service__pb2.PredictRequest.SerializeToString,
response_deserializer=ms__service__pb2.PredictReply.FromString,
)
self.Test = channel.unary_unary(
'/ms_serving.MSService/Test',
request_serializer=ms__service__pb2.PredictRequest.SerializeToString,
response_deserializer=ms__service__pb2.PredictReply.FromString,
)
class MSServiceServicer(object):
"""Missing associated documentation comment in .proto file"""
def Predict(self, request, context):
"""Missing associated documentation comment in .proto file"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Test(self, request, context):
"""Missing associated documentation comment in .proto file"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_MSServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'Predict': grpc.unary_unary_rpc_method_handler(
servicer.Predict,
request_deserializer=ms__service__pb2.PredictRequest.FromString,
response_serializer=ms__service__pb2.PredictReply.SerializeToString,
),
'Test': grpc.unary_unary_rpc_method_handler(
servicer.Test,
request_deserializer=ms__service__pb2.PredictRequest.FromString,
response_serializer=ms__service__pb2.PredictReply.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'ms_serving.MSService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class MSService(object):
"""Missing associated documentation comment in .proto file"""
@staticmethod
def Predict(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/ms_serving.MSService/Predict',
ms__service__pb2.PredictRequest.SerializeToString,
ms__service__pb2.PredictReply.FromString,
options, channel_credentials,
call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def Test(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/ms_serving.MSService/Test',
ms__service__pb2.PredictRequest.SerializeToString,
ms__service__pb2.PredictReply.FromString,
options, channel_credentials,
call_credentials, compression, wait_for_ready, timeout, metadata)
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <grpcpp/grpcpp.h>
#include <grpcpp/health_check_service_interface.h>
#include <grpcpp/ext/proto_server_reflection_plugin.h>
#include <iostream>
#include "./ms_service.grpc.pb.h"
using grpc::Server;
using grpc::ServerBuilder;
using grpc::ServerContext;
using grpc::Status;
using ms_serving::MSService;
using ms_serving::PredictReply;
using ms_serving::PredictRequest;
// Logic and data behind the server's behavior.
class MSServiceImpl final : public MSService::Service {
Status Predict(ServerContext *context, const PredictRequest *request, PredictReply *reply) override {
std::cout << "server eval" << std::endl;
return Status::OK;
}
};
void RunServer() {
std::string server_address("0.0.0.0:50051");
MSServiceImpl service;
grpc::EnableDefaultHealthCheckService(true);
grpc::reflection::InitProtoReflectionServerBuilderPlugin();
auto option = grpc::MakeChannelArgumentOption(GRPC_ARG_ALLOW_REUSEPORT, 0);
ServerBuilder builder;
builder.SetOption(std::move(option));
// Listen on the given address without any authentication mechanism.
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
// Register "service" as the instance through which we'll communicate with
// clients. In this case it corresponds to an *synchronous* service.
builder.RegisterService(&service);
// Finally assemble the server.
std::unique_ptr<Server> server(builder.BuildAndStart());
std::cout << "Server listening on " << server_address << std::endl;
// Wait for the server to shutdown. Note that some other thread must be
// responsible for shutting down the server for this call to ever return.
server->Wait();
}
int main(int argc, char **argv) {
RunServer();
return 0;
}
cmake_minimum_required(VERSION 3.5.1)
project(HelloWorld C CXX)
project(MSClient C CXX)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
......@@ -12,17 +12,33 @@ find_package(Threads REQUIRED)
# Find Protobuf installation
# Looks for protobuf-config.cmake file installed by Protobuf's cmake installation.
set(protobuf_MODULE_COMPATIBLE TRUE)
find_package(Protobuf CONFIG REQUIRED)
message(STATUS "Using protobuf ${protobuf_VERSION}")
option(GRPC_PATH "set grpc path")
if(GRPC_PATH)
set(CMAKE_PREFIX_PATH ${GRPC_PATH})
set(protobuf_MODULE_COMPATIBLE TRUE)
find_package(Protobuf CONFIG REQUIRED)
message(STATUS "Using protobuf ${protobuf_VERSION}, CMAKE_PREFIX_PATH : ${CMAKE_PREFIX_PATH}")
elseif(NOT GRPC_PATH)
if (EXISTS ${grpc_ROOT}/lib64)
set(gRPC_DIR "${grpc_ROOT}/lib64/cmake/grpc")
elseif(EXISTS ${grpc_ROOT}/lib)
set(gRPC_DIR "${grpc_ROOT}/lib/cmake/grpc")
endif()
add_library(protobuf::libprotobuf ALIAS protobuf::protobuf)
add_executable(protobuf::libprotoc ALIAS protobuf::protoc)
message(STATUS "serving using grpc_DIR : " ${gRPC_DIR})
elseif(NOT gRPC_DIR AND NOT GRPC_PATH)
message("please check gRPC. If the client is compiled separately,you can use the command: cmake -D GRPC_PATH=xxx")
message("XXX is the gRPC installation path")
endif()
set(_PROTOBUF_LIBPROTOBUF protobuf::libprotobuf)
set(_REFLECTION gRPC::grpc++_reflection)
if (CMAKE_CROSSCOMPILING)
find_program(_PROTOBUF_PROTOC protoc)
else ()
set(_PROTOBUF_PROTOC $<TARGET_FILE:protobuf::protoc>)
endif ()
if(CMAKE_CROSSCOMPILING)
find_program(_PROTOBUF_PROTOC protoc)
else()
set(_PROTOBUF_PROTOC $<TARGET_FILE:protobuf::protoc>)
endif()
# Find gRPC installation
# Looks for gRPCConfig.cmake file installed by gRPC's cmake installation.
......@@ -30,14 +46,14 @@ find_package(gRPC CONFIG REQUIRED)
message(STATUS "Using gRPC ${gRPC_VERSION}")
set(_GRPC_GRPCPP gRPC::grpc++)
if (CMAKE_CROSSCOMPILING)
find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin)
else ()
set(_GRPC_CPP_PLUGIN_EXECUTABLE $<TARGET_FILE:gRPC::grpc_cpp_plugin>)
endif ()
if(CMAKE_CROSSCOMPILING)
find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin)
else()
set(_GRPC_CPP_PLUGIN_EXECUTABLE $<TARGET_FILE:gRPC::grpc_cpp_plugin>)
endif()
# Proto file
get_filename_component(hw_proto "../ms_service.proto" ABSOLUTE)
get_filename_component(hw_proto "../../ms_service.proto" ABSOLUTE)
get_filename_component(hw_proto_path "${hw_proto}" PATH)
# Generated sources
......@@ -59,13 +75,13 @@ add_custom_command(
include_directories("${CMAKE_CURRENT_BINARY_DIR}")
# Targets greeter_[async_](client|server)
foreach (_target
ms_client ms_server)
add_executable(${_target} "${_target}.cc"
${hw_proto_srcs}
${hw_grpc_srcs})
target_link_libraries(${_target}
${_REFLECTION}
${_GRPC_GRPCPP}
${_PROTOBUF_LIBPROTOBUF})
endforeach ()
foreach(_target
ms_client)
add_executable(${_target} "${_target}.cc"
${hw_proto_srcs}
${hw_grpc_srcs})
target_link_libraries(${_target}
${_REFLECTION}
${_GRPC_GRPCPP}
${_PROTOBUF_LIBPROTOBUF})
endforeach()
......@@ -211,77 +211,12 @@ PredictRequest ReadBertInput() {
return request;
}
PredictRequest ReadLenetInput() {
size_t size;
auto buf = ReadFile("lenet_img.bin", &size);
if (buf == nullptr) {
std::cout << "read file failed" << std::endl;
return PredictRequest();
}
PredictRequest request;
auto cur = buf;
if (size > 0) {
Tensor data;
TensorShape shape;
// set type
data.set_tensor_type(ms_serving::MS_FLOAT32);
// set shape
shape.add_dims(size / sizeof(float));
*data.mutable_tensor_shape() = shape;
// set data
data.set_data(cur, size);
*request.add_data() = data;
}
std::cout << "get input data size " << size << std::endl;
return request;
}
PredictRequest ReadOtherInput(const std::string &data_file) {
size_t size;
auto buf = ReadFile(data_file.c_str(), &size);
if (buf == nullptr) {
std::cout << "read file failed" << std::endl;
return PredictRequest();
}
PredictRequest request;
auto cur = buf;
if (size > 0) {
Tensor data;
TensorShape shape;
// set type
data.set_tensor_type(ms_serving::MS_FLOAT32);
// set shape
shape.add_dims(size / sizeof(float));
*data.mutable_tensor_shape() = shape;
// set data
data.set_data(cur, size);
*request.add_data() = data;
}
std::cout << "get input data size " << size << std::endl;
return request;
}
template <class DT>
void print_array_item(const DT *data, size_t size) {
for (size_t i = 0; i < size && i < 100; i++) {
std::cout << data[i] << '\t';
if ((i + 1) % 10 == 0) {
std::cout << std::endl;
}
}
std::cout << std::endl;
}
class MSClient {
public:
explicit MSClient(std::shared_ptr<Channel> channel) : stub_(MSService::NewStub(channel)) {}
~MSClient() = default;
std::string Predict(const std::string &type, const std::string &data_file) {
std::string Predict(const std::string &type) {
// Data we are sending to the server.
PredictRequest request;
if (type == "add") {
......@@ -299,10 +234,6 @@ class MSClient {
*request.add_data() = data;
} else if (type == "bert") {
request = ReadBertInput();
} else if (type == "lenet") {
request = ReadLenetInput();
} else if (type == "other") {
request = ReadOtherInput(data_file);
} else {
std::cout << "type only support bert or add, but input is " << type << std::endl;
}
......@@ -325,20 +256,6 @@ class MSClient {
// Act upon its status.
if (status.ok()) {
for (size_t i = 0; i < reply.result_size(); i++) {
auto result = reply.result(i);
if (result.tensor_type() == ms_serving::DataType::MS_FLOAT32) {
print_array_item(reinterpret_cast<const float *>(result.data().data()), result.data().size() / sizeof(float));
} else if (result.tensor_type() == ms_serving::DataType::MS_INT32) {
print_array_item(reinterpret_cast<const int32_t *>(result.data().data()),
result.data().size() / sizeof(int32_t));
} else if (result.tensor_type() == ms_serving::DataType::MS_UINT32) {
print_array_item(reinterpret_cast<const uint32_t *>(result.data().data()),
result.data().size() / sizeof(uint32_t));
} else {
std::cout << "output datatype " << result.tensor_type() << std::endl;
}
}
return "RPC OK";
} else {
std::cout << status.error_code() << ": " << status.error_message() << std::endl;
......@@ -360,8 +277,6 @@ int main(int argc, char **argv) {
std::string arg_target_str("--target");
std::string type;
std::string arg_type_str("--type");
std::string arg_data_str("--data");
std::string data = "default_data.bin";
if (argc > 2) {
{
// parse target
......@@ -389,33 +304,19 @@ int main(int argc, char **argv) {
if (arg_val2[start_pos] == '=') {
type = arg_val2.substr(start_pos + 1);
} else {
std::cout << "The only correct argument syntax is --type=" << std::endl;
std::cout << "The only correct argument syntax is --target=" << std::endl;
return 0;
}
} else {
type = "add";
}
}
if (argc > 3) {
// parse type
std::string arg_val3 = argv[3];
size_t start_pos = arg_val3.find(arg_data_str);
if (start_pos != std::string::npos) {
start_pos += arg_data_str.size();
if (arg_val3[start_pos] == '=') {
data = arg_val3.substr(start_pos + 1);
} else {
std::cout << "The only correct argument syntax is --data=" << std::endl;
return 0;
}
}
}
} else {
target_str = "localhost:5500";
type = "add";
}
MSClient client(grpc::CreateChannel(target_str, grpc::InsecureChannelCredentials()));
std::string reply = client.Predict(type, data);
std::string reply = client.Predict(type);
std::cout << "client received: " << reply << std::endl;
return 0;
......
......@@ -12,44 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from concurrent import futures
import time
import grpc
import numpy as np
import ms_service_pb2
import ms_service_pb2_grpc
import test_cpu_lenet
import mindspore.context as context
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore.train.serialization import export
class MSService(ms_service_pb2_grpc.MSServiceServicer):
def Predict(self, request, context):
request_data = request.data
request_label = request.label
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
data_from_buffer = np.frombuffer(request_data.data, dtype=np.float32)
data_from_buffer = data_from_buffer.reshape(request_data.tensor_shape.dims)
data = Tensor(data_from_buffer)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.add = P.TensorAdd()
label_from_buffer = np.frombuffer(request_label.data, dtype=np.int32)
label_from_buffer = label_from_buffer.reshape(request_label.tensor_shape.dims)
label = Tensor(label_from_buffer)
def construct(self, x_, y_):
return self.add(x_, y_)
result = test_cpu_lenet.test_lenet(data, label)
result_reply = ms_service_pb2.PredictReply()
result_reply.result.tensor_shape.dims.extend(result.shape())
result_reply.result.data = result.asnumpy().tobytes()
return result_reply
x = np.ones(4).astype(np.float32)
y = np.ones(4).astype(np.float32)
def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
ms_service_pb2_grpc.add_MSServiceServicer_to_server(MSService(), server)
server.add_insecure_port('[::]:50051')
server.start()
try:
while True:
time.sleep(60*60*24) # one day in seconds
except KeyboardInterrupt:
server.stop(0)
def export_net():
add = Net()
output = add(Tensor(x), Tensor(y))
export(add, Tensor(x), Tensor(y), file_name='tensor_add.pb', file_format='BINARY')
print(x)
print(y)
print(output.asnumpy())
if __name__ == '__main__':
serve()
if __name__ == "__main__":
export_net()
\ No newline at end of file
......@@ -19,28 +19,25 @@ import ms_service_pb2_grpc
def run():
channel = grpc.insecure_channel('localhost:50051')
channel = grpc.insecure_channel('localhost:5050')
stub = ms_service_pb2_grpc.MSServiceStub(channel)
# request = ms_service_pb2.EvalRequest()
# request.name = 'haha'
# response = stub.Eval(request)
# print("ms client received: " + response.message)
request = ms_service_pb2.PredictRequest()
request.data.tensor_shape.dims.extend([32, 1, 32, 32])
request.data.tensor_type = ms_service_pb2.MS_FLOAT32
request.data.data = (np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01).tobytes()
request.label.tensor_shape.dims.extend([32])
request.label.tensor_type = ms_service_pb2.MS_INT32
request.label.data = np.ones([32]).astype(np.int32).tobytes()
result = stub.Test(request)
#result_np = np.frombuffer(result.result.data, dtype=np.float32).reshape(result.result.tensor_shape.dims)
print("ms client test call received: ")
#print(result_np)
x = request.data.add()
x.tensor_shape.dims.extend([4])
x.tensor_type = ms_service_pb2.MS_FLOAT32
x.data = (np.ones([4]).astype(np.float32)).tobytes()
y = request.data.add()
y.tensor_shape.dims.extend([4])
y.tensor_type = ms_service_pb2.MS_FLOAT32
y.data = (np.ones([4]).astype(np.float32)).tobytes()
result = stub.Predict(request)
print(result)
result_np = np.frombuffer(result.result[0].data, dtype=np.float32).reshape(result.result[0].tensor_shape.dims)
print("ms client received: ")
print(result_np)
if __name__ == '__main__':
run()
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import grpc
import numpy as np
import ms_service_pb2
import ms_service_pb2_grpc
def run():
channel = grpc.insecure_channel('localhost:50051')
stub = ms_service_pb2_grpc.MSServiceStub(channel)
# request = ms_service_pb2.PredictRequest()
# request.name = 'haha'
# response = stub.Eval(request)
# print("ms client received: " + response.message)
request = ms_service_pb2.PredictRequest()
request.data.tensor_shape.dims.extend([32, 1, 32, 32])
request.data.tensor_type = ms_service_pb2.MS_FLOAT32
request.data.data = (np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01).tobytes()
request.label.tensor_shape.dims.extend([32])
request.label.tensor_type = ms_service_pb2.MS_INT32
request.label.data = np.ones([32]).astype(np.int32).tobytes()
result = stub.Predict(request)
#result_np = np.frombuffer(result.result.data, dtype=np.float32).reshape(result.result.tensor_shape.dims)
print("ms client received: ")
#print(result_np)
# future_list = []
# times = 1000
# for i in range(times):
# async_future = stub.Eval.future(request)
# future_list.append(async_future)
# print("async call, future list add item " + str(i));
#
# for i in range(len(future_list)):
# async_result = future_list[i].result()
# print("ms client async get result of item " + str(i))
if __name__ == '__main__':
run()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Momentum
from mindspore.ops import operations as P
import ms_service_pb2
class LeNet(nn.Cell):
def __init__(self):
super(LeNet, self).__init__()
self.relu = P.ReLU()
self.batch_size = 32
self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
self.fc1 = nn.Dense(400, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10)
def construct(self, input_x):
output = self.conv1(input_x)
output = self.relu(output)
output = self.pool(output)
output = self.conv2(output)
output = self.relu(output)
output = self.pool(output)
output = self.reshape(output, (self.batch_size, -1))
output = self.fc1(output)
output = self.relu(output)
output = self.fc2(output)
output = self.relu(output)
output = self.fc3(output)
return output
def train(net, data, label):
learning_rate = 0.01
momentum = 0.9
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
train_network.set_train()
res = train_network(data, label)
print("+++++++++Loss+++++++++++++")
print(res)
print("+++++++++++++++++++++++++++")
assert res
return res
def test_lenet(data, label):
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
net = LeNet()
return train(net, data, label)
if __name__ == '__main__':
tensor = ms_service_pb2.Tensor()
tensor.tensor_shape.dim.extend([32, 1, 32, 32])
# tensor.tensor_shape.dim.add() = 1
# tensor.tensor_shape.dim.add() = 32
# tensor.tensor_shape.dim.add() = 32
tensor.tensor_type = ms_service_pb2.MS_FLOAT32
tensor.data = np.ones([32, 1, 32, 32]).astype(np.float32).tobytes()
data_from_buffer = np.frombuffer(tensor.data, dtype=np.float32)
print(tensor.tensor_shape.dim)
data_from_buffer = data_from_buffer.reshape(tensor.tensor_shape.dim)
print(data_from_buffer.shape)
input_data = Tensor(data_from_buffer * 0.01)
input_label = Tensor(np.ones([32]).astype(np.int32))
test_lenet(input_data, input_label)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册