Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Serving
提交
311fddf5
S
Serving
项目概览
PaddlePaddle
/
Serving
大约 1 年 前同步成功
通知
186
Star
833
Fork
253
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
105
列表
看板
标记
里程碑
合并请求
10
Wiki
2
Wiki
分析
仓库
DevOps
项目成员
Pages
S
Serving
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
105
Issue
105
列表
看板
标记
里程碑
合并请求
10
合并请求
10
Pages
分析
分析
仓库分析
DevOps
Wiki
2
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
311fddf5
编写于
3月 28, 2019
作者:
W
wangguibao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix pdcodegen
上级
6f98ec3b
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
731 addition
and
0 deletion
+731
-0
pdcodegen/src/pdcodegen.cpp
pdcodegen/src/pdcodegen.cpp
+731
-0
未找到文件。
pdcodegen/src/pdcodegen.cpp
浏览文件 @
311fddf5
...
...
@@ -83,6 +83,735 @@ bool valid_service_method(const std::vector<const MethodDescriptor*>& methods) {
}
return
false
;
}
#ifdef BCLOUD
class
PdsCodeGenerator
:
public
CodeGenerator
{
public:
virtual
bool
Generate
(
const
FileDescriptor
*
file
,
const
string
&
parameter
,
GeneratorContext
*
context
,
std
::
string
*
error
)
const
{
const
string
header
=
strip_proto
(
file
->
name
())
+
".pb.h"
;
const
string
body
=
strip_proto
(
file
->
name
())
+
".pb.cc"
;
bool
include_inserted
=
false
;
for
(
int
i
=
0
;
i
<
file
->
service_count
();
++
i
)
{
const
ServiceDescriptor
*
descriptor
=
file
->
service
(
i
);
if
(
!
descriptor
)
{
*
error
=
"get descriptor failed"
;
return
false
;
}
pds
::
PaddleServiceOption
options
=
descriptor
->
options
().
GetExtension
(
pds
::
options
);
bool
generate_impl
=
options
.
generate_impl
();
bool
generate_stub
=
options
.
generate_stub
();
if
(
!
generate_impl
&&
!
generate_stub
)
{
return
true
;
}
if
(
!
include_inserted
)
{
boost
::
scoped_ptr
<
google
::
protobuf
::
io
::
ZeroCopyOutputStream
>
output
(
context
->
OpenForInsert
(
header
,
"includes"
));
google
::
protobuf
::
io
::
Printer
printer
(
output
.
get
(),
'$'
);
if
(
generate_impl
)
{
printer
.
Print
(
"#include
\"
predictor/common/inner_common.h
\"\n
"
);
printer
.
Print
(
"#include
\"
predictor/framework/service.h
\"\n
"
);
printer
.
Print
(
"#include
\"
predictor/framework/manager.h
\"\n
"
);
printer
.
Print
(
"#include
\"
predictor/framework/service_manager.h
\"\n
"
);
}
if
(
generate_stub
)
{
printer
.
Print
(
"#include <baidu/rpc/parallel_channel.h>
\n
"
);
printer
.
Print
(
"#include
\"
sdk-cpp/include/factory.h
\"\n
"
);
printer
.
Print
(
"#include
\"
sdk-cpp/include/stub.h
\"\n
"
);
printer
.
Print
(
"#include
\"
sdk-cpp/include/stub_impl.h
\"\n
"
);
}
include_inserted
=
true
;
}
const
std
::
string
&
class_name
=
descriptor
->
name
();
const
std
::
string
&
service_name
=
descriptor
->
name
();
// xxx.ph.h
{
if
(
generate_impl
)
{
// service scope
// namespace scope
boost
::
scoped_ptr
<
google
::
protobuf
::
io
::
ZeroCopyOutputStream
>
output
(
context
->
OpenForInsert
(
header
,
"namespace_scope"
));
google
::
protobuf
::
io
::
Printer
printer
(
output
.
get
(),
'$'
);
if
(
!
generate_paddle_serving_head
(
&
printer
,
descriptor
,
error
,
service_name
,
class_name
))
{
return
false
;
}
}
if
(
generate_stub
)
{
// service class scope
// namespace scope
{
boost
::
scoped_ptr
<
google
::
protobuf
::
io
::
ZeroCopyOutputStream
>
output
(
context
->
OpenForInsert
(
header
,
"namespace_scope"
));
google
::
protobuf
::
io
::
Printer
printer
(
output
.
get
(),
'$'
);
if
(
!
generate_paddle_serving_stub_head
(
&
printer
,
descriptor
,
error
,
service_name
,
class_name
))
{
return
false
;
}
}
}
}
// xxx.pb.cc
{
if
(
generate_impl
)
{
// service scope
// namespace scope
boost
::
scoped_ptr
<
google
::
protobuf
::
io
::
ZeroCopyOutputStream
>
output
(
context
->
OpenForInsert
(
body
,
"namespace_scope"
));
google
::
protobuf
::
io
::
Printer
printer
(
output
.
get
(),
'$'
);
if
(
!
generate_paddle_serving_body
(
&
printer
,
descriptor
,
error
,
service_name
,
class_name
))
{
return
false
;
}
}
if
(
generate_stub
)
{
// service class scope
{}
// namespace scope
{
boost
::
scoped_ptr
<
google
::
protobuf
::
io
::
ZeroCopyOutputStream
>
output
(
context
->
OpenForInsert
(
body
,
"namespace_scope"
));
google
::
protobuf
::
io
::
Printer
printer
(
output
.
get
(),
'$'
);
if
(
!
generate_paddle_serving_stub_body
(
&
printer
,
descriptor
,
error
,
service_name
,
class_name
))
{
return
false
;
}
}
}
}
}
return
true
;
}
private:
bool
generate_paddle_serving_head
(
google
::
protobuf
::
io
::
Printer
*
printer
,
const
ServiceDescriptor
*
descriptor
,
string
*
error
,
const
std
::
string
&
service_name
,
const
std
::
string
&
class_name
)
const
{
std
::
vector
<
const
MethodDescriptor
*>
methods
;
for
(
int
i
=
0
;
i
<
descriptor
->
method_count
();
++
i
)
{
methods
.
push_back
(
descriptor
->
method
(
i
));
}
if
(
!
valid_service_method
(
methods
))
{
*
error
=
"Service can only contains two methods: inferend, debug"
;
return
false
;
}
std
::
string
variable_name
=
class_name
;
string_format
(
variable_name
);
printer
->
Print
(
"class $name$Impl : public $name$ {
\n
"
"public:
\n
"
" virtual ~$name$Impl() {}
\n
"
" static $name$Impl& instance() {
\n
"
" return _s_$variable_name$_impl;
\n
"
" }
\n\n
"
" $name$Impl(const std::string& service_name) {
\n
"
" REGIST_FORMAT_SERVICE(
\n
"
" service_name, &$name$Impl::instance());
\n
"
" }
\n\n
"
,
"name"
,
class_name
,
"variable_name"
,
variable_name
);
for
(
int
i
=
0
;
i
<
methods
.
size
();
i
++
)
{
const
MethodDescriptor
*
m
=
methods
[
i
];
printer
->
Print
(
" virtual void $name$(google::protobuf::RpcController* cntl_base,
\n
"
" const $input_name$* request,
\n
"
" $output_name$* response,
\n
"
" google::protobuf::Closure* done);
\n\n
"
,
"name"
,
m
->
name
(),
"input_name"
,
google
::
protobuf
::
dots_to_colons
(
m
->
input_type
()
->
full_name
()),
"output_name"
,
google
::
protobuf
::
dots_to_colons
(
m
->
output_type
()
->
full_name
()));
}
printer
->
Print
(
" static $name$Impl _s_$variable_name$_impl;
\n
"
"};"
,
"name"
,
class_name
,
"variable_name"
,
variable_name
);
return
true
;
}
bool
generate_paddle_serving_body
(
google
::
protobuf
::
io
::
Printer
*
printer
,
const
ServiceDescriptor
*
descriptor
,
string
*
error
,
const
std
::
string
&
service_name
,
const
std
::
string
&
class_name
)
const
{
std
::
vector
<
const
MethodDescriptor
*>
methods
;
for
(
int
i
=
0
;
i
<
descriptor
->
method_count
();
++
i
)
{
methods
.
push_back
(
descriptor
->
method
(
i
));
}
if
(
!
valid_service_method
(
methods
))
{
*
error
=
"Service can only contains two methods: inferend, debug"
;
return
false
;
}
std
::
string
variable_name
=
class_name
;
string_format
(
variable_name
);
for
(
int
i
=
0
;
i
<
methods
.
size
();
i
++
)
{
const
MethodDescriptor
*
m
=
methods
[
i
];
printer
->
Print
(
"void $name$Impl::$method$(
\n
"
,
"name"
,
class_name
,
"method"
,
m
->
name
());
printer
->
Print
(
" google::protobuf::RpcController* cntl_base,
\n
"
" const $input_name$* request,
\n
"
" $output_name$* response,
\n
"
" google::protobuf::Closure* done) {
\n
"
" struct timeval tv;
\n
"
" gettimeofday(&tv, NULL);"
" long start = tv.tv_sec * 1000000 + tv.tv_usec;"
,
"input_name"
,
google
::
protobuf
::
dots_to_colons
(
m
->
input_type
()
->
full_name
()),
"output_name"
,
google
::
protobuf
::
dots_to_colons
(
m
->
output_type
()
->
full_name
()));
if
(
m
->
name
()
==
"inference"
)
{
printer
->
Print
(
" baidu::rpc::ClosureGuard done_guard(done);
\n
"
" baidu::rpc::Controller* cntl =
\n
"
" static_cast<baidu::rpc::Controller*>(cntl_base);
\n
"
" ::baidu::paddle_serving::predictor::InferService* svr =
\n
"
" "
"::baidu::paddle_serving::predictor::InferServiceManager::instance("
").item(
\"
$service$
\"
);
\n
"
" if (svr == NULL) {
\n
"
" LOG(ERROR) <<
\"
Not found service: $service$
\"
;
\n
"
" cntl->SetFailed(404,
\"
Not found service: $service$
\"
);
\n
"
" return ;
\n
"
" }
\n
"
" LOG(INFO) <<
\"
remote_side=\[
\"
<< cntl->remote_side() << "
// NOLINT
"
\"
\]
\"
;
\n
"
" LOG(INFO) <<
\"
local_side=\[
\"
<< cntl->local_side() << "
// NOLINT
"
\"
\]
\"
;
\n
"
" LOG(INFO) <<
\"
service_name=\[
\"
<<
\"
$name$
\"
<<
\"
\]
\"
;
\n
"
// NOLINT
" LOG(INFO) <<
\"
log_id=\[
\"
<< cntl->log_id() <<
\"
\]
\"
;
\n
"
// NOLINT
" int err_code = svr->inference(request, response);
\n
"
" if (err_code != 0) {
\n
"
" LOG(WARNING)
\n
"
" <<
\"
Failed call inferservice[$name$], name[$service$]
\"\n
"
" <<
\"
, error_code:
\"
<< err_code;
\n
"
" cntl->SetFailed(err_code,
\"
InferService inference "
"failed!
\"
);
\n
"
" }
\n
"
" gettimeofday(&tv, NULL);
\n
"
" long end = tv.tv_sec * 1000000 + tv.tv_usec;
\n
"
" // flush notice log
\n
"
" LOG(INFO) <<
\"
tc=\[
\"
<< (end - start) <<
\"
\]
\"
;
\n
"
,
// NOLINT
"name"
,
class_name
,
"service"
,
service_name
);
}
if
(
m
->
name
()
==
"debug"
)
{
printer
->
Print
(
" baidu::rpc::ClosureGuard done_guard(done);
\n
"
" baidu::rpc::Controller* cntl =
\n
"
" static_cast<baidu::rpc::Controller*>(cntl_base);
\n
"
" ::baidu::paddle_serving::predictor::InferService* svr =
\n
"
" "
"::baidu::paddle_serving::predictor::InferServiceManager::instance("
").item(
\"
$service$
\"
);
\n
"
" if (svr == NULL) {
\n
"
" LOG(ERROR) <<
\"
Not found service: $service$
\"
;
\n
"
" cntl->SetFailed(404,
\"
Not found service: $service$
\"
);
\n
"
" return ;
\n
"
" }
\n
"
" LOG(INFO) <<
\"
remote_side=\[
\"
<< cntl->remote_side() << "
// NOLINT
"
\"
\]
\"
;
\n
"
" LOG(INFO) <<
\"
local_side=\[
\"
<< cntl->local_side() << "
// NOLINT
"
\"
\]
\"
;
\n
"
" LOG(INFO) <<
\"
service_name=\[
\"
<<
\"
$name$
\"
<<
\"
\]
\"
;
\n
"
// NOLINT
" LOG(INFO) <<
\"
log_id=\[
\"
<< cntl->log_id() <<
\"
\]
\"
;
\n
"
// NOLINT
" butil::IOBufBuilder debug_os;
\n
"
" int err_code = svr->inference(request, response, &debug_os);
\n
"
" if (err_code != 0) {
\n
"
" LOG(WARNING)
\n
"
" <<
\"
Failed call inferservice[$name$], name[$service$]
\"\n
"
" <<
\"
, error_code:
\"
<< err_code;
\n
"
" cntl->SetFailed(err_code,
\"
InferService inference "
"failed!
\"
);
\n
"
" }
\n
"
" debug_os.move_to(cntl->response_attachment());
\n
"
" gettimeofday(&tv, NULL);
\n
"
" long end = tv.tv_sec * 1000000 + tv.tv_usec;
\n
"
" // flush notice log
\n
"
" LOG(INFO) <<
\"
tc=\[
\"
<< (end - start) <<
\"
\]
\"
;
\n
"
// NOLINT
" LOG(INFO)
\n
"
" <<
\"
TC=[
\"
<< (end - start) <<
\"
] Received debug "
"request[log_id=
\"
<< cntl->log_id()
\n
"
" <<
\"
] from
\"
<< cntl->remote_side()
\n
"
" <<
\"
to
\"
<< cntl->local_side();
\n
"
,
"name"
,
class_name
,
"service"
,
service_name
);
}
printer
->
Print
(
"}
\n
"
);
}
printer
->
Print
(
"$name$Impl $name$Impl::_s_$variable_name$_impl(
\"
$service$
\"
);
\n
"
,
"name"
,
class_name
,
"variable_name"
,
variable_name
,
"service"
,
service_name
);
return
true
;
}
bool
generate_paddle_serving_stub_head
(
google
::
protobuf
::
io
::
Printer
*
printer
,
const
ServiceDescriptor
*
descriptor
,
string
*
error
,
const
std
::
string
&
service_name
,
const
std
::
string
&
class_name
)
const
{
printer
->
Print
(
"class $name$_StubCallMapper : public baidu::rpc::CallMapper {
\n
"
"private:
\n
"
" uint32_t _package_size;
\n
"
" baidu::paddle_serving::sdk_cpp::Stub* _stub_handler;
\n
"
"public:
\n
"
,
"name"
,
class_name
);
printer
->
Indent
();
printer
->
Print
(
"$name$_StubCallMapper(uint32_t package_size, "
"baidu::paddle_serving::sdk_cpp::Stub* stub) {
\n
"
" _package_size = package_size;
\n
"
" _stub_handler = stub;
\n
"
"}
\n
"
,
"name"
,
class_name
);
printer
->
Print
(
"baidu::rpc::SubCall default_map(
\n
"
" int channel_index,
\n
"
" const google::protobuf::MethodDescriptor* method,
\n
"
" const google::protobuf::Message* request,
\n
"
" google::protobuf::Message* response) {
\n
"
" baidu::paddle_serving::sdk_cpp::TracePackScope "
"scope(
\"
default_map
\"
, channel_index);"
,
"name"
,
class_name
);
printer
->
Indent
();
if
(
!
generate_paddle_serving_stub_default_map
(
printer
,
descriptor
,
error
,
service_name
,
class_name
))
{
return
false
;
}
printer
->
Outdent
();
printer
->
Print
(
"}
\n
"
);
printer
->
Print
(
"baidu::rpc::SubCall sub_package_map(
\n
"
" int channel_index,
\n
"
" const google::protobuf::MethodDescriptor* method,
\n
"
" const google::protobuf::Message* request,
\n
"
" google::protobuf::Message* response) {
\n
"
" baidu::paddle_serving::sdk_cpp::TracePackScope scope(
\"
sub_map
\"
, "
"channel_index);"
,
"name"
,
class_name
);
printer
->
Indent
();
std
::
vector
<
const
FieldDescriptor
*>
in_shared_fields
;
std
::
vector
<
const
FieldDescriptor
*>
in_item_fields
;
const
MethodDescriptor
*
md
=
descriptor
->
FindMethodByName
(
"inference"
);
if
(
!
md
)
{
*
error
=
"not found inference method!"
;
return
false
;
}
for
(
int
i
=
0
;
i
<
md
->
input_type
()
->
field_count
();
++
i
)
{
const
FieldDescriptor
*
fd
=
md
->
input_type
()
->
field
(
i
);
if
(
!
fd
)
{
*
error
=
"invalid fd at: "
+
i
;
return
false
;
}
bool
pack_on
=
fd
->
options
().
GetExtension
(
pds
::
pack_on
);
if
(
pack_on
&&
!
fd
->
is_repeated
())
{
*
error
=
"Pack fields must be repeated, field: "
+
fd
->
name
();
return
false
;
}
if
(
pack_on
)
{
in_item_fields
.
push_back
(
fd
);
}
else
{
in_shared_fields
.
push_back
(
fd
);
}
}
if
(
!
generate_paddle_serving_stub_package_map
(
printer
,
descriptor
,
error
,
service_name
,
class_name
,
in_shared_fields
,
in_item_fields
))
{
return
false
;
}
printer
->
Outdent
();
printer
->
Print
(
"}
\n
"
);
printer
->
Print
(
"baidu::rpc::SubCall Map(
\n
"
" int channel_index,
\n
"
" const google::protobuf::MethodDescriptor* method,
\n
"
" const google::protobuf::Message* request,
\n
"
" google::protobuf::Message* response) {
\n
"
,
"name"
,
class_name
);
printer
->
Indent
();
if
(
in_item_fields
.
size
()
<=
0
)
{
printer
->
Print
(
"// No packed items found in proto file, use default map method
\n
"
"return default_map(channel_index, method, request, response);
\n
"
);
}
else
{
printer
->
Print
(
"butil::Timer tt(butil::Timer::STARTED);
\n
"
"baidu::rpc::SubCall ret;
\n
"
"if (_package_size == 0) {
\n
"
" ret = default_map(channel_index, method, request, response);
\n
"
"} else {
\n
"
" ret = sub_package_map(channel_index, method, request, "
"response);
\n
"
"}
\n
"
"tt.stop();
\n
"
"if (ret.flags != baidu::rpc::SKIP_SUB_CHANNEL && ret.method != NULL) {
\n
"
" _stub_handler->update_latency(tt.u_elapsed(),
\"
pack_map
\"
);
\n
"
"}
\n
"
"return ret;
\n
"
);
}
printer
->
Outdent
();
printer
->
Print
(
"}
\n
"
);
printer
->
Outdent
();
printer
->
Print
(
"};
\n
"
);
////////////////////////////////////////////////////////////////
printer
->
Print
(
"class $name$_StubResponseMerger : public baidu::rpc::ResponseMerger {
\n
"
"private:
\n
"
" uint32_t _package_size;
\n
"
" baidu::paddle_serving::sdk_cpp::Stub* _stub_handler;
\n
"
"public:
\n
"
,
"name"
,
class_name
);
printer
->
Indent
();
printer
->
Print
(
"$name$_StubResponseMerger(uint32_t package_size, "
"baidu::paddle_serving::sdk_cpp::Stub* stub) {
\n
"
" _package_size = package_size;
\n
"
" _stub_handler = stub;
\n
"
"}
\n
"
,
"name"
,
class_name
);
printer
->
Print
(
"baidu::rpc::ResponseMerger::Result default_merge(
\n
"
" google::protobuf::Message* response,
\n
"
" const google::protobuf::Message* sub_response) {
\n
"
" baidu::paddle_serving::sdk_cpp::TracePackScope "
"scope(
\"
default_merge
\"
);"
,
"name"
,
class_name
);
printer
->
Indent
();
if
(
!
generate_paddle_serving_stub_default_merger
(
printer
,
descriptor
,
error
,
service_name
,
class_name
))
{
return
false
;
}
printer
->
Outdent
();
printer
->
Print
(
"}
\n
"
);
printer
->
Print
(
"baidu::rpc::ResponseMerger::Result sub_package_merge(
\n
"
" google::protobuf::Message* response,
\n
"
" const google::protobuf::Message* sub_response) {
\n
"
" baidu::paddle_serving::sdk_cpp::TracePackScope "
"scope(
\"
sub_merge
\"
);"
,
"name"
,
class_name
);
printer
->
Indent
();
if
(
!
generate_paddle_serving_stub_package_merger
(
printer
,
descriptor
,
error
,
service_name
,
class_name
))
{
return
false
;
}
printer
->
Outdent
();
printer
->
Print
(
"}
\n
"
);
printer
->
Print
(
"baidu::rpc::ResponseMerger::Result Merge(
\n
"
" google::protobuf::Message* response,
\n
"
" const google::protobuf::Message* sub_response) {
\n
"
,
"name"
,
class_name
);
printer
->
Indent
();
printer
->
Print
(
"butil::Timer tt(butil::Timer::STARTED);
\n
"
"baidu::rpc::ResponseMerger::Result ret;"
"if (_package_size <= 0) {
\n
"
" ret = default_merge(response, sub_response);
\n
"
"} else {
\n
"
" ret = sub_package_merge(response, sub_response);
\n
"
"}
\n
"
"tt.stop();
\n
"
"if (ret != baidu::rpc::ResponseMerger::FAIL) {
\n
"
" _stub_handler->update_latency(tt.u_elapsed(),
\"
pack_merge
\"
);
\n
"
"}
\n
"
"return ret;
\n
"
);
printer
->
Outdent
();
printer
->
Print
(
"}
\n
"
);
printer
->
Outdent
();
printer
->
Print
(
"};
\n
"
);
return
true
;
}
bool
generate_paddle_serving_stub_default_map
(
google
::
protobuf
::
io
::
Printer
*
printer
,
const
ServiceDescriptor
*
descriptor
,
string
*
error
,
const
std
::
string
&
service_name
,
const
std
::
string
&
class_name
)
const
{
printer
->
Print
(
"if (channel_index > 0) {
\n
"
" return baidu::rpc::SubCall::Skip();
\n
"
"}
\n
"
);
printer
->
Print
(
"google::protobuf::Message* cur_res = "
"_stub_handler->fetch_response();
\n
"
"if (cur_res == NULL) {
\n
"
" LOG(INFO) <<
\"
Failed fetch response from stub handler, new it
\"
;
\n
"
" cur_res = response->New();
\n
"
" if (cur_res == NULL) {
\n
"
" LOG(ERROR) <<
\"
Failed new response item!
\"
;
\n
"
" _stub_handler->update_average(1,
\"
pack_fail
\"
);
\n
"
" return baidu::rpc::SubCall::Bad();
\n
"
" }
\n
"
" return baidu::rpc::SubCall(method, request, cur_res, "
"baidu::rpc::DELETE_RESPONSE);
\n
"
"}
\n
"
);
"LOG(INFO)
\n
"
" <<
\"
[default] Succ map, channel_index:
\"
<< channel_index;
\n
"
;
printer
->
Print
(
"return baidu::rpc::SubCall(method, request, cur_res, 0);
\n
"
);
return
true
;
}
bool
generate_paddle_serving_stub_default_merger
(
google
::
protobuf
::
io
::
Printer
*
printer
,
const
ServiceDescriptor
*
descriptor
,
string
*
error
,
const
std
::
string
&
service_name
,
const
std
::
string
&
class_name
)
const
{
printer
->
Print
(
"try {
\n
"
" response->MergeFrom(*sub_response);
\n
"
" return baidu::rpc::ResponseMerger::MERGED;
\n
"
"} catch (const std::exception& e) {
\n
"
" LOG(ERROR) <<
\"
Merge failed.
\"
;
\n
"
" _stub_handler->update_average(1,
\"
pack_fail
\"
);
\n
"
" return baidu::rpc::ResponseMerger::FAIL;
\n
"
"}
\n
"
);
return
true
;
}
bool
generate_paddle_serving_stub_package_map
(
google
::
protobuf
::
io
::
Printer
*
printer
,
const
ServiceDescriptor
*
descriptor
,
string
*
error
,
const
std
::
string
&
service_name
,
const
std
::
string
&
class_name
,
std
::
vector
<
const
FieldDescriptor
*>&
in_shared_fields
,
// NOLINT
std
::
vector
<
const
FieldDescriptor
*>&
in_item_fields
)
const
{
// NOLINT
const
MethodDescriptor
*
md
=
descriptor
->
FindMethodByName
(
"inference"
);
if
(
!
md
)
{
*
error
=
"not found inference method!"
;
return
false
;
}
printer
->
Print
(
"const $req_type$* req
\n
"
" = dynamic_cast<const $req_type$*>(request);
\n
"
"$req_type$* sub_req = NULL;"
,
"req_type"
,
google
::
protobuf
::
dots_to_colons
(
md
->
input_type
()
->
full_name
()));
// 1. pack fields 逐字段计算index范围,并从req copy值sub_req
printer
->
Print
(
"
\n
// 1. 样本字段(必须为repeated类型)按指定下标复制
\n
"
);
for
(
uint32_t
ii
=
0
;
ii
<
in_item_fields
.
size
();
ii
++
)
{
const
FieldDescriptor
*
fd
=
in_item_fields
[
ii
];
std
::
string
field_name
=
fd
->
name
();
printer
->
Print
(
"
\n
/////$field_name$
\n
"
,
"field_name"
,
field_name
);
if
(
ii
==
0
)
{
printer
->
Print
(
"uint32_t total_size = req->$field_name$_size();
\n
"
"if (channel_index == 0) {
\n
"
" _stub_handler->update_average(total_size,
\"
item_size
\"
);
\n
"
"}
\n
"
,
"field_name"
,
field_name
);
printer
->
Print
(
"int start = _package_size * channel_index;
\n
"
"if (start >= total_size) {
\n
"
" return baidu::rpc::SubCall::Skip();
\n
"
"}
\n
"
"int end = _package_size * (channel_index + 1);
\n
"
"if (end > total_size) {
\n
"
" end = total_size;
\n
"
"}
\n
"
);
printer
->
Print
(
"sub_req = "
"dynamic_cast<$req_type$*>(_stub_handler->fetch_request());
\n
"
"if (sub_req == NULL) {
\n
"
" LOG(ERROR) <<
\"
failed fetch sub_req from stub.
\"
;
\n
"
" _stub_handler->update_average(1,
\"
pack_fail
\"
);
\n
"
" return baidu::rpc::SubCall::Bad();
\n
"
"}
\n
"
,
"name"
,
class_name
,
"req_type"
,
google
::
protobuf
::
dots_to_colons
(
md
->
input_type
()
->
full_name
()));
}
else
{
printer
->
Print
(
"if (req->$field_name$_size() != total_size) {
\n
"
" LOG(ERROR) <<
\"
pack field size not consistency:
\"\n
"
" << total_size <<
\"
!=
\"
<< "
"req->$field_name$_size()
\n
"
" <<
\"
, field: $field_name$.
\"
;
\n
"
" _stub_handler->update_average(1,
\"
pack_fail
\"
);
\n
"
" return baidu::rpc::SubCall::Bad();
\n
"
"}
\n
"
,
"field_name"
,
field_name
);
}
printer
->
Print
(
"for (uint32_t i = start; i < end; ++i) {
\n
"
);
printer
->
Indent
();
if
(
fd
->
cpp_type
()
==
google
::
protobuf
::
FieldDescriptor
::
CPPTYPE_MESSAGE
)
{
printer
->
Print
(
"sub_req->add_$field_name$()->CopyFrom(req->$field_name$(i));
\n
"
,
"field_name"
,
field_name
);
}
else
{
printer
->
Print
(
"sub_req->add_$field_name$(req->$field_name$(i));
\n
"
,
"field_name"
,
field_name
);
}
printer
->
Outdent
();
printer
->
Print
(
"}
\n
"
);
}
// 2. shared fields逐字段从req copy至sub_req
printer
->
Print
(
"
\n
// 2. 共享字段,从req逐个复制到sub_req
\n
"
);
if
(
in_item_fields
.
size
()
==
0
)
{
printer
->
Print
(
"if (sub_req == NULL) { // no packed items
\n
"
" sub_req = "
"dynamic_cast<$req_type$*>(_stub_handler->fetch_request());
\n
"
" if (!sub_req) {
\n
"
" LOG(ERROR) <<
\"
failed fetch sub_req from stub handler.
\"
;
\n
"
" _stub_handler->update_average(1,
\"
pack_fail
\"
);
\n
"
" return baidu::rpc::SubCall::Bad();
\n
"
" }
\n
"
"}
\n
"
,
"req_type"
,
google
::
protobuf
::
dots_to_colons
(
md
->
input_type
()
->
full_name
()));
}
for
(
uint32_t
si
=
0
;
si
<
in_shared_fields
.
size
();
si
++
)
{
const
FieldDescriptor
*
fd
=
in_shared_fields
[
si
];
std
::
string
field_name
=
fd
->
name
();
printer
->
Print
(
"
\n
/////$field_name$
\n
"
,
"field_name"
,
field_name
);
if
(
fd
->
is_optional
())
{
printer
->
Print
(
"if (req->has_$field_name$()) {
\n
"
,
"field_name"
,
field_name
);
printer
->
Indent
();
}
if
(
fd
->
cpp_type
()
==
google
::
protobuf
::
FieldDescriptor
::
CPPTYPE_MESSAGE
||
fd
->
is_repeated
())
{
printer
->
Print
(
"sub_req->mutable_$field_name$()->CopyFrom(req->$field_name$());
\n
"
,
"field_name"
,
field_name
);
}
else
{
printer
->
Print
(
"sub_req->set_$field_name$(req->$field_name$());
\n
"
,
"field_name"
,
field_name
);
}
if
(
fd
->
is_optional
())
{
printer
->
Outdent
();
printer
->
Print
(
"}
\n
"
);
}
}
printer
->
Print
(
"LOG(INFO)
\n
"
" <<
\"
[pack] Succ map req at:
\"\n
"
" << channel_index;
\n
"
);
printer
->
Print
(
"google::protobuf::Message* sub_res = "
"_stub_handler->fetch_response();
\n
"
"if (sub_res == NULL) {
\n
"
" LOG(ERROR) <<
\"
failed create sub_res from res.
\"
;
\n
"
" _stub_handler->update_average(1,
\"
pack_fail
\"
);
\n
"
" return baidu::rpc::SubCall::Bad();
\n
"
"}
\n
"
"return baidu::rpc::SubCall(method, sub_req, sub_res, 0);
\n
"
);
return
true
;
}
bool
generate_paddle_serving_stub_package_merger
(
google
::
protobuf
::
io
::
Printer
*
printer
,
const
ServiceDescriptor
*
descriptor
,
string
*
error
,
const
std
::
string
&
service_name
,
const
std
::
string
&
class_name
)
const
{
return
generate_paddle_serving_stub_default_merger
(
printer
,
descriptor
,
error
,
service_name
,
class_name
);
}
bool
generate_paddle_serving_stub_body
(
google
::
protobuf
::
io
::
Printer
*
printer
,
const
ServiceDescriptor
*
descriptor
,
string
*
error
,
const
std
::
string
&
service_name
,
const
std
::
string
&
class_name
)
const
{
std
::
vector
<
const
MethodDescriptor
*>
methods
;
for
(
int
i
=
0
;
i
<
descriptor
->
method_count
();
++
i
)
{
methods
.
push_back
(
descriptor
->
method
(
i
));
}
if
(
!
valid_service_method
(
methods
))
{
*
error
=
"Service can only contains two methods: inferend, debug"
;
return
false
;
}
const
MethodDescriptor
*
md
=
methods
[
0
];
std
::
map
<
string
,
string
>
variables
;
variables
[
"name"
]
=
class_name
;
variables
[
"req_type"
]
=
google
::
protobuf
::
dots_to_colons
(
md
->
input_type
()
->
full_name
());
variables
[
"res_type"
]
=
google
::
protobuf
::
dots_to_colons
(
md
->
output_type
()
->
full_name
());
variables
[
"fullname"
]
=
descriptor
->
full_name
();
printer
->
Print
(
variables
,
"REGIST_STUB_OBJECT_WITH_TAG(
\n
"
" $name$_Stub,
\n
"
" $name$_StubCallMapper,
\n
"
" $name$_StubResponseMerger,
\n
"
" $req_type$,
\n
"
" $res_type$,
\n
"
"
\"
$fullname$
\"
);
\n
"
);
variables
.
clear
();
return
true
;
}
};
#else // #ifdef BCLOUD
class
PdsCodeGenerator
:
public
CodeGenerator
{
public:
virtual
bool
Generate
(
const
FileDescriptor
*
file
,
...
...
@@ -809,6 +1538,8 @@ class PdsCodeGenerator : public CodeGenerator {
return
true
;
}
};
#endif // #ifdef BCLOUD
int
main
(
int
argc
,
char
**
argv
)
{
PdsCodeGenerator
generator
;
return
google
::
protobuf
::
compiler
::
PluginMain
(
argc
,
argv
,
&
generator
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录