Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
53a7d38b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
53a7d38b
编写于
6月 14, 2022
作者:
S
Shang Zhizhou
提交者:
GitHub
6月 14, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add jetson tool (#43486)
上级
22e75d92
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
764 addition
and
0 deletion
+764
-0
tools/infer_prune_patches/analysis_predictor.cc.patch
tools/infer_prune_patches/analysis_predictor.cc.patch
+29
-0
tools/infer_prune_patches/analyzer.cc.patch
tools/infer_prune_patches/analyzer.cc.patch
+14
-0
tools/infer_prune_patches/device_context.cc.patch
tools/infer_prune_patches/device_context.cc.patch
+46
-0
tools/infer_prune_patches/jitcode.h.patch
tools/infer_prune_patches/jitcode.h.patch
+15
-0
tools/infer_prune_patches/op_registry.h.patch
tools/infer_prune_patches/op_registry.h.patch
+215
-0
tools/infer_prune_patches/paddle_analysis_config.h.patch
tools/infer_prune_patches/paddle_analysis_config.h.patch
+21
-0
tools/infer_prune_patches/paddle_api.h.patch
tools/infer_prune_patches/paddle_api.h.patch
+12
-0
tools/infer_prune_patches/paddle_inference_api.h.patch
tools/infer_prune_patches/paddle_inference_api.h.patch
+16
-0
tools/infer_prune_patches/phi_cmake.patch
tools/infer_prune_patches/phi_cmake.patch
+13
-0
tools/infer_prune_patches/tensorrt_subgraph_pass.cc.patch
tools/infer_prune_patches/tensorrt_subgraph_pass.cc.patch
+68
-0
tools/infer_prune_patches/thread_local_allocator.cc.patch
tools/infer_prune_patches/thread_local_allocator.cc.patch
+95
-0
tools/infer_prune_patches/thread_local_allocator.h.patch
tools/infer_prune_patches/thread_local_allocator.h.patch
+30
-0
tools/prune_for_jetson.py
tools/prune_for_jetson.py
+190
-0
未找到文件。
tools/infer_prune_patches/analysis_predictor.cc.patch
0 → 100644
浏览文件 @
53a7d38b
diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc
index 14e4c3da62..dbf1c78c41 100644
--- a/paddle/fluid/inference/api/analysis_predictor.cc
+++ b/paddle/fluid/inference/api/analysis_predictor.cc
@@ -1732,7 +1732,7 @@
USE_TRT_CONVERTER(shuffle_channel);
USE_TRT_CONVERTER(swish);
USE_TRT_CONVERTER(group_norm);
USE_TRT_CONVERTER(instance_norm);
-USE_TRT_CONVERTER(layer_norm);
+//USE_TRT_CONVERTER(layer_norm);
USE_TRT_CONVERTER(gelu);
USE_TRT_CONVERTER(multihead_matmul);
USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm);
@@ -1742,11 +1742,11 @@
USE_TRT_CONVERTER(scale);
USE_TRT_CONVERTER(stack);
USE_TRT_CONVERTER(clip);
USE_TRT_CONVERTER(gather);
-USE_TRT_CONVERTER(anchor_generator);
+//USE_TRT_CONVERTER(anchor_generator);
USE_TRT_CONVERTER(yolo_box);
-USE_TRT_CONVERTER(roi_align);
-USE_TRT_CONVERTER(affine_channel);
-USE_TRT_CONVERTER(multiclass_nms);
+//USE_TRT_CONVERTER(roi_align);
+//USE_TRT_CONVERTER(affine_channel);
+//USE_TRT_CONVERTER(multiclass_nms);
USE_TRT_CONVERTER(multiclass_nms3);
USE_TRT_CONVERTER(nearest_interp);
USE_TRT_CONVERTER(nearest_interp_v2);
tools/infer_prune_patches/analyzer.cc.patch
0 → 100644
浏览文件 @
53a7d38b
diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc
index be7d6ab868..498e09cb4d 100644
--- a/paddle/fluid/inference/analysis/analyzer.cc
+++ b/paddle/fluid/inference/analysis/analyzer.cc
@@ -32,6 +32,9 @@
void Analyzer::RunAnalysis(Argument *argument) {
"analsis_passes is not valid in the argument."));
const bool disable_logs = argument->disable_logs();
for (auto &pass : argument->analysis_passes()) {
+ if (pass == "ir_params_sync_among_devices_pass") {
+ continue;
+ }
if (!disable_logs) {
string::PrettyLogH1("--- Running analysis [%s]", pass);
}
tools/infer_prune_patches/device_context.cc.patch
0 → 100644
浏览文件 @
53a7d38b
diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc
index 904e4854ba..4f8c955d8c 100644
--- a/paddle/fluid/platform/device_context.cc
+++ b/paddle/fluid/platform/device_context.cc
@@ -466,15 +466,15 @@
CUDAContext::CUDAContext(const CUDAPlace& place,
place_ = place;
CUDADeviceGuard guard(place_.device);
stream_.reset(new stream::CUDAStream(place, priority, flag));
- InitEigenContext();
- InitCuBlasContext();
- InitCuDNNContext();
+ //InitEigenContext();
+ //InitCuBlasContext();
+ //InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
#if CUDA_VERSION >= 11060
- InitCuBlasLtContext();
+ //InitCuBlasLtContext();
#endif
- InitCuSparseContext();
- InitCuSolverContext();
+ //InitCuSparseContext();
+ //InitCuSolverContext();
#endif
}
@@ -506,14 +506,14 @@
void CUDAContext::SetStream(gpuStream_t stream) {
CUDAContext::~CUDAContext() {
CUDADeviceGuard guard(place_.device);
- DestoryCuDNNContext();
- DestoryCuBlasContext();
+ //DestoryCuDNNContext();
+ //DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
#if CUDA_VERSION >= 11060
- InitCuBlasLtContext();
+ //InitCuBlasLtContext();
#endif
- DestoryCuSparseContext();
- DestoryCuSolverContext();
+ //DestoryCuSparseContext();
+ //DestoryCuSolverContext();
#endif
}
tools/infer_prune_patches/jitcode.h.patch
0 → 100644
浏览文件 @
53a7d38b
diff --git a/paddle/fluid/operators/jit/gen/jitcode.h b/paddle/fluid/operators/jit/gen/jitcode.h
index 23650c8efc..24466e4327 100644
--- a/paddle/fluid/operators/jit/gen/jitcode.h
+++ b/paddle/fluid/operators/jit/gen/jitcode.h
@@ -97,8 +97,8 @@
class JitCode : public GenBase, public Xbyak::CodeGenerator {
}
ret();
}
- void L(const char* label) { Xbyak::CodeGenerator::L(label); }
- void L(Xbyak::Label& label) { Xbyak::CodeGenerator::L(label); } // NOLINT
+ void L(const char* label) { }
+ void L(Xbyak::Label& label) { } // NOLINT
// Enhanced vector extension
Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base, int offt,
bool bcast = false) {
tools/infer_prune_patches/op_registry.h.patch
0 → 100644
浏览文件 @
53a7d38b
diff --git a/paddle/fluid/framework/op_registry.h b/paddle/fluid/framework/op_registry.h
index a1f07f9f25..179df3b981 100644
--- a/paddle/fluid/framework/op_registry.h
+++ b/paddle/fluid/framework/op_registry.h
@@ -178,9 +178,8 @@
struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
RegisterKernelClass<PlaceType, T>(
op_type, library_type, customized_type_value,
- [op_type](const framework::ExecutionContext& ctx) {
+ [](const framework::ExecutionContext& ctx) {
KERNEL_TYPE().Compute(ctx);
- CheckKernelLaunch<PlaceType>(op_type);
});
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
@@ -240,13 +239,8 @@
struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
void operator()(const char* op_type, const char* library_type,
int customized_type_value) const {
- RegisterKernelClass<PlaceType, T>(
- op_type, library_type, customized_type_value,
-
- [op_type](const framework::ExecutionContext& ctx) {
- Functor()(ctx);
- CheckKernelLaunch<PlaceType>(op_type);
- });
+ RegisterKernelClass<PlaceType, T>(op_type, library_type,
+ customized_type_value, Functor());
constexpr auto size =
std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value;
@@ -275,7 +269,7 @@
struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
VarTypeInference
InferShapeBase
*/
-#define REGISTER_OPERATOR(op_type, op_class, ...) \
+#define REGISTER_OPERATOR__(op_type, op_class, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, \
"REGISTER_OPERATOR must be called in global namespace"); \
@@ -286,15 +280,22 @@
struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
return 0; \
}
-#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, ...) \
+#define REGISTER_OPERATOR(op_type, op_class, ...)
+
+#define REGISTER_OP_WITHOUT_GRADIENT__(op_type, op_class, ...) \
REGISTER_OPERATOR(op_type, op_class, __VA_ARGS__, \
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>)
+#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, ...)
/**
* Macro to register OperatorKernel.
*/
#define REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(op_type, library_type, \
+ place_class, customized_name, \
+ customized_type_value, ...)
+
+#define REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE__(op_type, library_type, \
place_class, customized_name, \
customized_type_value, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
@@ -311,18 +312,22 @@
struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
return 0; \
}
-#define REGISTER_OP_KERNEL(op_type, library_type, place_class, ...) \
- REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( \
+#define REGISTER_OP_KERNEL(op_type, library_type, place_class, ...)
+
+#define REGISTER_OP_KERNEL__(op_type, library_type, place_class, ...) \
+ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE__( \
op_type, library_type, place_class, DEFAULT_TYPE, \
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__)
-#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
+#define REGISTER_OP_CUDA_KERNEL__(op_type, ...) \
+ REGISTER_OP_KERNEL__(op_type, CUDA, ::paddle::platform::CUDAPlace, __VA_ARGS__)
+
#define REGISTER_OP_CUDA_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::CUDAPlace, __VA_ARGS__)
-#else
-#define REGISTER_OP_CUDA_KERNEL(op_type, ...)
-#endif
+
+#define REGISTER_OP_CPU_KERNEL__(op_type, ...) \
+ REGISTER_OP_KERNEL__(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
@@ -340,6 +345,11 @@
struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
REGISTER_OP_KERNEL(op_type, MLU, ::paddle::platform::MLUPlace, __VA_ARGS__)
#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, \
+ customized_name, \
+ customized_type_value, \
+ ...)
+
+#define REGISTER_OP_KERNEL_EX__(op_type, library_type, place_class, \
customized_name, \
customized_type_value, \
...) \
@@ -357,8 +367,10 @@
struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
return 0; \
}
-#define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...) \
- REGISTER_OP_KERNEL_EX( \
+#define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...)
+
+#define REGISTER_OP_CUDA_KERNEL_FUNCTOR__(op_type, ...) \
+ REGISTER_OP_KERNEL_EX__( \
op_type, CUDA, ::paddle::platform::CUDAPlace, DEFAULT_TYPE, \
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__)
@@ -375,12 +387,6 @@
struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__)
-#define REGISTER_OP_NPU_KERNEL_FUNCTOR(op_type, ...) \
- REGISTER_OP_KERNEL_EX( \
- op_type, NPU, ::paddle::platform::NPUPlace, DEFAULT_TYPE, \
- ::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
- __VA_ARGS__)
-
#define REGISTER_OP_MLU_KERNEL_FUNCTOR(op_type, ...) \
REGISTER_OP_KERNEL_EX( \
op_type, MLU, ::paddle::platform::MLUPlace, DEFAULT_TYPE, \
@@ -392,7 +398,9 @@
struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
* we will use and tell the compiler to
* link them into target.
*/
-#define USE_OP_ITSELF(op_type) \
+#define USE_OP_ITSELF(op_type)
+
+#define USE_OP_ITSELF__(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_itself_##op_type, \
"USE_OP_ITSELF must be called in global namespace"); \
@@ -400,6 +408,10 @@
struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
UNUSED static int use_op_itself_##op_type##_ = TouchOpRegistrar_##op_type()
#define USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(op_type, \
+ LIBRARY_TYPE, \
+ customized_name)
+
+#define USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE__(op_type, \
LIBRARY_TYPE, \
customized_name) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
@@ -410,33 +422,58 @@
struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
UNUSED static int use_op_kernel_##op_type##_##LIBRARY_TYPE##_##customized_name##_ = /* NOLINT */ \
TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE##_##customized_name()
-#define USE_OP_DEVICE_KERNEL(op_type, LIBRARY_TYPE) \
- USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(op_type, LIBRARY_TYPE, DEFAULT_TYPE)
+#define USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(op_type, \
+ LIBRARY_TYPE, \
+ customized_name)
+
+#define USE_OP_DEVICE_KERNEL__(op_type, LIBRARY_TYPE) \
+ USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE__(op_type, LIBRARY_TYPE, DEFAULT_TYPE)
+
+#define USE_OP_DEVICE_KERNEL(op_type, LIBRARY_TYPE)
// TODO(fengjiayi): The following macros
// seems ugly, do we have better method?
-#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
+#ifndef PADDLE_WITH_CUDA
#define USE_OP_KERNEL(op_type) USE_OP_DEVICE_KERNEL(op_type, CPU)
+#define USE_OP_KERNEL__(op_type) USE_OP_DEVICE_KERNEL__(op_type, CPU)
#else
#define USE_OP_KERNEL(op_type) \
USE_OP_DEVICE_KERNEL(op_type, CPU); \
USE_OP_DEVICE_KERNEL(op_type, CUDA)
+
+#define USE_OP_KERNEL__(op_type) \
+ USE_OP_DEVICE_KERNEL__(op_type, CPU); \
+ USE_OP_DEVICE_KERNEL__(op_type, CUDA)
#endif
#define USE_NO_KERNEL_OP(op_type) USE_OP_ITSELF(op_type);
+#define USE_NO_KERNEL_OP__(op_type) USE_OP_ITSELF__(op_type);
+
#define USE_CPU_ONLY_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_DEVICE_KERNEL(op_type, CPU);
+#define USE_CPU_ONLY_OP__(op_type) \
+ USE_OP_ITSELF__(op_type); \
+ USE_OP_DEVICE_KERNEL__(op_type, CPU);
+
#define USE_CUDA_ONLY_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_DEVICE_KERNEL(op_type, CUDA)
+#define USE_CUDA_ONLY_OP__(op_type) \
+ USE_OP_ITSELF__(op_type); \
+ USE_OP_DEVICE_KERNEL__(op_type, CUDA)
+
#define USE_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_KERNEL(op_type)
+
+#define USE_OP__(op_type) \
+ USE_OP_ITSELF__(op_type); \
+ USE_OP_KERNEL__(op_type)
// clang-format on
} // namespace framework
tools/infer_prune_patches/paddle_analysis_config.h.patch
0 → 100644
浏览文件 @
53a7d38b
diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h
index d6a0b643c2..511844b482 100644
--- a/paddle/fluid/inference/api/paddle_analysis_config.h
+++ b/paddle/fluid/inference/api/paddle_analysis_config.h
@@ -46,6 +46,7 @@
namespace paddle {
class AnalysisPredictor;
+class TensorRTPredictor;
struct MkldnnQuantizerConfig;
struct LiteNNAdapterConfig {
@@ -700,6 +701,8 @@
struct PD_INFER_DECL AnalysisConfig {
friend class ::paddle::AnalysisPredictor;
+ friend class ::paddle::TensorRTPredictor;
+
///
/// \brief Get a pass builder for customize the passes in IR analysis phase.
/// NOTE: Just for developer, not an official API, easy to be broken.
tools/infer_prune_patches/paddle_api.h.patch
0 → 100644
浏览文件 @
53a7d38b
diff --git a/paddle/fluid/inference/api/paddle_api.h b/paddle/fluid/inference/api/paddle_api.h
index 0f8f9e0a97..a7efd9975a 100644
--- a/paddle/fluid/inference/api/paddle_api.h
+++ b/paddle/fluid/inference/api/paddle_api.h
@@ -193,6 +193,7 @@
class PD_INFER_DECL ZeroCopyTensor : public paddle_infer::Tensor {
private:
friend class AnalysisPredictor;
+ friend class TensorRTPredictor;
friend class ONNXRuntimePredictor;
explicit ZeroCopyTensor(void* scope) : paddle_infer::Tensor{scope} {}
};
tools/infer_prune_patches/paddle_inference_api.h.patch
0 → 100644
浏览文件 @
53a7d38b
diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h
index 35b90bfa54..ba8220d06a 100644
--- a/paddle/fluid/inference/api/paddle_inference_api.h
+++ b/paddle/fluid/inference/api/paddle_inference_api.h
@@ -41,6 +41,11 @@
limitations under the License. */
/// \since 2.0.0-beta
///
+namespace paddle {
+std::unique_ptr<PaddlePredictor> CreateTensorRTPredictor(
+ const AnalysisConfig& config);
+}
+
namespace paddle_infer {
using PrecisionType = paddle::AnalysisConfig::Precision;
tools/infer_prune_patches/phi_cmake.patch
0 → 100644
浏览文件 @
53a7d38b
diff --git a/paddle/phi/CMakeLists.txt b/paddle/phi/CMakeLists.txt
index 58ad42ddd1..8ffdafcf0d 100644
--- a/paddle/phi/CMakeLists.txt
+++ b/paddle/phi/CMakeLists.txt
@@ -18,7 +18,7 @@
add_subdirectory(infermeta)
# phi operator definitions
add_subdirectory(ops)
# phi tools
-add_subdirectory(tools)
+#add_subdirectory(tools)
# phi tests
add_subdirectory(tests)
tools/infer_prune_patches/tensorrt_subgraph_pass.cc.patch
0 → 100644
浏览文件 @
53a7d38b
diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
index e4fc52b6fa..24b6f73949 100644
--- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
+++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
@@ -384,6 +384,7 @@
void TensorRtSubgraphPass::CreateTensorRTOp(
(graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass)));
+ std::unordered_set<std::string> param_set(params.begin(), params.end());
if (use_static_engine) {
trt_engine_serialized_data = GetTrtEngineSerializedData(
Get<std::string>("model_opt_cache_dir"), engine_key);
@@ -393,6 +394,19 @@
void TensorRtSubgraphPass::CreateTensorRTOp(
LOG(INFO) << "Load TRT Optimized Info from "
<< GetTrtEngineSerializedPath(
Get<std::string>("model_opt_cache_dir"), engine_key);
+ const auto* root_scope{param_scope()};
+ for (;root_scope->parent();) {
+ root_scope = root_scope->parent();
+ }
+ for (const auto& name: param_set) {
+ LOG(INFO) << " ===== Clear param: " << name;
+ root_scope->FindLocalVar(name)->Clear();
+ }
+ for (int dev_id = 0; dev_id < paddle::platform::GetGPUDeviceCount();
+ ++dev_id) {
+ memory::Release(platform::CUDAPlace(dev_id));
+ }
+ memory::Release(platform::CPUPlace());
return;
}
}
@@ -405,12 +419,25 @@
void TensorRtSubgraphPass::CreateTensorRTOp(
auto *scope = param_scope();
framework::BlockDesc block_desc_temp(nullptr, block_desc.Proto());
- std::unordered_set<std::string> param_set(params.begin(), params.end());
inference::Singleton<inference::tensorrt::OpConverter>::Global()
.ConvertBlockToTRTEngine(
&block_desc_temp, *scope,
std::vector<std::string>(input_names.begin(), input_names.end()),
param_set, output_mapping, trt_engine);
+ const auto* root_scope{scope};
+ for (;root_scope->parent();) {
+ root_scope = root_scope->parent();
+ }
+ VLOG(4) << "root_scope->LocalVarNames().size: " << root_scope->LocalVarNames().size();
+ for (const auto& name: param_set) {
+ VLOG(4) << " ===== Clear param: " << name;
+ root_scope->FindLocalVar(name)->Clear();
+ }
+ for (int dev_id = 0; dev_id < paddle::platform::GetGPUDeviceCount();
+ ++dev_id) {
+ memory::Release(platform::CUDAPlace(dev_id));
+ }
+ memory::Release(platform::CPUPlace());
if (use_static_engine) {
nvinfer1::IHostMemory *serialized_engine_data = trt_engine->Serialize();
@@ -425,6 +452,8 @@
void TensorRtSubgraphPass::CreateTensorRTOp(
<< GetTrtEngineSerializedPath(
Get<std::string>("model_opt_cache_dir"), engine_key);
}
+ trt_engine_serialized_data.clear();
+ trt_engine_serialized_data.shrink_to_fit();
}
} // namespace analysis
tools/infer_prune_patches/thread_local_allocator.cc.patch
0 → 100644
浏览文件 @
53a7d38b
diff --git a/paddle/fluid/memory/allocation/thread_local_allocator.cc b/paddle/fluid/memory/allocation/thread_local_allocator.cc
index f125670a59..f858a30301 100644
--- a/paddle/fluid/memory/allocation/thread_local_allocator.cc
+++ b/paddle/fluid/memory/allocation/thread_local_allocator.cc
@@ -13,18 +13,62 @@
// limitations under the License.
#include "paddle/fluid/memory/allocation/thread_local_allocator.h"
+#include "paddle/fluid/platform/cuda_device_guard.h"
namespace paddle {
namespace memory {
namespace allocation {
+const int MALLOC_ALIGN = 64;
+
+#define CUDA_CALL(func) \
+ { \
+ auto e = (func); \
+ CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
+ << "CUDA: " << cudaGetErrorString(e); \
+ }
+
+void* DirectAllocator::Alloc(size_t unaligned_size) {
+ if (platform::is_cpu_place(place_)) {
+ size_t offset = sizeof(void*) + MALLOC_ALIGN - 1;
+ char* p = static_cast<char*>(std::malloc(offset + unaligned_size));
+ // Memory checking
+ CHECK(p) << "Error occurred in malloc period: available space is not enough "
+ "for mallocing "
+ << unaligned_size << " bytes.";
+ // Byte alignment
+ void* r = reinterpret_cast<void*>(reinterpret_cast<size_t>(p + offset) &
+ (~(MALLOC_ALIGN - 1)));
+ static_cast<void**>(r)[-1] = p;
+ return r;
+ } else if (platform::is_gpu_place(place_)) {
+ int dev_id = place_.GetDeviceId();
+ platform::CUDADeviceGuard guard(dev_id);
+ void* ptr{};
+ CUDA_CALL(cudaMalloc(&ptr, unaligned_size));
+ return ptr;
+ }
+ return nullptr;
+}
+
+void DirectAllocator::Free(void* ptr) {
+ if (platform::is_cpu_place(place_)) {
+ if (ptr) {
+ std::free(static_cast<void**>(ptr)[-1]);
+ }
+ } else if (platform::is_gpu_place(place_)) {
+ int dev_id = place_.GetDeviceId();
+ platform::CUDADeviceGuard guard(dev_id);
+ CUDA_CALL(cudaFree(ptr));
+ }
+}
+
+
+
ThreadLocalAllocatorImpl::ThreadLocalAllocatorImpl(const platform::Place& p)
: place_(p) {
if (platform::is_gpu_place(place_)) {
- buddy_allocator_.reset(new memory::detail::BuddyAllocator(
- std::unique_ptr<memory::detail::SystemAllocator>(
- new memory::detail::GPUAllocator(place_.device)),
- platform::GpuMinChunkSize(), platform::GpuMaxChunkSize()));
+ direct_allocator_.reset(new DirectAllocator{place_});
} else {
PADDLE_THROW(platform::errors::Unavailable(
"Thread local allocator only supports CUDAPlace now."));
@@ -59,7 +103,7 @@
ThreadLocalCUDAAllocatorPool::ThreadLocalCUDAAllocatorPool()
ThreadLocalAllocation* ThreadLocalAllocatorImpl::AllocateImpl(size_t size) {
VLOG(10) << "ThreadLocalAllocatorImpl::AllocateImpl " << size;
- void* ptr = buddy_allocator_->Alloc(size);
+ void* ptr = direct_allocator_->Alloc(size);
auto* tl_allocation = new ThreadLocalAllocation(ptr, size, place_);
tl_allocation->SetThreadLocalAllocatorImpl(shared_from_this());
return tl_allocation;
@@ -67,12 +111,12 @@
ThreadLocalAllocation* ThreadLocalAllocatorImpl::AllocateImpl(size_t size) {
void ThreadLocalAllocatorImpl::FreeImpl(ThreadLocalAllocation* allocation) {
VLOG(10) << "ThreadLocalAllocatorImpl::FreeImpl " << allocation;
- buddy_allocator_->Free(allocation->ptr());
+ direct_allocator_->Free(allocation->ptr());
delete allocation;
}
uint64_t ThreadLocalAllocatorImpl::ReleaseImpl() {
- return buddy_allocator_->Release();
+ return direct_allocator_->Release();
}
} // namespace allocation
tools/infer_prune_patches/thread_local_allocator.h.patch
0 → 100644
浏览文件 @
53a7d38b
diff --git a/paddle/fluid/memory/allocation/thread_local_allocator.h b/paddle/fluid/memory/allocation/thread_local_allocator.h
index 654fb3fe7b..44c5dbf87f 100644
--- a/paddle/fluid/memory/allocation/thread_local_allocator.h
+++ b/paddle/fluid/memory/allocation/thread_local_allocator.h
@@ -26,6 +26,16 @@
namespace paddle {
namespace memory {
namespace allocation {
+class DirectAllocator {
+public:
+ DirectAllocator(const platform::Place& place) : place_{place} {}
+ void* Alloc(size_t unaligned_size);
+ void Free(void* ptr);
+ uint64_t Release() { return 0;}
+private:
+ platform::Place place_;
+};
+
class ThreadLocalAllocatorImpl;
class ThreadLocalAllocation : public Allocation {
@@ -55,7 +65,7 @@
class ThreadLocalAllocatorImpl
uint64_t ReleaseImpl();
private:
- std::unique_ptr<memory::detail::BuddyAllocator> buddy_allocator_;
+ std::unique_ptr<DirectAllocator> direct_allocator_;
platform::Place place_;
};
tools/prune_for_jetson.py
0 → 100644
浏览文件 @
53a7d38b
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script simply removes all grad ops and kernels. You should use this script
when cmake ON_INFER=ON, which can greatly reduce the volume of the prediction library.
"""
import
os
import
sys
import
re
import
glob
import
io
def
find_type_files
(
cur_dir
,
file_type
,
file_list
=
[]):
next_level_dirs
=
os
.
listdir
(
cur_dir
)
for
next_level_name
in
next_level_dirs
:
next_level_dir
=
os
.
path
.
join
(
cur_dir
,
next_level_name
)
if
os
.
path
.
isfile
(
next_level_dir
):
if
os
.
path
.
splitext
(
next_level_dir
)[
1
]
==
file_type
:
file_list
.
append
(
next_level_dir
)
elif
os
.
path
.
isdir
(
next_level_dir
):
find_type_files
(
next_level_dir
,
file_type
,
file_list
)
return
file_list
def
find_kernel
(
content
,
pattern
):
res
=
re
.
findall
(
pattern
,
content
,
flags
=
re
.
DOTALL
)
ret
=
[]
for
p
in
res
:
left
,
right
=
0
,
0
for
c
in
p
:
if
c
==
'{'
:
left
+=
1
elif
c
==
'}'
:
right
+=
1
if
left
==
right
:
ret
.
append
(
p
)
return
ret
,
len
(
ret
)
def
prune_phi_kernels
():
tool_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
if
sys
.
version_info
[
0
]
==
3
:
all_op
=
glob
.
glob
(
os
.
path
.
join
(
tool_dir
,
'../paddle/phi/kernels/**/*.cc'
),
recursive
=
True
)
all_op
+=
glob
.
glob
(
os
.
path
.
join
(
tool_dir
,
'../paddle/phi/kernels/**/*.cu'
),
recursive
=
True
)
elif
sys
.
version_info
[
0
]
==
2
:
all_op
=
find_type_files
(
os
.
path
.
join
(
tool_dir
,
'../paddle/phi/kernels/'
),
'.cc'
)
all_op
=
find_type_files
(
os
.
path
.
join
(
tool_dir
,
'../paddle/phi/kernels/'
),
'.cu'
,
all_op
)
register_op_count
=
0
for
op_file
in
all_op
:
need_continue
=
False
file_blacklist
=
[
"kernels/empty_kernel.cc"
,
"/cast_kernel.c"
,
"/batch_norm_kernel.c"
]
for
bname
in
file_blacklist
:
if
op_file
.
find
(
bname
)
>=
0
:
need_continue
=
True
break
if
need_continue
:
print
(
"continue:"
,
op_file
)
continue
op_name
=
os
.
path
.
split
(
op_file
)[
1
]
all_matches
=
[]
with
io
.
open
(
op_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
content
=
''
.
join
(
f
.
readlines
())
op_pattern
=
'PD_REGISTER_KERNEL\(.*?\).*?\{.*?\}'
op
,
op_count
=
find_kernel
(
content
,
op_pattern
)
register_op_count
+=
op_count
all_matches
.
extend
(
op
)
for
p
in
all_matches
:
content
=
content
.
replace
(
p
,
''
)
with
io
.
open
(
op_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
u
'{}'
.
format
(
content
))
print
(
'We erase all grad op and kernel for Paddle-Inference lib.'
)
print
(
'%50s%10s'
%
(
'type'
,
'count'
))
print
(
'%50s%10s'
%
(
'REGISTER_OPERATOR'
,
register_op_count
))
return
True
def
apply_patches
():
work_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
+
"/../"
ret
=
os
.
system
(
"cd %s && rm -f paddle/fluid/inference/api/tensorrt_predictor.* "
" && rm -f paddle/fluid/inference/api/paddle_tensorrt_predictor.h "
" && git apply tools/infer_prune_patches/*.patch && cd -"
%
work_path
)
return
ret
==
0
def
append_fluid_kernels
():
op_white_list
=
[
"load"
,
"load_combine"
]
#1. add to makefile
file_name
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
\
+
"/../paddle/fluid/inference/tensorrt/CMakeLists.txt"
append_str
=
"
\n
file(APPEND ${pybind_file}
\"
USE_NO_KERNEL_OP__(tensorrt_engine);
\\
n
\"
)
\n
"
for
op
in
op_white_list
:
append_str
=
append_str
+
"file(APPEND ${pybind_file}
\"
USE_OP__(%s);
\\
n
\"
)
\n
"
%
op
with
io
.
open
(
file_name
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
content
=
''
.
join
(
f
.
readlines
())
location_str
=
"nv_library(tensorrt_op_teller SRCS op_teller.cc DEPS framework_proto device_context boost)"
new_content
=
content
.
replace
(
location_str
,
location_str
+
append_str
)
if
new_content
==
content
:
print
(
"ERROR: can not find
\"
%s
\"
in file
\"
%s
\"
"
%
(
location_str
,
file_name
))
return
False
with
io
.
open
(
file_name
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
u
'{}'
.
format
(
new_content
))
#2. add op and kernel register
op_white_list
.
append
(
"tensorrt_engine"
)
tool_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
if
sys
.
version_info
[
0
]
==
3
:
all_op
=
glob
.
glob
(
os
.
path
.
join
(
tool_dir
,
'../paddle/fluid/operators/**/*.cc'
),
recursive
=
True
)
all_op
+=
glob
.
glob
(
os
.
path
.
join
(
tool_dir
,
'../paddle/fluid/operators/**/*.cu'
),
recursive
=
True
)
elif
sys
.
version_info
[
0
]
==
2
:
all_op
=
find_type_files
(
os
.
path
.
join
(
tool_dir
,
'../paddle/fluid/operators/'
),
'.cc'
)
all_op
=
find_type_files
(
os
.
path
.
join
(
tool_dir
,
'../paddle/fluid/operators/'
),
'.cu'
,
all_op
)
for
op_file
in
all_op
:
with
io
.
open
(
op_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
content
=
''
.
join
(
f
.
readlines
())
for
op
in
op_white_list
:
patterns
=
{
"REGISTER_OPERATOR"
:
"REGISTER_OPERATOR\(\s*%s\s*,"
%
op
,
"REGISTER_OP_CPU_KERNEL"
:
"REGISTER_OP_CPU_KERNEL\(\s*%s\s*,"
%
op
,
"REGISTER_OP_CUDA_KERNEL"
:
"REGISTER_OP_CUDA_KERNEL\(\s*%s\s*,"
%
op
}
for
k
,
p
in
patterns
.
items
():
matches
=
re
.
findall
(
p
,
content
,
flags
=
re
.
DOTALL
)
if
len
(
matches
)
>
0
:
content
=
content
.
replace
(
matches
[
0
],
matches
[
0
].
replace
(
k
,
k
+
"__"
))
with
io
.
open
(
op_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
u
'{}'
.
format
(
content
))
return
True
if
__name__
==
'__main__'
:
print
(
"================ step 1: apply patches ======================="
)
assert
(
apply_patches
())
print
(
"==============================================================
\n
"
)
print
(
"================ step 2: append fluid op/kernels=============="
)
assert
(
append_fluid_kernels
())
print
(
"==============================================================
\n
"
)
print
(
"================ step 3:prune phi kernels ===================="
)
assert
(
prune_phi_kernels
())
print
(
"==============================================================
\n
"
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录