未验证 提交 d74d1838 编写于 作者: S Shang Zhizhou 提交者: GitHub

增加为Jetson推理的库体积裁剪工具 (#43453)

* test=document_fix

* test=document_fix; add patch file

* test=document_fix;update style

* test=document_fix;update patch file

* test=document_fix;remove useless patch file
上级 e0a01461
diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc
index 0645af611b..6b05a7fffb 100644
--- a/paddle/fluid/inference/api/analysis_predictor.cc
+++ b/paddle/fluid/inference/api/analysis_predictor.cc
@@ -1923,7 +1923,7 @@ USE_TRT_CONVERTER(shuffle_channel);
USE_TRT_CONVERTER(swish);
USE_TRT_CONVERTER(group_norm);
USE_TRT_CONVERTER(instance_norm);
-USE_TRT_CONVERTER(layer_norm);
+//USE_TRT_CONVERTER(layer_norm);
USE_TRT_CONVERTER(gelu);
USE_TRT_CONVERTER(multihead_matmul);
USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm);
@@ -1933,13 +1933,13 @@ USE_TRT_CONVERTER(scale);
USE_TRT_CONVERTER(stack);
USE_TRT_CONVERTER(clip);
USE_TRT_CONVERTER(gather);
-USE_TRT_CONVERTER(anchor_generator);
+//USE_TRT_CONVERTER(anchor_generator);
USE_TRT_CONVERTER(yolo_box);
USE_TRT_CONVERTER(yolo_box_head);
USE_TRT_CONVERTER(arg_max);
-USE_TRT_CONVERTER(roi_align);
-USE_TRT_CONVERTER(affine_channel);
-USE_TRT_CONVERTER(multiclass_nms);
+//USE_TRT_CONVERTER(roi_align);
+//USE_TRT_CONVERTER(affine_channel);
+//USE_TRT_CONVERTER(multiclass_nms);
USE_TRT_CONVERTER(multiclass_nms3);
USE_TRT_CONVERTER(nearest_interp);
USE_TRT_CONVERTER(nearest_interp_v2);
diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc
index be7d6ab868..498e09cb4d 100644
--- a/paddle/fluid/inference/analysis/analyzer.cc
+++ b/paddle/fluid/inference/analysis/analyzer.cc
@@ -32,6 +32,9 @@ void Analyzer::RunAnalysis(Argument *argument) {
"analsis_passes is not valid in the argument."));
const bool disable_logs = argument->disable_logs();
for (auto &pass : argument->analysis_passes()) {
+ if (pass == "ir_params_sync_among_devices_pass") {
+ continue;
+ }
if (!disable_logs) {
string::PrettyLogH1("--- Running analysis [%s]", pass);
}
diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc
index 904e4854ba..4f8c955d8c 100644
--- a/paddle/fluid/platform/device_context.cc
+++ b/paddle/fluid/platform/device_context.cc
@@ -466,15 +466,15 @@ CUDAContext::CUDAContext(const CUDAPlace& place,
place_ = place;
CUDADeviceGuard guard(place_.device);
stream_.reset(new stream::CUDAStream(place, priority, flag));
- InitEigenContext();
- InitCuBlasContext();
- InitCuDNNContext();
+ //InitEigenContext();
+ //InitCuBlasContext();
+ //InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
#if CUDA_VERSION >= 11060
- InitCuBlasLtContext();
+ //InitCuBlasLtContext();
#endif
- InitCuSparseContext();
- InitCuSolverContext();
+ //InitCuSparseContext();
+ //InitCuSolverContext();
#endif
}
@@ -506,14 +506,14 @@ void CUDAContext::SetStream(gpuStream_t stream) {
CUDAContext::~CUDAContext() {
CUDADeviceGuard guard(place_.device);
- DestoryCuDNNContext();
- DestoryCuBlasContext();
+ //DestoryCuDNNContext();
+ //DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
#if CUDA_VERSION >= 11060
- InitCuBlasLtContext();
+ //InitCuBlasLtContext();
#endif
- DestoryCuSparseContext();
- DestoryCuSolverContext();
+ //DestoryCuSparseContext();
+ //DestoryCuSolverContext();
#endif
}
diff --git a/paddle/fluid/operators/jit/gen/jitcode.h b/paddle/fluid/operators/jit/gen/jitcode.h
index 23650c8efc..24466e4327 100644
--- a/paddle/fluid/operators/jit/gen/jitcode.h
+++ b/paddle/fluid/operators/jit/gen/jitcode.h
@@ -97,8 +97,8 @@ class JitCode : public GenBase, public Xbyak::CodeGenerator {
}
ret();
}
- void L(const char* label) { Xbyak::CodeGenerator::L(label); }
- void L(Xbyak::Label& label) { Xbyak::CodeGenerator::L(label); } // NOLINT
+ void L(const char* label) { }
+ void L(Xbyak::Label& label) { } // NOLINT
// Enhanced vector extension
Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base, int offt,
bool bcast = false) {
diff --git a/paddle/fluid/framework/op_registry.h b/paddle/fluid/framework/op_registry.h
index a1f07f9f25..179df3b981 100644
--- a/paddle/fluid/framework/op_registry.h
+++ b/paddle/fluid/framework/op_registry.h
@@ -178,9 +178,8 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
RegisterKernelClass<PlaceType, T>(
op_type, library_type, customized_type_value,
- [op_type](const framework::ExecutionContext& ctx) {
+ [](const framework::ExecutionContext& ctx) {
KERNEL_TYPE().Compute(ctx);
- CheckKernelLaunch<PlaceType>(op_type);
});
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
@@ -240,13 +239,8 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
void operator()(const char* op_type, const char* library_type,
int customized_type_value) const {
- RegisterKernelClass<PlaceType, T>(
- op_type, library_type, customized_type_value,
-
- [op_type](const framework::ExecutionContext& ctx) {
- Functor()(ctx);
- CheckKernelLaunch<PlaceType>(op_type);
- });
+ RegisterKernelClass<PlaceType, T>(op_type, library_type,
+ customized_type_value, Functor());
constexpr auto size =
std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value;
@@ -275,7 +269,7 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
VarTypeInference
InferShapeBase
*/
-#define REGISTER_OPERATOR(op_type, op_class, ...) \
+#define REGISTER_OPERATOR__(op_type, op_class, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, \
"REGISTER_OPERATOR must be called in global namespace"); \
@@ -286,15 +280,22 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
return 0; \
}
-#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, ...) \
+#define REGISTER_OPERATOR(op_type, op_class, ...)
+
+#define REGISTER_OP_WITHOUT_GRADIENT__(op_type, op_class, ...) \
REGISTER_OPERATOR(op_type, op_class, __VA_ARGS__, \
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>)
+#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, ...)
/**
* Macro to register OperatorKernel.
*/
#define REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(op_type, library_type, \
+ place_class, customized_name, \
+ customized_type_value, ...)
+
+#define REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE__(op_type, library_type, \
place_class, customized_name, \
customized_type_value, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
@@ -311,18 +312,22 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
return 0; \
}
-#define REGISTER_OP_KERNEL(op_type, library_type, place_class, ...) \
- REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( \
+#define REGISTER_OP_KERNEL(op_type, library_type, place_class, ...)
+
+#define REGISTER_OP_KERNEL__(op_type, library_type, place_class, ...) \
+ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE__( \
op_type, library_type, place_class, DEFAULT_TYPE, \
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__)
-#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
+#define REGISTER_OP_CUDA_KERNEL__(op_type, ...) \
+ REGISTER_OP_KERNEL__(op_type, CUDA, ::paddle::platform::CUDAPlace, __VA_ARGS__)
+
#define REGISTER_OP_CUDA_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::CUDAPlace, __VA_ARGS__)
-#else
-#define REGISTER_OP_CUDA_KERNEL(op_type, ...)
-#endif
+
+#define REGISTER_OP_CPU_KERNEL__(op_type, ...) \
+ REGISTER_OP_KERNEL__(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
@@ -340,6 +345,11 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
REGISTER_OP_KERNEL(op_type, MLU, ::paddle::platform::MLUPlace, __VA_ARGS__)
#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, \
+ customized_name, \
+ customized_type_value, \
+ ...)
+
+#define REGISTER_OP_KERNEL_EX__(op_type, library_type, place_class, \
customized_name, \
customized_type_value, \
...) \
@@ -357,8 +367,10 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
return 0; \
}
-#define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...) \
- REGISTER_OP_KERNEL_EX( \
+#define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...)
+
+#define REGISTER_OP_CUDA_KERNEL_FUNCTOR__(op_type, ...) \
+ REGISTER_OP_KERNEL_EX__( \
op_type, CUDA, ::paddle::platform::CUDAPlace, DEFAULT_TYPE, \
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__)
@@ -375,12 +387,6 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__)
-#define REGISTER_OP_NPU_KERNEL_FUNCTOR(op_type, ...) \
- REGISTER_OP_KERNEL_EX( \
- op_type, NPU, ::paddle::platform::NPUPlace, DEFAULT_TYPE, \
- ::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
- __VA_ARGS__)
-
#define REGISTER_OP_MLU_KERNEL_FUNCTOR(op_type, ...) \
REGISTER_OP_KERNEL_EX( \
op_type, MLU, ::paddle::platform::MLUPlace, DEFAULT_TYPE, \
@@ -392,7 +398,9 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
* we will use and tell the compiler to
* link them into target.
*/
-#define USE_OP_ITSELF(op_type) \
+#define USE_OP_ITSELF(op_type)
+
+#define USE_OP_ITSELF__(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_itself_##op_type, \
"USE_OP_ITSELF must be called in global namespace"); \
@@ -400,6 +408,10 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
UNUSED static int use_op_itself_##op_type##_ = TouchOpRegistrar_##op_type()
#define USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(op_type, \
+ LIBRARY_TYPE, \
+ customized_name)
+
+#define USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE__(op_type, \
LIBRARY_TYPE, \
customized_name) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
@@ -410,33 +422,58 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
UNUSED static int use_op_kernel_##op_type##_##LIBRARY_TYPE##_##customized_name##_ = /* NOLINT */ \
TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE##_##customized_name()
-#define USE_OP_DEVICE_KERNEL(op_type, LIBRARY_TYPE) \
- USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(op_type, LIBRARY_TYPE, DEFAULT_TYPE)
+#define USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(op_type, \
+ LIBRARY_TYPE, \
+ customized_name)
+
+#define USE_OP_DEVICE_KERNEL__(op_type, LIBRARY_TYPE) \
+ USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE__(op_type, LIBRARY_TYPE, DEFAULT_TYPE)
+
+#define USE_OP_DEVICE_KERNEL(op_type, LIBRARY_TYPE)
// TODO(fengjiayi): The following macros
// seems ugly, do we have better method?
-#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
+#ifndef PADDLE_WITH_CUDA
#define USE_OP_KERNEL(op_type) USE_OP_DEVICE_KERNEL(op_type, CPU)
+#define USE_OP_KERNEL__(op_type) USE_OP_DEVICE_KERNEL__(op_type, CPU)
#else
#define USE_OP_KERNEL(op_type) \
USE_OP_DEVICE_KERNEL(op_type, CPU); \
USE_OP_DEVICE_KERNEL(op_type, CUDA)
+
+#define USE_OP_KERNEL__(op_type) \
+ USE_OP_DEVICE_KERNEL__(op_type, CPU); \
+ USE_OP_DEVICE_KERNEL__(op_type, CUDA)
#endif
#define USE_NO_KERNEL_OP(op_type) USE_OP_ITSELF(op_type);
+#define USE_NO_KERNEL_OP__(op_type) USE_OP_ITSELF__(op_type);
+
#define USE_CPU_ONLY_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_DEVICE_KERNEL(op_type, CPU);
+#define USE_CPU_ONLY_OP__(op_type) \
+ USE_OP_ITSELF__(op_type); \
+ USE_OP_DEVICE_KERNEL__(op_type, CPU);
+
#define USE_CUDA_ONLY_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_DEVICE_KERNEL(op_type, CUDA)
+#define USE_CUDA_ONLY_OP__(op_type) \
+ USE_OP_ITSELF__(op_type); \
+ USE_OP_DEVICE_KERNEL__(op_type, CUDA)
+
#define USE_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_KERNEL(op_type)
+
+#define USE_OP__(op_type) \
+ USE_OP_ITSELF__(op_type); \
+ USE_OP_KERNEL__(op_type)
// clang-format on
} // namespace framework
diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h
index d6a0b643c2..511844b482 100644
--- a/paddle/fluid/inference/api/paddle_analysis_config.h
+++ b/paddle/fluid/inference/api/paddle_analysis_config.h
@@ -46,6 +46,7 @@
namespace paddle {
class AnalysisPredictor;
+class TensorRTPredictor;
struct MkldnnQuantizerConfig;
struct LiteNNAdapterConfig {
@@ -700,6 +701,8 @@ struct PD_INFER_DECL AnalysisConfig {
friend class ::paddle::AnalysisPredictor;
+ friend class ::paddle::TensorRTPredictor;
+
///
/// \brief Get a pass builder for customize the passes in IR analysis phase.
/// NOTE: Just for developer, not an official API, easy to be broken.
diff --git a/paddle/fluid/inference/api/paddle_api.h b/paddle/fluid/inference/api/paddle_api.h
index b28370fb82..aaf20a28b0 100644
--- a/paddle/fluid/inference/api/paddle_api.h
+++ b/paddle/fluid/inference/api/paddle_api.h
@@ -194,6 +194,7 @@ class PD_INFER_DECL ZeroCopyTensor : public paddle_infer::Tensor {
private:
friend class AnalysisPredictor;
+ friend class TensorRTPredictor;
friend class ONNXRuntimePredictor;
explicit ZeroCopyTensor(void* scope, const void* device_contexts)
: paddle_infer::Tensor{scope, device_contexts} {}
diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h
index 35b90bfa54..ba8220d06a 100644
--- a/paddle/fluid/inference/api/paddle_inference_api.h
+++ b/paddle/fluid/inference/api/paddle_inference_api.h
@@ -41,6 +41,11 @@ limitations under the License. */
/// \since 2.0.0-beta
///
+namespace paddle {
+std::unique_ptr<PaddlePredictor> CreateTensorRTPredictor(
+ const AnalysisConfig& config);
+}
+
namespace paddle_infer {
using PrecisionType = paddle::AnalysisConfig::Precision;
diff --git a/paddle/phi/CMakeLists.txt b/paddle/phi/CMakeLists.txt
index 58ad42ddd1..8ffdafcf0d 100644
--- a/paddle/phi/CMakeLists.txt
+++ b/paddle/phi/CMakeLists.txt
@@ -18,7 +18,7 @@ add_subdirectory(infermeta)
# phi operator definitions
add_subdirectory(ops)
# phi tools
-add_subdirectory(tools)
+#add_subdirectory(tools)
# phi tests
add_subdirectory(tests)
diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
index 394ce7799e..8edbef50be 100644
--- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
+++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
@@ -390,6 +390,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
graph->Has(framework::ir::kEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass));
+ std::unordered_set<std::string> param_set(params.begin(), params.end());
if (use_static_engine) {
trt_engine_serialized_data = GetTrtEngineSerializedData(
Get<std::string>("model_opt_cache_dir"), engine_key);
@@ -399,6 +400,19 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
LOG(INFO) << "Load TRT Optimized Info from "
<< GetTrtEngineSerializedPath(
Get<std::string>("model_opt_cache_dir"), engine_key);
+ const auto* root_scope{param_scope()};
+ for (;root_scope->parent();) {
+ root_scope = root_scope->parent();
+ }
+ for (const auto& name: param_set) {
+ LOG(INFO) << " ===== Clear param: " << name;
+ root_scope->FindLocalVar(name)->Clear();
+ }
+ for (int dev_id = 0; dev_id < paddle::platform::GetGPUDeviceCount();
+ ++dev_id) {
+ memory::Release(platform::CUDAPlace(dev_id));
+ }
+ memory::Release(platform::CPUPlace());
return;
}
}
@@ -411,12 +425,25 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
auto *scope = param_scope();
framework::BlockDesc block_desc_temp(nullptr, block_desc.Proto());
- std::unordered_set<std::string> param_set(params.begin(), params.end());
inference::Singleton<inference::tensorrt::OpConverter>::Global()
.ConvertBlockToTRTEngine(
&block_desc_temp, *scope,
std::vector<std::string>(input_names.begin(), input_names.end()),
param_set, output_mapping, trt_engine);
+ const auto* root_scope{scope};
+ for (;root_scope->parent();) {
+ root_scope = root_scope->parent();
+ }
+ VLOG(4) << "root_scope->LocalVarNames().size: " << root_scope->LocalVarNames().size();
+ for (const auto& name: param_set) {
+ VLOG(4) << " ===== Clear param: " << name;
+ root_scope->FindLocalVar(name)->Clear();
+ }
+ for (int dev_id = 0; dev_id < paddle::platform::GetGPUDeviceCount();
+ ++dev_id) {
+ memory::Release(platform::CUDAPlace(dev_id));
+ }
+ memory::Release(platform::CPUPlace());
if (use_static_engine) {
nvinfer1::IHostMemory *serialized_engine_data = trt_engine->Serialize();
@@ -431,6 +458,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
<< GetTrtEngineSerializedPath(
Get<std::string>("model_opt_cache_dir"), engine_key);
}
+ trt_engine_serialized_data.clear();
+ trt_engine_serialized_data.shrink_to_fit();
}
} // namespace analysis
diff --git a/paddle/fluid/memory/allocation/thread_local_allocator.cc b/paddle/fluid/memory/allocation/thread_local_allocator.cc
index f125670a59..f858a30301 100644
--- a/paddle/fluid/memory/allocation/thread_local_allocator.cc
+++ b/paddle/fluid/memory/allocation/thread_local_allocator.cc
@@ -13,18 +13,62 @@
// limitations under the License.
#include "paddle/fluid/memory/allocation/thread_local_allocator.h"
+#include "paddle/fluid/platform/cuda_device_guard.h"
namespace paddle {
namespace memory {
namespace allocation {
+const int MALLOC_ALIGN = 64;
+
+#define CUDA_CALL(func) \
+ { \
+ auto e = (func); \
+ CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
+ << "CUDA: " << cudaGetErrorString(e); \
+ }
+
+void* DirectAllocator::Alloc(size_t unaligned_size) {
+ if (platform::is_cpu_place(place_)) {
+ size_t offset = sizeof(void*) + MALLOC_ALIGN - 1;
+ char* p = static_cast<char*>(std::malloc(offset + unaligned_size));
+ // Memory checking
+ CHECK(p) << "Error occurred in malloc period: available space is not enough "
+ "for mallocing "
+ << unaligned_size << " bytes.";
+ // Byte alignment
+ void* r = reinterpret_cast<void*>(reinterpret_cast<size_t>(p + offset) &
+ (~(MALLOC_ALIGN - 1)));
+ static_cast<void**>(r)[-1] = p;
+ return r;
+ } else if (platform::is_gpu_place(place_)) {
+ int dev_id = place_.GetDeviceId();
+ platform::CUDADeviceGuard guard(dev_id);
+ void* ptr{};
+ CUDA_CALL(cudaMalloc(&ptr, unaligned_size));
+ return ptr;
+ }
+ return nullptr;
+}
+
+void DirectAllocator::Free(void* ptr) {
+ if (platform::is_cpu_place(place_)) {
+ if (ptr) {
+ std::free(static_cast<void**>(ptr)[-1]);
+ }
+ } else if (platform::is_gpu_place(place_)) {
+ int dev_id = place_.GetDeviceId();
+ platform::CUDADeviceGuard guard(dev_id);
+ CUDA_CALL(cudaFree(ptr));
+ }
+}
+
+
+
ThreadLocalAllocatorImpl::ThreadLocalAllocatorImpl(const platform::Place& p)
: place_(p) {
if (platform::is_gpu_place(place_)) {
- buddy_allocator_.reset(new memory::detail::BuddyAllocator(
- std::unique_ptr<memory::detail::SystemAllocator>(
- new memory::detail::GPUAllocator(place_.device)),
- platform::GpuMinChunkSize(), platform::GpuMaxChunkSize()));
+ direct_allocator_.reset(new DirectAllocator{place_});
} else {
PADDLE_THROW(platform::errors::Unavailable(
"Thread local allocator only supports CUDAPlace now."));
@@ -59,7 +103,7 @@ ThreadLocalCUDAAllocatorPool::ThreadLocalCUDAAllocatorPool()
ThreadLocalAllocation* ThreadLocalAllocatorImpl::AllocateImpl(size_t size) {
VLOG(10) << "ThreadLocalAllocatorImpl::AllocateImpl " << size;
- void* ptr = buddy_allocator_->Alloc(size);
+ void* ptr = direct_allocator_->Alloc(size);
auto* tl_allocation = new ThreadLocalAllocation(ptr, size, place_);
tl_allocation->SetThreadLocalAllocatorImpl(shared_from_this());
return tl_allocation;
@@ -67,12 +111,12 @@ ThreadLocalAllocation* ThreadLocalAllocatorImpl::AllocateImpl(size_t size) {
void ThreadLocalAllocatorImpl::FreeImpl(ThreadLocalAllocation* allocation) {
VLOG(10) << "ThreadLocalAllocatorImpl::FreeImpl " << allocation;
- buddy_allocator_->Free(allocation->ptr());
+ direct_allocator_->Free(allocation->ptr());
delete allocation;
}
uint64_t ThreadLocalAllocatorImpl::ReleaseImpl() {
- return buddy_allocator_->Release();
+ return direct_allocator_->Release();
}
} // namespace allocation
diff --git a/paddle/fluid/memory/allocation/thread_local_allocator.h b/paddle/fluid/memory/allocation/thread_local_allocator.h
index 654fb3fe7b..44c5dbf87f 100644
--- a/paddle/fluid/memory/allocation/thread_local_allocator.h
+++ b/paddle/fluid/memory/allocation/thread_local_allocator.h
@@ -26,6 +26,16 @@ namespace paddle {
namespace memory {
namespace allocation {
+class DirectAllocator {
+public:
+ DirectAllocator(const platform::Place& place) : place_{place} {}
+ void* Alloc(size_t unaligned_size);
+ void Free(void* ptr);
+ uint64_t Release() { return 0;}
+private:
+ platform::Place place_;
+};
+
class ThreadLocalAllocatorImpl;
class ThreadLocalAllocation : public Allocation {
@@ -55,7 +65,7 @@ class ThreadLocalAllocatorImpl
uint64_t ReleaseImpl();
private:
- std::unique_ptr<memory::detail::BuddyAllocator> buddy_allocator_;
+ std::unique_ptr<DirectAllocator> direct_allocator_;
platform::Place place_;
};
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script simply removes all grad ops and kernels. You should use this script
when cmake ON_INFER=ON, which can greatly reduce the volume of the prediction library.
"""
import os
import sys
import re
import glob
import io
def find_type_files(cur_dir, file_type, file_list=[]):
next_level_dirs = os.listdir(cur_dir)
for next_level_name in next_level_dirs:
next_level_dir = os.path.join(cur_dir, next_level_name)
if os.path.isfile(next_level_dir):
if os.path.splitext(next_level_dir)[1] == file_type:
file_list.append(next_level_dir)
elif os.path.isdir(next_level_dir):
find_type_files(next_level_dir, file_type, file_list)
return file_list
def find_kernel(content, pattern):
res = re.findall(pattern, content, flags=re.DOTALL)
ret = []
for p in res:
left, right = 0, 0
for c in p:
if c == '{':
left += 1
elif c == '}':
right += 1
if left == right:
ret.append(p)
return ret, len(ret)
def prune_phi_kernels():
tool_dir = os.path.dirname(os.path.abspath(__file__))
if sys.version_info[0] == 3:
all_op = glob.glob(os.path.join(tool_dir,
'../paddle/phi/kernels/**/*.cc'),
recursive=True)
all_op += glob.glob(os.path.join(tool_dir,
'../paddle/phi/kernels/**/*.cu'),
recursive=True)
elif sys.version_info[0] == 2:
all_op = find_type_files(
os.path.join(tool_dir, '../paddle/phi/kernels/'), '.cc')
all_op = find_type_files(
os.path.join(tool_dir, '../paddle/phi/kernels/'), '.cu', all_op)
register_op_count = 0
for op_file in all_op:
need_continue = False
file_blacklist = [
"kernels/empty_kernel.cc", "/cast_kernel.c", "/batch_norm_kernel.c"
]
for bname in file_blacklist:
if op_file.find(bname) >= 0:
need_continue = True
break
if need_continue:
print("continue:", op_file)
continue
op_name = os.path.split(op_file)[1]
all_matches = []
with io.open(op_file, 'r', encoding='utf-8') as f:
content = ''.join(f.readlines())
op_pattern = 'PD_REGISTER_KERNEL\(.*?\).*?\{.*?\}'
op, op_count = find_kernel(content, op_pattern)
register_op_count += op_count
all_matches.extend(op)
for p in all_matches:
content = content.replace(p, '')
with io.open(op_file, 'w', encoding='utf-8') as f:
f.write(u'{}'.format(content))
print('We erase all grad op and kernel for Paddle-Inference lib.')
print('%50s%10s' % ('type', 'count'))
print('%50s%10s' % ('REGISTER_OPERATOR', register_op_count))
return True
def apply_patches():
work_path = os.path.dirname(os.path.abspath(__file__)) + "/../"
ret = os.system(
"cd %s && rm -f paddle/fluid/inference/api/tensorrt_predictor.* "
" && rm -f paddle/fluid/inference/api/paddle_tensorrt_predictor.h "
" && git apply tools/infer_prune_patches/*.patch && cd -" % work_path)
return ret == 0
def append_fluid_kernels():
op_white_list = ["load", "load_combine"]
#1. add to makefile
file_name = os.path.dirname(os.path.abspath(__file__)) \
+ "/../paddle/fluid/inference/tensorrt/CMakeLists.txt"
append_str = "\nfile(APPEND ${pybind_file} \"USE_NO_KERNEL_OP__(tensorrt_engine);\\n\")\n"
for op in op_white_list:
append_str = append_str + "file(APPEND ${pybind_file} \"USE_OP__(%s);\\n\")\n" % op
with io.open(file_name, 'r', encoding='utf-8') as f:
content = ''.join(f.readlines())
location_str = "nv_library(\n tensorrt_op_teller\n SRCS op_teller.cc\n DEPS framework_proto device_context boost)"
new_content = content.replace(location_str, location_str + append_str)
if new_content == content:
print("ERROR: can not find \"%s\" in file \"%s\"" %
(location_str, file_name))
return False
with io.open(file_name, 'w', encoding='utf-8') as f:
f.write(u'{}'.format(new_content))
#2. add op and kernel register
op_white_list.append("tensorrt_engine")
tool_dir = os.path.dirname(os.path.abspath(__file__))
if sys.version_info[0] == 3:
all_op = glob.glob(os.path.join(tool_dir,
'../paddle/fluid/operators/**/*.cc'),
recursive=True)
all_op += glob.glob(os.path.join(tool_dir,
'../paddle/fluid/operators/**/*.cu'),
recursive=True)
elif sys.version_info[0] == 2:
all_op = find_type_files(
os.path.join(tool_dir, '../paddle/fluid/operators/'), '.cc')
all_op = find_type_files(
os.path.join(tool_dir, '../paddle/fluid/operators/'), '.cu', all_op)
for op_file in all_op:
with io.open(op_file, 'r', encoding='utf-8') as f:
content = ''.join(f.readlines())
for op in op_white_list:
patterns = {
"REGISTER_OPERATOR": "REGISTER_OPERATOR\(\s*%s\s*," % op,
"REGISTER_OP_CPU_KERNEL":
"REGISTER_OP_CPU_KERNEL\(\s*%s\s*," % op,
"REGISTER_OP_CUDA_KERNEL":
"REGISTER_OP_CUDA_KERNEL\(\s*%s\s*," % op
}
for k, p in patterns.items():
matches = re.findall(p, content, flags=re.DOTALL)
if len(matches) > 0:
content = content.replace(matches[0],
matches[0].replace(k, k + "__"))
with io.open(op_file, 'w', encoding='utf-8') as f:
f.write(u'{}'.format(content))
return True
if __name__ == '__main__':
print("================ step 1: apply patches =======================")
assert (apply_patches())
print("==============================================================\n")
print("================ step 2: append fluid op/kernels==============")
assert (append_fluid_kernels())
print("==============================================================\n")
print("================ step 3:prune phi kernels ====================")
assert (prune_phi_kernels())
print("==============================================================\n")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册