提交 0a279b7f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2849 fix serving peformance

Merge pull request !2849 from dinghao/master
......@@ -31,14 +31,69 @@ using mindspore::tensor::TensorPy;
namespace mindspore {
namespace session {
namespace {
std::set<AnfNodePtr> weight_infos;
static TypeId GetDataType(const py::buffer_info &buf) {
if (buf.format.size() == 1) {
switch (buf.format.front()) {
case 'e':
case 'f':
case 'd':
switch (buf.itemsize) {
case 2:
return TypeId::kNumberTypeFloat16;
case 4:
return TypeId::kNumberTypeFloat32;
case 8:
return TypeId::kNumberTypeFloat64;
}
break;
case 'b':
case 'h':
case 'i':
case 'l':
case 'q':
switch (buf.itemsize) {
case 1:
return TypeId::kNumberTypeInt8;
case 2:
return TypeId::kNumberTypeInt16;
case 4:
return TypeId::kNumberTypeInt32;
case 8:
return TypeId::kNumberTypeInt64;
}
break;
case 'B':
case 'H':
case 'I':
case 'L':
case 'Q':
switch (buf.itemsize) {
case 1:
return TypeId::kNumberTypeUInt8;
case 2:
return TypeId::kNumberTypeUInt16;
case 4:
return TypeId::kNumberTypeUInt32;
case 8:
return TypeId::kNumberTypeUInt64;
}
break;
case '?':
return TypeId::kNumberTypeBool;
}
}
MS_LOG(WARNING) << "Unsupported DataType format " << buf.format << " item size " << buf.itemsize;
return TypeId::kTypeUnknown;
}
} // namespace
void AscendInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const {
MS_EXCEPTION_IF_NULL(kernel_graph);
std::vector<tensor::TensorPtr> inputs(inputs_const);
auto input_nodes = kernel_graph->inputs();
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
size_t no_weight_input = 0;
for (size_t i = 0; i < input_nodes.size(); ++i) {
tensor::TensorPtr tensor = nullptr;
......@@ -48,45 +103,32 @@ void AscendInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &k
}
auto pk_node = input_nodes[i]->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(pk_node);
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
MS_EXCEPTION_IF_NULL(device_address);
if (AnfAlgo::IsParameterWeight(pk_node)) {
if (weight_infos.count(pk_node) != 0) {
continue;
}
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(pk_node->default_param());
MS_EXCEPTION_IF_NULL(param_value);
auto py_param = param_value->value();
MS_EXCEPTION_IF_NULL(py_param);
py::array py_array = py_param.cast<py::array>();
tensor = TensorPy::MakeTensor(py_array);
py::buffer_info buf = py_array.request();
auto buf_type = GetDataType(buf);
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
LongToSize(buf.size * buf.itemsize), buf_type, buf.ptr)) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}
weight_infos.insert(pk_node);
} else {
tensor = inputs[no_weight_input++];
}
MS_EXCEPTION_IF_NULL(tensor);
if (AnfAlgo::OutputAddrExist(pk_node, 0)) {
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
bool need_sync = false;
if (ms_context->enable_pynative_infer()) {
if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) {
need_sync = true;
}
} else {
if (tensor->is_dirty()) {
need_sync = true;
} else if (tensor->device_address() != device_address) {
(void)tensor->data_sync();
need_sync = true;
}
}
if (need_sync) {
if (ms_context->execution_mode() == kPynativeMode || AnfAlgo::IsParameterWeight(pk_node)) {
tensor->set_device_address(device_address);
}
MS_EXCEPTION_IF_NULL(device_address);
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c())) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c())) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}
}
tensor->set_dirty(false);
}
}
} // namespace session
......
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: ms_service.proto
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='ms_service.proto',
package='ms_serving',
syntax='proto3',
serialized_options=None,
serialized_pb=b'\n\x10ms_service.proto\x12\nms_serving\"2\n\x0ePredictRequest\x12 \n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32\x12.ms_serving.Tensor\"2\n\x0cPredictReply\x12\"\n\x06result\x18\x01 \x03(\x0b\x32\x12.ms_serving.Tensor\"\x1b\n\x0bTensorShape\x12\x0c\n\x04\x64ims\x18\x01 \x03(\x03\"p\n\x06Tensor\x12-\n\x0ctensor_shape\x18\x01 \x01(\x0b\x32\x17.ms_serving.TensorShape\x12)\n\x0btensor_type\x18\x02 \x01(\x0e\x32\x14.ms_serving.DataType\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c*\xc9\x01\n\x08\x44\x61taType\x12\x0e\n\nMS_UNKNOWN\x10\x00\x12\x0b\n\x07MS_BOOL\x10\x01\x12\x0b\n\x07MS_INT8\x10\x02\x12\x0c\n\x08MS_UINT8\x10\x03\x12\x0c\n\x08MS_INT16\x10\x04\x12\r\n\tMS_UINT16\x10\x05\x12\x0c\n\x08MS_INT32\x10\x06\x12\r\n\tMS_UINT32\x10\x07\x12\x0c\n\x08MS_INT64\x10\x08\x12\r\n\tMS_UINT64\x10\t\x12\x0e\n\nMS_FLOAT16\x10\n\x12\x0e\n\nMS_FLOAT32\x10\x0b\x12\x0e\n\nMS_FLOAT64\x10\x0c\x32\x8e\x01\n\tMSService\x12\x41\n\x07Predict\x12\x1a.ms_serving.PredictRequest\x1a\x18.ms_serving.PredictReply\"\x00\x12>\n\x04Test\x12\x1a.ms_serving.PredictRequest\x1a\x18.ms_serving.PredictReply\"\x00\x62\x06proto3'
)
_DATATYPE = _descriptor.EnumDescriptor(
name='DataType',
full_name='ms_serving.DataType',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='MS_UNKNOWN', index=0, number=0,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MS_BOOL', index=1, number=1,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MS_INT8', index=2, number=2,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MS_UINT8', index=3, number=3,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MS_INT16', index=4, number=4,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MS_UINT16', index=5, number=5,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MS_INT32', index=6, number=6,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MS_UINT32', index=7, number=7,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MS_INT64', index=8, number=8,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MS_UINT64', index=9, number=9,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MS_FLOAT16', index=10, number=10,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MS_FLOAT32', index=11, number=11,
serialized_options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MS_FLOAT64', index=12, number=12,
serialized_options=None,
type=None),
],
containing_type=None,
serialized_options=None,
serialized_start=280,
serialized_end=481,
)
_sym_db.RegisterEnumDescriptor(_DATATYPE)
DataType = enum_type_wrapper.EnumTypeWrapper(_DATATYPE)
MS_UNKNOWN = 0
MS_BOOL = 1
MS_INT8 = 2
MS_UINT8 = 3
MS_INT16 = 4
MS_UINT16 = 5
MS_INT32 = 6
MS_UINT32 = 7
MS_INT64 = 8
MS_UINT64 = 9
MS_FLOAT16 = 10
MS_FLOAT32 = 11
MS_FLOAT64 = 12
_PREDICTREQUEST = _descriptor.Descriptor(
name='PredictRequest',
full_name='ms_serving.PredictRequest',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='data', full_name='ms_serving.PredictRequest.data', index=0,
number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=32,
serialized_end=82,
)
_PREDICTREPLY = _descriptor.Descriptor(
name='PredictReply',
full_name='ms_serving.PredictReply',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='result', full_name='ms_serving.PredictReply.result', index=0,
number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=84,
serialized_end=134,
)
_TENSORSHAPE = _descriptor.Descriptor(
name='TensorShape',
full_name='ms_serving.TensorShape',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='dims', full_name='ms_serving.TensorShape.dims', index=0,
number=1, type=3, cpp_type=2, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=136,
serialized_end=163,
)
_TENSOR = _descriptor.Descriptor(
name='Tensor',
full_name='ms_serving.Tensor',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='tensor_shape', full_name='ms_serving.Tensor.tensor_shape', index=0,
number=1, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='tensor_type', full_name='ms_serving.Tensor.tensor_type', index=1,
number=2, type=14, cpp_type=8, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='data', full_name='ms_serving.Tensor.data', index=2,
number=3, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=165,
serialized_end=277,
)
_PREDICTREQUEST.fields_by_name['data'].message_type = _TENSOR
_PREDICTREPLY.fields_by_name['result'].message_type = _TENSOR
_TENSOR.fields_by_name['tensor_shape'].message_type = _TENSORSHAPE
_TENSOR.fields_by_name['tensor_type'].enum_type = _DATATYPE
DESCRIPTOR.message_types_by_name['PredictRequest'] = _PREDICTREQUEST
DESCRIPTOR.message_types_by_name['PredictReply'] = _PREDICTREPLY
DESCRIPTOR.message_types_by_name['TensorShape'] = _TENSORSHAPE
DESCRIPTOR.message_types_by_name['Tensor'] = _TENSOR
DESCRIPTOR.enum_types_by_name['DataType'] = _DATATYPE
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
PredictRequest = _reflection.GeneratedProtocolMessageType('PredictRequest', (_message.Message,), {
'DESCRIPTOR' : _PREDICTREQUEST,
'__module__' : 'ms_service_pb2'
# @@protoc_insertion_point(class_scope:ms_serving.PredictRequest)
})
_sym_db.RegisterMessage(PredictRequest)
PredictReply = _reflection.GeneratedProtocolMessageType('PredictReply', (_message.Message,), {
'DESCRIPTOR' : _PREDICTREPLY,
'__module__' : 'ms_service_pb2'
# @@protoc_insertion_point(class_scope:ms_serving.PredictReply)
})
_sym_db.RegisterMessage(PredictReply)
TensorShape = _reflection.GeneratedProtocolMessageType('TensorShape', (_message.Message,), {
'DESCRIPTOR' : _TENSORSHAPE,
'__module__' : 'ms_service_pb2'
# @@protoc_insertion_point(class_scope:ms_serving.TensorShape)
})
_sym_db.RegisterMessage(TensorShape)
Tensor = _reflection.GeneratedProtocolMessageType('Tensor', (_message.Message,), {
'DESCRIPTOR' : _TENSOR,
'__module__' : 'ms_service_pb2'
# @@protoc_insertion_point(class_scope:ms_serving.Tensor)
})
_sym_db.RegisterMessage(Tensor)
_MSSERVICE = _descriptor.ServiceDescriptor(
name='MSService',
full_name='ms_serving.MSService',
file=DESCRIPTOR,
index=0,
serialized_options=None,
serialized_start=484,
serialized_end=626,
methods=[
_descriptor.MethodDescriptor(
name='Predict',
full_name='ms_serving.MSService.Predict',
index=0,
containing_service=None,
input_type=_PREDICTREQUEST,
output_type=_PREDICTREPLY,
serialized_options=None,
),
_descriptor.MethodDescriptor(
name='Test',
full_name='ms_serving.MSService.Test',
index=1,
containing_service=None,
input_type=_PREDICTREQUEST,
output_type=_PREDICTREPLY,
serialized_options=None,
),
])
_sym_db.RegisterServiceDescriptor(_MSSERVICE)
DESCRIPTOR.services_by_name['MSService'] = _MSSERVICE
# @@protoc_insertion_point(module_scope)
......@@ -259,7 +259,7 @@ Status Server::BuildAndStart() {
}
g_ctx = ctx;
#endif
MSServiceImpl msService;
MSServiceImpl ms_service;
grpc::EnableDefaultHealthCheckService(true);
grpc::reflection::InitProtoReflectionServerBuilderPlugin();
// Set the port is not reuseable
......@@ -268,7 +268,7 @@ Status Server::BuildAndStart() {
serverBuilder.SetOption(std::move(option));
serverBuilder.SetMaxMessageSize(uint32max);
serverBuilder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
serverBuilder.RegisterService(&msService);
serverBuilder.RegisterService(&ms_service);
std::unique_ptr<grpc::Server> server(serverBuilder.BuildAndStart());
if (server == nullptr) {
MS_LOG(ERROR) << "The serving server create failed";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册