未验证 提交 3356fb3c 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #10461 from luotao1/refine_convert

refine io_convert and op_convert
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader) nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda) nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda)
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
set(ENGINE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/engine.cc) set(ENGINE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/engine.cc)
add_subdirectory(convert) add_subdirectory(convert)
nv_test(test_tensorrt_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES}) nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES})
nv_test(test_tensorrt_activation_op SRCS test_activation_op.cc ${ENGINE_FILE} activation_op.cc nv_test(test_trt_activation_op SRCS test_activation_op.cc ${ENGINE_FILE} activation_op.cc
DEPS ${FLUID_CORE_MODULES} activation_op) DEPS ${FLUID_CORE_MODULES} activation_op)
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/io_converter.h" #include "paddle/fluid/inference/tensorrt/convert/io_converter.h"
#include <cuda.h> #include <cuda.h>
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -50,7 +50,7 @@ class DefaultInputConverter : public EngineInputConverter { ...@@ -50,7 +50,7 @@ class DefaultInputConverter : public EngineInputConverter {
} }
}; };
REGISTER_TENSORRT_INPUT_CONVERTER(mul, DefaultInputConverter); REGISTER_TENSORRT_INPUT_CONVERTER(default, DefaultInputConverter);
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
......
...@@ -40,7 +40,8 @@ class EngineInputConverter { ...@@ -40,7 +40,8 @@ class EngineInputConverter {
static void Run(const std::string& in_op_type, const LoDTensor& in, void* out, static void Run(const std::string& in_op_type, const LoDTensor& in, void* out,
size_t max_size, cudaStream_t* stream) { size_t max_size, cudaStream_t* stream) {
PADDLE_ENFORCE(stream != nullptr); PADDLE_ENFORCE(stream != nullptr);
auto* converter = Registry<EngineInputConverter>::Lookup(in_op_type); auto* converter = Registry<EngineInputConverter>::Lookup(
in_op_type, "default" /* default_type */);
PADDLE_ENFORCE_NOT_NULL(converter); PADDLE_ENFORCE_NOT_NULL(converter);
converter->SetStream(stream); converter->SetStream(stream);
(*converter)(in, out, max_size); (*converter)(in, out, max_size);
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -32,34 +33,23 @@ class OpConverter { ...@@ -32,34 +33,23 @@ class OpConverter {
OpConverter() {} OpConverter() {}
virtual void operator()(const framework::OpDesc& op) {} virtual void operator()(const framework::OpDesc& op) {}
void Execute(const framework::OpDesc& op, TensorRTEngine* engine) { void Run(const framework::OpDesc& op, TensorRTEngine* engine) {
std::string type = op.Type(); std::string type = op.Type();
auto it = converters_.find(type); auto* it = Registry<OpConverter>::Lookup(type);
PADDLE_ENFORCE(it != converters_.end(), "no OpConverter for optype [%s]", PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", type);
type); it->SetEngine(engine);
it->second->SetEngine(engine); (*it)(op);
(*it->second)(op);
}
static OpConverter& Global() {
static auto* x = new OpConverter;
return *x;
}
template <typename T>
void Register(const std::string& key) {
converters_[key] = new T;
} }
// convert fluid op to tensorrt layer // convert fluid op to tensorrt layer
void ConvertOp(const framework::OpDesc& op, TensorRTEngine* engine) { void ConvertOp(const framework::OpDesc& op, TensorRTEngine* engine) {
OpConverter::Global().Execute(op, engine); OpConverter::Run(op, engine);
} }
// convert fluid block to tensorrt network // convert fluid block to tensorrt network
void ConvertBlock(const framework::BlockDesc& block, TensorRTEngine* engine) { void ConvertBlock(const framework::BlockDesc& block, TensorRTEngine* engine) {
for (auto op : block.AllOps()) { for (auto op : block.AllOps()) {
OpConverter::Global().Execute(*op, engine); OpConverter::Run(*op, engine);
} }
} }
...@@ -78,12 +68,12 @@ class OpConverter { ...@@ -78,12 +68,12 @@ class OpConverter {
framework::Scope* scope_{nullptr}; framework::Scope* scope_{nullptr};
}; };
#define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \ #define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \
struct trt_##op_type__##_converter { \ struct trt_##op_type__##_converter { \
trt_##op_type__##_converter() { \ trt_##op_type__##_converter() { \
OpConverter::Global().Register<Converter__>(#op_type__); \ Registry<OpConverter>::Register<Converter__>(#op_type__); \
} \ } \
}; \ }; \
trt_##op_type__##_converter trt_##op_type__##_converter__; trt_##op_type__##_converter trt_##op_type__##_converter__;
} // namespace tensorrt } // namespace tensorrt
......
...@@ -26,7 +26,7 @@ namespace paddle { ...@@ -26,7 +26,7 @@ namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
void compare(float input, float expect) { void Compare(float input, float expect) {
framework::Scope scope; framework::Scope scope;
platform::CUDAPlace place; platform::CUDAPlace place;
platform::CUDADeviceContext ctx(place); platform::CUDADeviceContext ctx(place);
...@@ -85,8 +85,8 @@ void compare(float input, float expect) { ...@@ -85,8 +85,8 @@ void compare(float input, float expect) {
} }
TEST(OpConverter, ConvertRelu) { TEST(OpConverter, ConvertRelu) {
compare(1, 1); // relu(1) = 1 Compare(1, 1); // relu(1) = 1
compare(-5, 0); // relu(-5) = 0 Compare(-5, 0); // relu(-5) = 0
} }
} // namespace tensorrt } // namespace tensorrt
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/tensorrt/io_converter.h" #include "paddle/fluid/inference/tensorrt/convert/io_converter.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
...@@ -34,7 +34,7 @@ TEST_F(EngineInputConverterTester, DefaultCPU) { ...@@ -34,7 +34,7 @@ TEST_F(EngineInputConverterTester, DefaultCPU) {
ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0); ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0);
cudaStream_t stream; cudaStream_t stream;
EngineInputConverter::Run("mul", tensor, buffer, tensor.memory_size(), EngineInputConverter::Run("test", tensor, buffer, tensor.memory_size(),
&stream); &stream);
} }
...@@ -44,7 +44,7 @@ TEST_F(EngineInputConverterTester, DefaultGPU) { ...@@ -44,7 +44,7 @@ TEST_F(EngineInputConverterTester, DefaultGPU) {
ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0); ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0);
cudaStream_t stream; cudaStream_t stream;
EngineInputConverter::Run("mul", tensor, buffer, tensor.memory_size(), EngineInputConverter::Run("test", tensor, buffer, tensor.memory_size(),
&stream); &stream);
} }
......
...@@ -20,7 +20,7 @@ namespace paddle { ...@@ -20,7 +20,7 @@ namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
TEST(BlockConverter, ConvertBlock) { TEST(OpConverter, ConvertBlock) {
framework::ProgramDesc prog; framework::ProgramDesc prog;
auto* block = prog.MutableBlock(0); auto* block = prog.MutableBlock(0);
auto* mul_op = block->AppendOp(); auto* mul_op = block->AppendOp();
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <string>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -49,9 +50,15 @@ struct Registry { ...@@ -49,9 +50,15 @@ struct Registry {
items_[name] = new ItemChild; items_[name] = new ItemChild;
} }
static ItemParent* Lookup(const std::string& name) { static ItemParent* Lookup(const std::string& name,
const std::string& default_name = "") {
auto it = items_.find(name); auto it = items_.find(name);
if (it == items_.end()) return nullptr; if (it == items_.end()) {
if (default_name == "")
return nullptr;
else
return items_.find(default_name)->second;
}
return it->second; return it->second;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册