提交 53b401d5 编写于 作者: L Luo Tao

refine io_convert and op_convert

上级 2a2c83b9
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda)
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
set(ENGINE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/engine.cc)
add_subdirectory(convert)
nv_test(test_tensorrt_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES})
nv_test(test_tensorrt_activation_op SRCS test_activation_op.cc ${ENGINE_FILE} activation_op.cc
nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES})
nv_test(test_trt_activation_op SRCS test_activation_op.cc ${ENGINE_FILE} activation_op.cc
DEPS ${FLUID_CORE_MODULES} activation_op)
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/io_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/io_converter.h"
#include <cuda.h>
#include "paddle/fluid/platform/enforce.h"
......@@ -50,7 +50,7 @@ class DefaultInputConverter : public EngineInputConverter {
}
};
REGISTER_TENSORRT_INPUT_CONVERTER(mul, DefaultInputConverter);
REGISTER_TENSORRT_INPUT_CONVERTER(default, DefaultInputConverter);
} // namespace tensorrt
} // namespace inference
......
......@@ -40,7 +40,8 @@ class EngineInputConverter {
static void Run(const std::string& in_op_type, const LoDTensor& in, void* out,
size_t max_size, cudaStream_t* stream) {
PADDLE_ENFORCE(stream != nullptr);
auto* converter = Registry<EngineInputConverter>::Lookup(in_op_type);
auto* converter = Registry<EngineInputConverter>::Lookup(
in_op_type, "default" /* default_type */);
PADDLE_ENFORCE_NOT_NULL(converter);
converter->SetStream(stream);
(*converter)(in, out, max_size);
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace paddle {
namespace inference {
......@@ -32,34 +33,23 @@ class OpConverter {
OpConverter() {}
virtual void operator()(const framework::OpDesc& op) {}
void Execute(const framework::OpDesc& op, TensorRTEngine* engine) {
void Run(const framework::OpDesc& op, TensorRTEngine* engine) {
std::string type = op.Type();
auto it = converters_.find(type);
PADDLE_ENFORCE(it != converters_.end(), "no OpConverter for optype [%s]",
type);
it->second->SetEngine(engine);
(*it->second)(op);
}
static OpConverter& Global() {
static auto* x = new OpConverter;
return *x;
}
template <typename T>
void Register(const std::string& key) {
converters_[key] = new T;
auto* it = Registry<OpConverter>::Lookup(type);
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", type);
it->SetEngine(engine);
(*it)(op);
}
// convert fluid op to tensorrt layer
void ConvertOp(const framework::OpDesc& op, TensorRTEngine* engine) {
OpConverter::Global().Execute(op, engine);
OpConverter::Run(op, engine);
}
// convert fluid block to tensorrt network
void ConvertBlock(const framework::BlockDesc& block, TensorRTEngine* engine) {
for (auto op : block.AllOps()) {
OpConverter::Global().Execute(*op, engine);
OpConverter::Run(*op, engine);
}
}
......@@ -81,7 +71,7 @@ class OpConverter {
#define REGISTER_TRT_OP_CONVERTER(op_type__, Converter__) \
struct trt_##op_type__##_converter { \
trt_##op_type__##_converter() { \
OpConverter::Global().Register<Converter__>(#op_type__); \
Registry<OpConverter>::Register<Converter__>(#op_type__); \
} \
}; \
trt_##op_type__##_converter trt_##op_type__##_converter__;
......
......@@ -26,7 +26,7 @@ namespace paddle {
namespace inference {
namespace tensorrt {
void compare(float input, float expect) {
void Compare(float input, float expect) {
framework::Scope scope;
platform::CUDAPlace place;
platform::CUDADeviceContext ctx(place);
......@@ -85,8 +85,8 @@ void compare(float input, float expect) {
}
TEST(OpConverter, ConvertRelu) {
compare(1, 1); // relu(1) = 1
compare(-5, 0); // relu(-5) = 0
Compare(1, 1); // relu(1) = 1
Compare(-5, 0); // relu(-5) = 0
}
} // namespace tensorrt
......
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/tensorrt/io_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/io_converter.h"
#include <gtest/gtest.h>
......@@ -34,7 +34,7 @@ TEST_F(EngineInputConverterTester, DefaultCPU) {
ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0);
cudaStream_t stream;
EngineInputConverter::Run("mul", tensor, buffer, tensor.memory_size(),
EngineInputConverter::Run("test", tensor, buffer, tensor.memory_size(),
&stream);
}
......@@ -44,7 +44,7 @@ TEST_F(EngineInputConverterTester, DefaultGPU) {
ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0);
cudaStream_t stream;
EngineInputConverter::Run("mul", tensor, buffer, tensor.memory_size(),
EngineInputConverter::Run("test", tensor, buffer, tensor.memory_size(),
&stream);
}
......
......@@ -20,7 +20,7 @@ namespace paddle {
namespace inference {
namespace tensorrt {
TEST(BlockConverter, ConvertBlock) {
TEST(OpConverter, ConvertBlock) {
framework::ProgramDesc prog;
auto* block = prog.MutableBlock(0);
auto* mul_op = block->AppendOp();
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <string>
#include <unordered_map>
#include "paddle/fluid/platform/enforce.h"
......@@ -49,9 +50,15 @@ struct Registry {
items_[name] = new ItemChild;
}
static ItemParent* Lookup(const std::string& name) {
static ItemParent* Lookup(const std::string& name,
const std::string& default_name = "") {
auto it = items_.find(name);
if (it == items_.end()) return nullptr;
if (it == items_.end()) {
if (default_name == "")
return nullptr;
else
return items_.find(default_name)->second;
}
return it->second;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册