提交 2a2c83b9 编写于 作者: Y Yan Chunwei 提交者: Tao Luo

feature/convert tensorrt io (#10440)

* init

* init

* add ut

* split singleton from base class

* add singleton

* ad singleton
上级 b708ec0a
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda) nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda)
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
set(ENGINE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/engine.cc) set(ENGINE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/engine.cc)
add_subdirectory(convert) add_subdirectory(convert)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/io_converter.h"
#include <cuda.h>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace inference {
namespace tensorrt {
using platform::is_gpu_place;
using platform::is_cpu_place;
class DefaultInputConverter : public EngineInputConverter {
public:
DefaultInputConverter() {}
// NOTE out is GPU memory.
virtual void operator()(const LoDTensor& in, void* out,
size_t max_size) override {
PADDLE_ENFORCE(out != nullptr);
PADDLE_ENFORCE_LE(in.memory_size(), max_size);
const auto& place = in.place();
if (is_cpu_place(place)) {
PADDLE_ENFORCE(stream_ != nullptr);
PADDLE_ENFORCE_EQ(0,
cudaMemcpyAsync(out, in.data<float>(), in.memory_size(),
cudaMemcpyHostToDevice, *stream_));
} else if (is_gpu_place(place)) {
PADDLE_ENFORCE_EQ(0,
cudaMemcpyAsync(out, in.data<float>(), in.memory_size(),
cudaMemcpyHostToHost, *stream_));
} else {
PADDLE_THROW("Unknown device for converter");
}
cudaStreamSynchronize(*stream_);
}
};
REGISTER_TENSORRT_INPUT_CONVERTER(mul, DefaultInputConverter);
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <unordered_map>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace paddle {
namespace inference {
namespace tensorrt {
using framework::LoDTensor;
/*
* Convert Input from Fluid to an Engine.
* TensorRT's ITensor follows row major, NCHW. Fluid is also row major, so in
* most cases just need to copy the data.
*/
class EngineInputConverter {
public:
EngineInputConverter() {}
virtual void operator()(const LoDTensor& in, void* out, size_t max_size) {}
void SetStream(cudaStream_t* stream) { stream_ = stream; }
static void Run(const std::string& in_op_type, const LoDTensor& in, void* out,
size_t max_size, cudaStream_t* stream) {
PADDLE_ENFORCE(stream != nullptr);
auto* converter = Registry<EngineInputConverter>::Lookup(in_op_type);
PADDLE_ENFORCE_NOT_NULL(converter);
converter->SetStream(stream);
(*converter)(in, out, max_size);
}
virtual ~EngineInputConverter() {}
protected:
cudaStream_t* stream_{nullptr};
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
#define REGISTER_TENSORRT_INPUT_CONVERTER(in_op_type__, Converter__) \
struct trt_input_##in_op_type__##_converter { \
trt_input_##in_op_type__##_converter() { \
::paddle::inference::Registry<EngineInputConverter>::Register< \
Converter__>(#in_op_type__); \
} \
}; \
trt_input_##in_op_type__##_converter trt_input_##in_op_type__##_converter__;
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/tensorrt/io_converter.h"
#include <gtest/gtest.h>
namespace paddle {
namespace inference {
namespace tensorrt {
class EngineInputConverterTester : public ::testing::Test {
public:
void SetUp() override { tensor.Resize({10, 10}); }
framework::LoDTensor tensor;
};
TEST_F(EngineInputConverterTester, DefaultCPU) {
void* buffer;
tensor.mutable_data<float>(platform::CPUPlace());
ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0);
cudaStream_t stream;
EngineInputConverter::Run("mul", tensor, buffer, tensor.memory_size(),
&stream);
}
TEST_F(EngineInputConverterTester, DefaultGPU) {
void* buffer;
tensor.mutable_data<float>(platform::CUDAPlace());
ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0);
cudaStream_t stream;
EngineInputConverter::Run("mul", tensor, buffer, tensor.memory_size(),
&stream);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <unordered_map>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace inference {
// NOTE not thread-safe.
template <typename T>
struct Singleton {
static T& Global() {
static T* x = new T;
return *x;
}
Singleton() = delete;
Singleton& operator=(const Singleton&) = delete;
};
/*
* An registor for any type.
* NOTE not thread-safe.
*/
template <typename ItemParent>
struct Registry {
static Registry& Global() {
static auto* x = new Registry<ItemParent>;
return *x;
}
template <typename ItemChild>
static void Register(const std::string& name) {
PADDLE_ENFORCE_EQ(items_.count(name), 0);
items_[name] = new ItemChild;
}
static ItemParent* Lookup(const std::string& name) {
auto it = items_.find(name);
if (it == items_.end()) return nullptr;
return it->second;
}
~Registry() {
for (auto& item : items_) {
delete item.second;
}
}
private:
Registry() = default;
static std::unordered_map<std::string, ItemParent*> items_;
};
template <typename ItemParent>
std::unordered_map<std::string, ItemParent*> Registry<ItemParent>::items_;
} // namespace inference
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册