Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
a3396f8a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a3396f8a
编写于
6月 23, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 23, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2479 add bert inference example in serving
Merge pull request !2479 from dinghao/master
上级
c8f26f79
c1a518ce
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
255 addition
and
38 deletion
+255
-38
mindspore/ccsrc/CMakeLists.txt
mindspore/ccsrc/CMakeLists.txt
+1
-1
serving/cpp_example/CMakeLists.txt
serving/cpp_example/CMakeLists.txt
+2
-3
serving/cpp_example/ms_client.cc
serving/cpp_example/ms_client.cc
+250
-32
serving/cpp_example/ms_server.cc
serving/cpp_example/ms_server.cc
+2
-2
未找到文件。
mindspore/ccsrc/CMakeLists.txt
浏览文件 @
a3396f8a
...
...
@@ -247,7 +247,7 @@ add_library(inference SHARED
${
CMAKE_CURRENT_SOURCE_DIR
}
/session/session.cc
${
LOAD_ONNX_SRC
}
)
target_link_libraries
(
inference PRIVATE
${
PYTHON_LIBRAR
Y
}
${
SECUREC_LIBRARY
}
target_link_libraries
(
inference PRIVATE
${
PYTHON_LIBRAR
IES
}
${
SECUREC_LIBRARY
}
-Wl,--whole-archive mindspore -Wl,--no-whole-archive mindspore_gvar mindspore::protobuf
)
if
(
ENABLE_CPU
)
...
...
serving/cpp_example/CMakeLists.txt
浏览文件 @
a3396f8a
...
...
@@ -2,9 +2,8 @@ cmake_minimum_required(VERSION 3.5.1)
project
(
HelloWorld C CXX
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-std=c++11"
)
add_compile_definitions
(
_GLIBCXX_USE_CXX11_ABI=0
)
find_package
(
Threads REQUIRED
)
...
...
@@ -69,4 +68,4 @@ foreach(_target
${
_REFLECTION
}
${
_GRPC_GRPCPP
}
${
_PROTOBUF_LIBPROTOBUF
}
)
endforeach
()
\ No newline at end of file
endforeach
()
serving/cpp_example/ms_client.cc
浏览文件 @
a3396f8a
...
...
@@ -15,8 +15,10 @@
*/
#include <grpcpp/grpcpp.h>
#include <iostream>
#include "serving/ms_service.grpc.pb.h"
#include <vector>
#include <string>
#include <fstream>
#include "./ms_service.grpc.pb.h"
using
grpc
::
Channel
;
using
grpc
::
ClientContext
;
...
...
@@ -27,26 +29,214 @@ using ms_serving::PredictRequest;
using
ms_serving
::
Tensor
;
using
ms_serving
::
TensorShape
;
class
MSClient
{
public:
explicit
MSClient
(
std
::
shared_ptr
<
Channel
>
channel
)
:
stub_
(
MSService
::
NewStub
(
channel
))
{}
enum
TypeId
:
int
{
kTypeUnknown
=
0
,
kMetaTypeBegin
=
kTypeUnknown
,
kMetaTypeType
,
// Type
kMetaTypeAnything
,
kMetaTypeObject
,
kMetaTypeTypeType
,
// TypeType
kMetaTypeProblem
,
kMetaTypeExternal
,
kMetaTypeNone
,
kMetaTypeNull
,
kMetaTypeEllipsis
,
kMetaTypeEnd
,
//
// Object types
//
kObjectTypeBegin
=
kMetaTypeEnd
,
kObjectTypeNumber
,
kObjectTypeString
,
kObjectTypeList
,
kObjectTypeTuple
,
kObjectTypeSlice
,
kObjectTypeKeyword
,
kObjectTypeTensorType
,
kObjectTypeClass
,
kObjectTypeDictionary
,
kObjectTypeFunction
,
kObjectTypeJTagged
,
kObjectTypeSymbolicKeyType
,
kObjectTypeEnvType
,
kObjectTypeRefKey
,
kObjectTypeRef
,
kObjectTypeEnd
,
//
// Number Types
//
kNumberTypeBegin
=
kObjectTypeEnd
,
kNumberTypeBool
,
kNumberTypeInt
,
kNumberTypeInt8
,
kNumberTypeInt16
,
kNumberTypeInt32
,
kNumberTypeInt64
,
kNumberTypeUInt
,
kNumberTypeUInt8
,
kNumberTypeUInt16
,
kNumberTypeUInt32
,
kNumberTypeUInt64
,
kNumberTypeFloat
,
kNumberTypeFloat16
,
kNumberTypeFloat32
,
kNumberTypeFloat64
,
kNumberTypeEnd
};
std
::
string
Predict
(
const
std
::
string
&
user
)
{
// Data we are sending to the server.
PredictRequest
request
;
std
::
string
RealPath
(
const
char
*
path
)
{
if
(
path
==
nullptr
)
{
std
::
cout
<<
"path is nullptr"
;
return
""
;
}
if
((
strlen
(
path
))
>=
PATH_MAX
)
{
std
::
cout
<<
"path is too long"
;
return
""
;
}
std
::
shared_ptr
<
char
>
resolvedPath
(
new
(
std
::
nothrow
)
char
[
PATH_MAX
]{
0
});
if
(
resolvedPath
==
nullptr
)
{
std
::
cout
<<
"new resolvedPath failed"
;
return
""
;
}
auto
ret
=
realpath
(
path
,
resolvedPath
.
get
());
if
(
ret
==
nullptr
)
{
std
::
cout
<<
"realpath failed"
;
return
""
;
}
return
resolvedPath
.
get
();
}
char
*
ReadFile
(
const
char
*
file
,
size_t
*
size
)
{
if
(
file
==
nullptr
)
{
std
::
cout
<<
"file is nullptr"
<<
std
::
endl
;
return
nullptr
;
}
if
(
size
==
nullptr
)
{
std
::
cout
<<
"size should not be nullptr"
<<
std
::
endl
;
return
nullptr
;
}
std
::
ifstream
ifs
(
RealPath
(
file
));
if
(
!
ifs
.
good
())
{
std
::
cout
<<
"file: "
<<
file
<<
"is not exist"
;
return
nullptr
;
}
if
(
!
ifs
.
is_open
())
{
std
::
cout
<<
"file: "
<<
file
<<
"open failed"
;
return
nullptr
;
}
ifs
.
seekg
(
0
,
std
::
ios
::
end
);
*
size
=
ifs
.
tellg
();
std
::
unique_ptr
<
char
>
buf
(
new
(
std
::
nothrow
)
char
[
*
size
]);
if
(
buf
==
nullptr
)
{
std
::
cout
<<
"malloc buf failed, file: "
<<
file
;
ifs
.
close
();
return
nullptr
;
}
ifs
.
seekg
(
0
,
std
::
ios
::
beg
);
ifs
.
read
(
buf
.
get
(),
*
size
);
ifs
.
close
();
return
buf
.
release
();
}
const
std
::
map
<
TypeId
,
ms_serving
::
DataType
>
id2type_map
{
{
TypeId
::
kNumberTypeBegin
,
ms_serving
::
MS_UNKNOWN
},
{
TypeId
::
kNumberTypeBool
,
ms_serving
::
MS_BOOL
},
{
TypeId
::
kNumberTypeInt8
,
ms_serving
::
MS_INT8
},
{
TypeId
::
kNumberTypeUInt8
,
ms_serving
::
MS_UINT8
},
{
TypeId
::
kNumberTypeInt16
,
ms_serving
::
MS_INT16
},
{
TypeId
::
kNumberTypeUInt16
,
ms_serving
::
MS_UINT16
},
{
TypeId
::
kNumberTypeInt32
,
ms_serving
::
MS_INT32
},
{
TypeId
::
kNumberTypeUInt32
,
ms_serving
::
MS_UINT32
},
{
TypeId
::
kNumberTypeInt64
,
ms_serving
::
MS_INT64
},
{
TypeId
::
kNumberTypeUInt64
,
ms_serving
::
MS_UINT64
},
{
TypeId
::
kNumberTypeFloat16
,
ms_serving
::
MS_FLOAT16
},
{
TypeId
::
kNumberTypeFloat32
,
ms_serving
::
MS_FLOAT32
},
{
TypeId
::
kNumberTypeFloat64
,
ms_serving
::
MS_FLOAT64
},
};
int
WriteFile
(
const
void
*
buf
,
size_t
size
)
{
auto
fd
=
fopen
(
"output.json"
,
"a+"
);
if
(
fd
==
NULL
)
{
std
::
cout
<<
"fd is null and open file fail"
<<
std
::
endl
;
return
0
;
}
fwrite
(
buf
,
size
,
1
,
fd
);
fclose
(
fd
);
return
0
;
}
PredictRequest
ReadBertInput
()
{
size_t
size
;
auto
buf
=
ReadFile
(
"input206.json"
,
&
size
);
if
(
buf
==
nullptr
)
{
std
::
cout
<<
"read file failed"
<<
std
::
endl
;
return
PredictRequest
();
}
PredictRequest
request
;
auto
cur
=
buf
;
while
(
size
>
0
)
{
if
(
request
.
data_size
()
==
4
)
{
break
;
}
Tensor
data
;
TensorShape
shape
;
shape
.
add_dims
(
1
);
shape
.
add_dims
(
1
);
shape
.
add_dims
(
2
);
shape
.
add_dims
(
2
);
// set type
int
type
=
*
(
reinterpret_cast
<
int
*>
(
cur
));
cur
=
cur
+
sizeof
(
int
);
size
=
size
-
sizeof
(
int
);
ms_serving
::
DataType
dataType
=
id2type_map
.
at
(
TypeId
(
type
));
data
.
set_tensor_type
(
dataType
);
// set shape
size_t
dims
=
*
(
reinterpret_cast
<
size_t
*>
(
cur
));
cur
=
cur
+
sizeof
(
size_t
);
size
=
size
-
sizeof
(
size_t
);
for
(
size_t
i
=
0
;
i
<
dims
;
i
++
)
{
int
dim
=
*
(
reinterpret_cast
<
int
*>
(
cur
));
shape
.
add_dims
(
dim
);
cur
=
cur
+
sizeof
(
int
);
size
=
size
-
sizeof
(
int
);
}
*
data
.
mutable_tensor_shape
()
=
shape
;
data
.
set_tensor_type
(
ms_serving
::
MS_FLOAT32
);
vector
<
float
>
input_data
{
1.1
,
2.1
,
3.1
,
4.1
};
data
.
set_data
(
input_data
.
data
(),
input_data
.
size
());
*
request
.
add_data
()
=
data
;
// set data
size_t
data_len
=
*
(
reinterpret_cast
<
size_t
*>
(
cur
));
cur
=
cur
+
sizeof
(
size_t
);
size
=
size
-
sizeof
(
size_t
);
data
.
set_data
(
cur
,
data_len
);
cur
=
cur
+
data_len
;
size
=
size
-
data_len
;
*
request
.
add_data
()
=
data
;
}
return
request
;
}
class
MSClient
{
public:
explicit
MSClient
(
std
::
shared_ptr
<
Channel
>
channel
)
:
stub_
(
MSService
::
NewStub
(
channel
))
{}
std
::
string
Predict
(
const
std
::
string
&
type
)
{
// Data we are sending to the server.
PredictRequest
request
;
if
(
type
==
"add"
)
{
Tensor
data
;
TensorShape
shape
;
shape
.
add_dims
(
1
);
shape
.
add_dims
(
1
);
shape
.
add_dims
(
2
);
shape
.
add_dims
(
2
);
*
data
.
mutable_tensor_shape
()
=
shape
;
data
.
set_tensor_type
(
ms_serving
::
MS_FLOAT32
);
std
::
vector
<
float
>
input_data
{
1.1
,
2.1
,
3.1
,
4.1
};
data
.
set_data
(
input_data
.
data
(),
input_data
.
size
());
*
request
.
add_data
()
=
data
;
*
request
.
add_data
()
=
data
;
}
else
if
(
type
==
"bert"
)
{
request
=
ReadBertInput
();
}
else
{
std
::
cout
<<
"type only support bert or add, but input is "
<<
type
<<
std
::
endl
;
}
std
::
cout
<<
"intput tensor size is "
<<
request
.
data_size
()
<<
std
::
endl
;
// Container for the data we expect from the server.
PredictReply
reply
;
...
...
@@ -57,6 +247,12 @@ class MSClient {
// The actual RPC.
Status
status
=
stub_
->
Predict
(
&
context
,
request
,
&
reply
);
for
(
int
i
=
0
;
i
<
reply
.
result_size
();
i
++
)
{
WriteFile
(
reply
.
result
(
i
).
data
().
data
(),
reply
.
result
(
i
).
data
().
size
());
}
std
::
cout
<<
"the return result size is "
<<
reply
.
result_size
()
<<
std
::
endl
;
// Act upon its status.
if
(
status
.
ok
())
{
return
"RPC OK"
;
...
...
@@ -77,28 +273,50 @@ int main(int argc, char **argv) {
// We indicate that the channel isn't authenticated (use of
// InsecureChannelCredentials()).
std
::
string
target_str
;
std
::
string
arg_str
(
"--target"
);
if
(
argc
>
1
)
{
std
::
string
arg_val
=
argv
[
1
];
size_t
start_pos
=
arg_val
.
find
(
arg_str
);
if
(
start_pos
!=
std
::
string
::
npos
)
{
start_pos
+=
arg_str
.
size
();
if
(
arg_val
[
start_pos
]
==
'='
)
{
target_str
=
arg_val
.
substr
(
start_pos
+
1
);
std
::
string
arg_target_str
(
"--target"
);
std
::
string
type
;
std
::
string
arg_type_str
(
"--type"
);
if
(
argc
>
2
)
{
{
// parse target
std
::
string
arg_val
=
argv
[
1
];
size_t
start_pos
=
arg_val
.
find
(
arg_target_str
);
if
(
start_pos
!=
std
::
string
::
npos
)
{
start_pos
+=
arg_target_str
.
size
();
if
(
arg_val
[
start_pos
]
==
'='
)
{
target_str
=
arg_val
.
substr
(
start_pos
+
1
);
}
else
{
std
::
cout
<<
"The only correct argument syntax is --target="
<<
std
::
endl
;
return
0
;
}
}
else
{
std
::
cout
<<
"The only correct argument syntax is --target="
<<
std
::
endl
;
return
0
;
target_str
=
"localhost:5500"
;
}
}
else
{
std
::
cout
<<
"The only acceptable argument is --target="
<<
std
::
endl
;
return
0
;
}
{
// parse type
std
::
string
arg_val2
=
argv
[
2
];
size_t
start_pos
=
arg_val2
.
find
(
arg_type_str
);
if
(
start_pos
!=
std
::
string
::
npos
)
{
start_pos
+=
arg_type_str
.
size
();
if
(
arg_val2
[
start_pos
]
==
'='
)
{
type
=
arg_val2
.
substr
(
start_pos
+
1
);
}
else
{
std
::
cout
<<
"The only correct argument syntax is --target="
<<
std
::
endl
;
return
0
;
}
}
else
{
type
=
"add"
;
}
}
}
else
{
target_str
=
"localhost:85010"
;
target_str
=
"localhost:5500"
;
type
=
"add"
;
}
MSClient
client
(
grpc
::
CreateChannel
(
target_str
,
grpc
::
InsecureChannelCredentials
()));
string
request
;
string
reply
=
client
.
Predict
(
request
);
std
::
string
reply
=
client
.
Predict
(
type
);
std
::
cout
<<
"client received: "
<<
reply
<<
std
::
endl
;
return
0
;
...
...
serving/cpp_example/ms_server.cc
浏览文件 @
a3396f8a
...
...
@@ -18,7 +18,7 @@
#include <grpcpp/ext/proto_server_reflection_plugin.h>
#include <iostream>
#include "
serving
/ms_service.grpc.pb.h"
#include "
.
/ms_service.grpc.pb.h"
using
grpc
::
Server
;
using
grpc
::
ServerBuilder
;
...
...
@@ -31,7 +31,7 @@ using ms_serving::PredictRequest;
// Logic and data behind the server's behavior.
class
MSServiceImpl
final
:
public
MSService
::
Service
{
Status
Predict
(
ServerContext
*
context
,
const
PredictRequest
*
request
,
PredictReply
*
reply
)
override
{
cout
<<
"server eval"
<<
endl
;
std
::
cout
<<
"server eval"
<<
std
::
endl
;
return
Status
::
OK
;
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录