Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
2d57158e
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2d57158e
编写于
4月 25, 2018
作者:
Y
Yan Chunwei
提交者:
GitHub
4月 25, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fea/init tensorrt engine (#10003)
上级
64babc9a
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
515 addition
and
10 deletion
+515
-10
paddle/fluid/inference/engine.h
paddle/fluid/inference/engine.h
+53
-0
paddle/fluid/inference/tensorrt/CMakeLists.txt
paddle/fluid/inference/tensorrt/CMakeLists.txt
+4
-1
paddle/fluid/inference/tensorrt/engine.cc
paddle/fluid/inference/tensorrt/engine.cc
+134
-0
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+144
-0
paddle/fluid/inference/tensorrt/helper.h
paddle/fluid/inference/tensorrt/helper.h
+88
-0
paddle/fluid/inference/tensorrt/test_engine.cc
paddle/fluid/inference/tensorrt/test_engine.cc
+83
-0
paddle/fluid/inference/tensorrt/test_tensorrt.cc
paddle/fluid/inference/tensorrt/test_tensorrt.cc
+9
-9
未找到文件。
paddle/fluid/inference/engine.h
0 → 100644
浏览文件 @
2d57158e
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/framework.pb.h"
namespace
paddle
{
namespace
inference
{
/*
* EngineBase is the base class of all inference engines. An inference engine
* takes a paddle program as input, and outputs the result in fluid Tensor
* format. It can be used to optimize performance of computation sub-blocks, for
* example, break down the original block into sub-blocks and execute each
* sub-blocks in different engines.
*
* For example:
* When inference, the resnet50 model can put most of the model into subgraph
* and run it on a TensorRT engine.
*
* There are several engines such as TensorRT and other frameworks, so an
* EngineBase is put forward to give an unified interface for all the
* different engine implemention.
*/
class
EngineBase
{
public:
using
DescType
=
::
paddle
::
framework
::
proto
::
BlockDesc
;
// Build the model and do some preparation, for example, in TensorRT, run
// createInferBuilder, buildCudaEngine.
virtual
void
Build
(
const
DescType
&
paddle_model
)
=
0
;
// Execute the engine, that will run the inference network.
virtual
void
Execute
(
int
batch_size
)
=
0
;
virtual
~
EngineBase
()
{}
};
// class EngineBase
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/CMakeLists.txt
浏览文件 @
2d57158e
nv_test
(
test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader
)
if
(
WITH_TESTING
)
nv_test
(
test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader
)
nv_test
(
test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda
)
endif
()
paddle/fluid/inference/tensorrt/engine.cc
0 → 100644
浏览文件 @
2d57158e
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/engine.h"
#include <NvInfer.h>
#include <cuda.h>
#include <glog/logging.h>
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
void
TensorRTEngine
::
Build
(
const
DescType
&
paddle_model
)
{
PADDLE_ENFORCE
(
false
,
"not implemented"
);
}
void
TensorRTEngine
::
Execute
(
int
batch_size
)
{
infer_context_
->
enqueue
(
batch_size
,
buffers_
.
data
(),
*
stream_
,
nullptr
);
cudaStreamSynchronize
(
*
stream_
);
}
TensorRTEngine
::~
TensorRTEngine
()
{
// clean buffer
for
(
auto
&
buffer
:
buffers_
)
{
if
(
buffer
!=
nullptr
)
{
PADDLE_ENFORCE_EQ
(
0
,
cudaFree
(
buffer
));
buffer
=
nullptr
;
}
}
}
void
TensorRTEngine
::
FreezeNetwork
()
{
PADDLE_ENFORCE
(
infer_builder_
!=
nullptr
,
"Call InitNetwork first to initialize network."
);
PADDLE_ENFORCE
(
infer_network_
!=
nullptr
,
"Call InitNetwork first to initialize network."
);
// build engine.
infer_builder_
->
setMaxBatchSize
(
max_batch_
);
infer_builder_
->
setMaxWorkspaceSize
(
max_workspace_
);
infer_engine_
.
reset
(
infer_builder_
->
buildCudaEngine
(
*
infer_network_
));
PADDLE_ENFORCE
(
infer_engine_
!=
nullptr
,
"build cuda engine failed!"
);
infer_context_
.
reset
(
infer_engine_
->
createExecutionContext
());
// allocate GPU buffers.
buffers_
.
resize
(
buffer_sizes_
.
size
(),
nullptr
);
for
(
auto
&
item
:
buffer_sizes_
)
{
if
(
item
.
second
==
0
)
{
auto
slot_offset
=
infer_engine_
->
getBindingIndex
(
item
.
first
.
c_str
());
item
.
second
=
kDataTypeSize
[
static_cast
<
int
>
(
infer_engine_
->
getBindingDataType
(
slot_offset
))]
*
AccumDims
(
infer_engine_
->
getBindingDimensions
(
slot_offset
));
}
PADDLE_ENFORCE_EQ
(
0
,
cudaMalloc
(
&
buffer
(
item
.
first
),
item
.
second
));
}
}
nvinfer1
::
ITensor
*
TensorRTEngine
::
DeclareInput
(
const
std
::
string
&
name
,
nvinfer1
::
DataType
dtype
,
const
nvinfer1
::
Dims
&
dim
)
{
PADDLE_ENFORCE_EQ
(
0
,
buffer_sizes_
.
count
(
name
),
"duplicate input name %s"
,
name
);
PADDLE_ENFORCE
(
infer_network_
!=
nullptr
,
"should initnetwork first"
);
auto
*
input
=
infer_network_
->
addInput
(
name
.
c_str
(),
dtype
,
dim
);
PADDLE_ENFORCE
(
input
,
"infer network add input %s failed"
,
name
);
buffer_sizes_
[
name
]
=
kDataTypeSize
[
static_cast
<
int
>
(
dtype
)]
*
AccumDims
(
dim
);
return
input
;
}
void
TensorRTEngine
::
DeclareOutput
(
const
nvinfer1
::
ILayer
*
layer
,
int
offset
,
const
std
::
string
&
name
)
{
PADDLE_ENFORCE_EQ
(
0
,
buffer_sizes_
.
count
(
name
),
"duplicate output name %s"
,
name
);
auto
*
output
=
layer
->
getOutput
(
offset
);
PADDLE_ENFORCE
(
output
!=
nullptr
);
output
->
setName
(
name
.
c_str
());
infer_network_
->
markOutput
(
*
output
);
// output buffers' size can only be decided latter, set zero here to mark this
// and will reset latter.
buffer_sizes_
[
name
]
=
0
;
}
void
*
TensorRTEngine
::
GetOutputInGPU
(
const
std
::
string
&
name
)
{
return
buffer
(
name
);
}
void
TensorRTEngine
::
GetOutputInCPU
(
const
std
::
string
&
name
,
void
*
dst
,
size_t
max_size
)
{
// determine data size
auto
it
=
buffer_sizes_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
buffer_sizes_
.
end
());
PADDLE_ENFORCE_GT
(
it
->
second
,
0
);
PADDLE_ENFORCE_GE
(
max_size
,
it
->
second
);
PADDLE_ENFORCE_EQ
(
0
,
cudaMemcpyAsync
(
dst
,
buffer
(
name
),
it
->
second
,
cudaMemcpyDeviceToHost
,
*
stream_
));
}
void
*&
TensorRTEngine
::
buffer
(
const
std
::
string
&
name
)
{
PADDLE_ENFORCE
(
infer_engine_
!=
nullptr
,
"call FreezeNetwork first."
);
auto
it
=
buffer_sizes_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
buffer_sizes_
.
end
());
auto
slot_offset
=
infer_engine_
->
getBindingIndex
(
name
.
c_str
());
return
buffers_
[
slot_offset
];
}
void
TensorRTEngine
::
SetInputFromCPU
(
const
std
::
string
&
name
,
void
*
data
,
size_t
size
)
{
void
*
buf
=
buffer
(
name
);
PADDLE_ENFORCE_EQ
(
0
,
cudaMemcpyAsync
(
buf
,
data
,
size
,
cudaMemcpyHostToDevice
,
*
stream_
));
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/engine.h
0 → 100644
浏览文件 @
2d57158e
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <NvInfer.h>
#include <memory>
#include <unordered_map>
#include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
/*
* TensorRT Engine.
*
* There are two alternative ways to use it, one is to build from a paddle
* protobuf model, another way is to manully construct the network.
*/
class
TensorRTEngine
:
public
EngineBase
{
public:
// Weight is model parameter.
class
Weight
{
public:
Weight
(
nvinfer1
::
DataType
dtype
,
void
*
value
,
int
num_elem
)
{
w_
.
type
=
dtype
;
w_
.
values
=
value
;
w_
.
count
=
num_elem
;
}
const
nvinfer1
::
Weights
&
get
()
{
return
w_
;
}
private:
nvinfer1
::
Weights
w_
;
};
TensorRTEngine
(
int
max_batch
,
int
max_workspace
,
cudaStream_t
*
stream
,
nvinfer1
::
ILogger
&
logger
=
NaiveLogger
::
Global
())
:
max_batch_
(
max_batch
),
max_workspace_
(
max_workspace
),
stream_
(
stream
),
logger_
(
logger
)
{}
virtual
~
TensorRTEngine
();
// TODO(Superjomn) implement it later when graph segmentation is supported.
virtual
void
Build
(
const
DescType
&
paddle_model
)
override
;
virtual
void
Execute
(
int
batch_size
)
override
;
// Initialize the inference network, so that TensorRT layers can add to this
// network.
void
InitNetwork
()
{
infer_builder_
.
reset
(
createInferBuilder
(
logger_
));
infer_network_
.
reset
(
infer_builder_
->
createNetwork
());
}
// After finishing adding ops, freeze this network and creates the executation
// environment.
void
FreezeNetwork
();
// Add an input and set its name, data type and dimention.
nvinfer1
::
ITensor
*
DeclareInput
(
const
std
::
string
&
name
,
nvinfer1
::
DataType
dtype
,
const
nvinfer1
::
Dims
&
dim
);
// Set the offset-th output from a layer as the network's output, and set its
// name.
void
DeclareOutput
(
const
nvinfer1
::
ILayer
*
layer
,
int
offset
,
const
std
::
string
&
name
);
// GPU memory address for an ITensor with specific name. One can operate on
// these memory directly for acceleration, for example, output the converted
// data directly to the buffer to save data copy overhead.
// NOTE this should be used after calling `FreezeNetwork`.
void
*&
buffer
(
const
std
::
string
&
name
);
// Fill an input from CPU memory with name and size.
void
SetInputFromCPU
(
const
std
::
string
&
name
,
void
*
data
,
size_t
size
);
// TODO(Superjomn) is this method necessary given that buffer(xxx) can be
// accessed directly. Fill an input from GPU memory with name and size.
void
SetInputFromGPU
(
const
std
::
string
&
name
,
void
*
data
,
size_t
size
);
// Get an output called name, the output of tensorrt is in GPU, so this method
// will just return the output's GPU memory address.
void
*
GetOutputInGPU
(
const
std
::
string
&
name
);
// LOW EFFICENCY! Get output to CPU, this will trigger a memory copy from GPU
// to CPU.
void
GetOutputInCPU
(
const
std
::
string
&
name
,
void
*
dst
,
size_t
max_size
);
nvinfer1
::
ICudaEngine
*
engine
()
{
return
infer_engine_
.
get
();
}
nvinfer1
::
INetworkDefinition
*
network
()
{
return
infer_network_
.
get
();
}
private:
// the max batch size
int
max_batch_
;
// the max memory size the engine uses
int
max_workspace_
;
cudaStream_t
*
stream_
;
nvinfer1
::
ILogger
&
logger_
;
std
::
vector
<
void
*>
buffers_
;
// max data size for the buffers.
std
::
unordered_map
<
std
::
string
/*name*/
,
size_t
/*max size*/
>
buffer_sizes_
;
// TensorRT related internal members
template
<
typename
T
>
struct
Destroyer
{
void
operator
()(
T
*
x
)
{
x
->
destroy
();
}
};
template
<
typename
T
>
using
infer_ptr
=
std
::
unique_ptr
<
T
,
Destroyer
<
T
>>
;
infer_ptr
<
nvinfer1
::
IBuilder
>
infer_builder_
;
infer_ptr
<
nvinfer1
::
INetworkDefinition
>
infer_network_
;
infer_ptr
<
nvinfer1
::
ICudaEngine
>
infer_engine_
;
infer_ptr
<
nvinfer1
::
IExecutionContext
>
infer_context_
;
};
// class TensorRTEngine
// Add an layer__ into engine__ with args ARGS.
// For example:
// TRT_ENGINE_ADD_LAYER(xxx, FullyConnected, input, dim, weights, bias)
//
// Reference
// https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#charRNN_define_network
//
// will add a fully connected layer into the engine.
// TensorRT has too many layers, so that is not wise to add member functions for
// them, and an macro like this is more extensible when underlying TensorRT
// library add new layer supports.
#define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \
engine__->network()->add##layer__(ARGS);
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/helper.h
0 → 100644
浏览文件 @
2d57158e
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <NvInfer.h>
#include <cuda.h>
#include <glog/logging.h>
#include "paddle/fluid/platform/dynload/tensorrt.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
namespace
dy
=
paddle
::
platform
::
dynload
;
static
size_t
AccumDims
(
nvinfer1
::
Dims
dims
)
{
size_t
num
=
dims
.
nbDims
==
0
?
0
:
1
;
for
(
int
i
=
0
;
i
<
dims
.
nbDims
;
i
++
)
{
PADDLE_ENFORCE_GT
(
dims
.
d
[
i
],
0
);
num
*=
dims
.
d
[
i
];
}
return
num
;
}
// TensorRT data type to size
const
int
kDataTypeSize
[]
=
{
4
,
// kFLOAT
2
,
// kHALF
1
,
// kINT8
4
// kINT32
};
// The following two API are implemented in TensorRT's header file, cannot load
// from the dynamic library. So create our own implementation and directly
// trigger the method from the dynamic library.
static
nvinfer1
::
IBuilder
*
createInferBuilder
(
nvinfer1
::
ILogger
&
logger
)
{
return
static_cast
<
nvinfer1
::
IBuilder
*>
(
dy
::
createInferBuilder_INTERNAL
(
&
logger
,
NV_TENSORRT_VERSION
));
}
static
nvinfer1
::
IRuntime
*
createInferRuntime
(
nvinfer1
::
ILogger
&
logger
)
{
return
static_cast
<
nvinfer1
::
IRuntime
*>
(
dy
::
createInferRuntime_INTERNAL
(
&
logger
,
NV_TENSORRT_VERSION
));
}
// A logger for create TensorRT infer builder.
class
NaiveLogger
:
public
nvinfer1
::
ILogger
{
public:
void
log
(
nvinfer1
::
ILogger
::
Severity
severity
,
const
char
*
msg
)
override
{
switch
(
severity
)
{
case
Severity
::
kINFO
:
LOG
(
INFO
)
<<
msg
;
break
;
case
Severity
::
kWARNING
:
LOG
(
WARNING
)
<<
msg
;
break
;
case
Severity
::
kINTERNAL_ERROR
:
case
Severity
::
kERROR
:
LOG
(
ERROR
)
<<
msg
;
break
;
default:
break
;
}
}
static
nvinfer1
::
ILogger
&
Global
()
{
static
nvinfer1
::
ILogger
*
x
=
new
NaiveLogger
;
return
*
x
;
}
virtual
~
NaiveLogger
()
override
{}
};
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/test_engine.cc
0 → 100644
浏览文件 @
2d57158e
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/engine.h"
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
inference
{
namespace
tensorrt
{
class
TensorRTEngineTest
:
public
::
testing
::
Test
{
protected:
void
SetUp
()
override
{
ASSERT_EQ
(
0
,
cudaStreamCreate
(
&
stream_
));
engine_
=
new
TensorRTEngine
(
1
,
1
<<
10
,
&
stream_
);
engine_
->
InitNetwork
();
}
void
TearDown
()
override
{
delete
engine_
;
cudaStreamDestroy
(
stream_
);
}
protected:
TensorRTEngine
*
engine_
;
cudaStream_t
stream_
;
};
TEST_F
(
TensorRTEngineTest
,
add_layer
)
{
const
int
size
=
1
;
float
raw_weight
[
size
]
=
{
2.
};
// Weight in CPU memory.
float
raw_bias
[
size
]
=
{
3.
};
LOG
(
INFO
)
<<
"create weights"
;
TensorRTEngine
::
Weight
weight
(
nvinfer1
::
DataType
::
kFLOAT
,
raw_weight
,
size
);
TensorRTEngine
::
Weight
bias
(
nvinfer1
::
DataType
::
kFLOAT
,
raw_bias
,
size
);
auto
*
x
=
engine_
->
DeclareInput
(
"x"
,
nvinfer1
::
DataType
::
kFLOAT
,
nvinfer1
::
DimsCHW
{
1
,
1
,
1
});
auto
*
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
FullyConnected
,
*
x
,
size
,
weight
.
get
(),
bias
.
get
());
PADDLE_ENFORCE
(
fc_layer
!=
nullptr
);
engine_
->
DeclareOutput
(
fc_layer
,
0
,
"y"
);
LOG
(
INFO
)
<<
"freeze network"
;
engine_
->
FreezeNetwork
();
ASSERT_EQ
(
engine_
->
engine
()
->
getNbBindings
(),
2
);
// fill in real data
float
x_v
=
1234
;
engine_
->
SetInputFromCPU
(
"x"
,
(
void
*
)
&
x_v
,
1
*
sizeof
(
float
));
LOG
(
INFO
)
<<
"to execute"
;
engine_
->
Execute
(
1
);
LOG
(
INFO
)
<<
"to get output"
;
// void* y_v =
float
y_cpu
;
engine_
->
GetOutputInCPU
(
"y"
,
&
y_cpu
,
sizeof
(
float
));
LOG
(
INFO
)
<<
"to checkout output"
;
ASSERT_EQ
(
y_cpu
,
x_v
*
2
+
3
);
}
}
// namespace tensorrt
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/tensorrt/test_tensorrt.cc
浏览文件 @
2d57158e
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <glog/logging.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录