未验证 提交 0701160a 编写于 作者: W Wilber 提交者: GitHub

infrt-trt run resnet50 (#41442)

* add rewrite pattern form paddle op tp trt op

* infrt-trt run resnet50.
Co-authored-by: 圣颖君's avatarweishengying <1343838695@qq.com>
上级 c9e0e10e
...@@ -115,9 +115,6 @@ if (INFRT_WITH_PHI) ...@@ -115,9 +115,6 @@ if (INFRT_WITH_PHI)
endif() endif()
cc_library(infrt SHARED SRCS ${infrt_src} DEPS glog boost ${mlir_libs} ${phi_libs} paddle_framework_proto infrt_naive) cc_library(infrt SHARED SRCS ${infrt_src} DEPS glog boost ${mlir_libs} ${phi_libs} paddle_framework_proto infrt_naive)
if (INFRT_WITH_TRT)
target_link_libraries(infrt infrt_trt)
endif()
cc_library(infrt_static SRCS ${infrt_src} DEPS glog boost ${mlir_libs} ${phi_libs} paddle_framework_proto) cc_library(infrt_static SRCS ${infrt_src} DEPS glog boost ${mlir_libs} ${phi_libs} paddle_framework_proto)
add_dependencies(infrt ${infrt_mlir_incs} mlir-headers) add_dependencies(infrt ${infrt_mlir_incs} mlir-headers)
......
cc_library(infrt_trt SRCS trt_engine.cc DEPS glog phi_dynload_cuda phi) add_subdirectory(plugin)
cc_test_tiny(test_infrt_trt SRCS test_trt_engine.cc DEPS infrt_trt phi_dynload_cuda tensorrt_converter) core_gather_headers()
gather_srcs(infrt_src SRCS trt_engine.cc)
cc_test_tiny(test_infrt_trt SRCS test_trt_engine.cc DEPS infrt phi_dynload_cuda tensorrt_converter)
gather_srcs(infrt_src SRCS pool_op_plugin.cu)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <cassert>
#include <cstring>
#include <string>
#include <type_traits>
#include <vector>
#include "paddle/phi/backends/dynload/tensorrt.h"
namespace infrt {
namespace backends {
namespace tensorrt {
namespace plugin {
template <typename T>
inline void SerializeValue(void** buffer, T const& value);
template <typename T>
inline void DeserializeValue(void const** buffer,
size_t* buffer_size,
T* value);
namespace details {
template <typename T, class Enable = void>
struct Serializer {};
template <typename T>
struct Serializer<T,
typename std::enable_if<std::is_arithmetic<T>::value ||
std::is_enum<T>::value ||
std::is_pod<T>::value>::type> {
static size_t SerializedSize(T const& value) { return sizeof(T); }
static void Serialize(void** buffer, T const& value) {
std::memcpy(*buffer, &value, sizeof(T));
reinterpret_cast<char*&>(*buffer) += sizeof(T);
}
static void Deserialize(void const** buffer, size_t* buffer_size, T* value) {
assert(*buffer_size >= sizeof(T));
std::memcpy(value, *buffer, sizeof(T));
reinterpret_cast<char const*&>(*buffer) += sizeof(T);
*buffer_size -= sizeof(T);
}
};
template <>
struct Serializer<const char*> {
static size_t SerializedSize(const char* value) { return strlen(value) + 1; }
static void Serialize(void** buffer, const char* value) {
std::strcpy(static_cast<char*>(*buffer), value); // NOLINT
reinterpret_cast<char*&>(*buffer) += strlen(value) + 1;
}
static void Deserialize(void const** buffer,
size_t* buffer_size,
const char** value) {
*value = static_cast<char const*>(*buffer);
size_t data_size = strnlen(*value, *buffer_size) + 1;
assert(*buffer_size >= data_size);
reinterpret_cast<char const*&>(*buffer) += data_size;
*buffer_size -= data_size;
}
};
template <typename T>
struct Serializer<std::vector<T>,
typename std::enable_if<std::is_arithmetic<T>::value ||
std::is_enum<T>::value ||
std::is_pod<T>::value>::type> {
static size_t SerializedSize(std::vector<T> const& value) {
return sizeof(value.size()) + value.size() * sizeof(T);
}
static void Serialize(void** buffer, std::vector<T> const& value) {
SerializeValue(buffer, value.size());
size_t nbyte = value.size() * sizeof(T);
std::memcpy(*buffer, value.data(), nbyte);
reinterpret_cast<char*&>(*buffer) += nbyte;
}
static void Deserialize(void const** buffer,
size_t* buffer_size,
std::vector<T>* value) {
size_t size;
DeserializeValue(buffer, buffer_size, &size);
value->resize(size);
size_t nbyte = value->size() * sizeof(T);
CHECK_GE(*buffer_size, nbyte);
std::memcpy(value->data(), *buffer, nbyte);
reinterpret_cast<char const*&>(*buffer) += nbyte;
*buffer_size -= nbyte;
}
};
} // namespace details
template <typename T>
inline size_t SerializedSize(T const& value) {
return details::Serializer<T>::SerializedSize(value);
}
template <typename T>
inline void SerializeValue(void** buffer, T const& value) {
return details::Serializer<T>::Serialize(buffer, value);
}
template <typename T>
inline void DeserializeValue(void const** buffer,
size_t* buffer_size,
T* value) {
return details::Serializer<T>::Deserialize(buffer, buffer_size, value);
}
template <typename T>
class TrtPluginRegistrar {
public:
TrtPluginRegistrar() {
static auto func_ptr = static_cast<nvinfer1::IPluginRegistry*>(
::phi::dynload::getPluginRegistry());
func_ptr->registerCreator(instance, "");
}
private:
//! Plugin instance.
T instance{};
};
#define REGISTER_TRT_PLUGIN(name) \
static TrtPluginRegistrar<name> pluginRegistrar##name {}
} // namespace plugin
} // namespace tensorrt
} // namespace backends
} // namespace infrt
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "glog/logging.h"
#include "paddle/infrt/backends/tensorrt/plugin/plugin_utils.h"
#include "paddle/infrt/backends/tensorrt/plugin/pool_op_plugin.h"
#include "paddle/phi/kernels/funcs/pooling.h"
namespace infrt {
namespace backends {
namespace tensorrt {
namespace plugin {
PoolPlugin::PoolPlugin(bool ceil_mode,
PoolType pool_type,
bool adaptive,
bool exclusive,
std::vector<int> ksize,
std::vector<int> strides,
std::vector<int> paddings,
std::vector<int> input_shape,
std::vector<int> real_paddings)
: ceil_mode_(ceil_mode),
pool_type_(pool_type),
adaptive_(adaptive),
exclusive_(exclusive),
ksize_(ksize),
strides_(strides),
paddings_(paddings),
real_paddings_(real_paddings),
input_shape_(input_shape) {
output_shape_ = input_shape_;
std::vector<int> output_shape =
CalcOutputSize({input_shape_[1], input_shape_[2]},
ceil_mode_,
adaptive_,
ksize_,
strides_,
real_paddings_);
output_shape_[1] = output_shape[0];
output_shape_[2] = output_shape[1];
}
PoolPlugin::PoolPlugin(void const* serialData, size_t serialLength) {
// deserializeBase(serialData, serialLength);
DeserializeValue(&serialData, &serialLength, &ceil_mode_);
DeserializeValue(&serialData, &serialLength, &pool_type_);
DeserializeValue(&serialData, &serialLength, &adaptive_);
DeserializeValue(&serialData, &serialLength, &exclusive_);
DeserializeValue(&serialData, &serialLength, &ksize_);
DeserializeValue(&serialData, &serialLength, &strides_);
DeserializeValue(&serialData, &serialLength, &paddings_);
DeserializeValue(&serialData, &serialLength, &real_paddings_);
DeserializeValue(&serialData, &serialLength, &input_shape_);
DeserializeValue(&serialData, &serialLength, &output_shape_);
}
const char* PoolPlugin::getPluginType() const noexcept { return "pool_plugin"; }
const char* PoolPlugin::getPluginVersion() const noexcept { return "1"; }
int PoolPlugin::getNbOutputs() const noexcept { return 1; }
nvinfer1::Dims PoolPlugin::getOutputDimensions(int outputIndex,
const nvinfer1::Dims* inputs,
int nbInputs) noexcept {
assert(nbInputs == 1);
assert(index == 0);
assert(inputs[0].nbDims == 3);
nvinfer1::Dims const& input_dims = inputs[0];
nvinfer1::Dims output_dims = input_dims;
output_dims.d[1] = output_shape_[1];
output_dims.d[2] = output_shape_[2];
return output_dims;
}
int32_t PoolPlugin::initialize() noexcept { return 0; }
void PoolPlugin::terminate() noexcept {}
size_t PoolPlugin::getWorkspaceSize(int32_t maxBatchSize) const noexcept {
return 0;
}
#if IS_TRT_VERSION_LT(8000)
int PoolPlugin::enqueue(int batch_size,
const void* const* inputs,
void** outputs,
#else
int PoolPlugin::enqueue(int batch_size,
const void* const* inputs,
void* const* outputs,
#endif
void* workspace,
cudaStream_t stream) noexcept {
// TODO(wilber)
int input_size = 0;
float const* idata = reinterpret_cast<float const*>(inputs[0]);
float* const* odatas = reinterpret_cast<float* const*>(outputs);
std::vector<int> input_shape = input_shape_;
std::vector<int> output_shape = output_shape_;
input_shape.insert(input_shape.begin(), batch_size);
output_shape.insert(output_shape.begin(), batch_size);
if (pool_type_ == PoolType::max) {
::phi::funcs::MaxPool<float> pool_process;
::phi::funcs::Pool2dDirectCUDAFunctor<phi::funcs::MaxPool<float>, float>
pool2d_forward;
pool2d_forward(idata,
input_shape,
output_shape,
ksize_,
strides_,
paddings_,
true,
false,
odatas[0],
stream,
pool_process);
} else if (pool_type_ == PoolType::avg) {
::phi::funcs::AvgPool<float> pool_process;
::phi::funcs::Pool2dDirectCUDAFunctor<phi::funcs::AvgPool<float>, float>
pool2d_forward;
pool2d_forward(idata,
input_shape,
output_shape,
ksize_,
strides_,
paddings_,
exclusive_,
adaptive_,
odatas[0],
stream,
pool_process);
}
return cudaGetLastError() != cudaSuccess;
}
// TODO(wilber): serialize base info?
size_t PoolPlugin::getSerializationSize() const noexcept {
return SerializedSize(ceil_mode_) + SerializedSize(pool_type_) +
SerializedSize(adaptive_) + SerializedSize(exclusive_) +
SerializedSize(ksize_) + SerializedSize(strides_) +
SerializedSize(paddings_) + SerializedSize(real_paddings_) +
SerializedSize(input_shape_) + SerializedSize(output_shape_);
}
// TODO(wilber): serialize base info?
void PoolPlugin::serialize(void* buffer) const noexcept {
// serializeBase(buffer);
SerializeValue(&buffer, ceil_mode_);
SerializeValue(&buffer, pool_type_);
SerializeValue(&buffer, adaptive_);
SerializeValue(&buffer, exclusive_);
SerializeValue(&buffer, ksize_);
SerializeValue(&buffer, strides_);
SerializeValue(&buffer, paddings_);
SerializeValue(&buffer, real_paddings_);
SerializeValue(&buffer, input_shape_);
SerializeValue(&buffer, output_shape_);
}
void PoolPlugin::destroy() noexcept { delete this; }
void PoolPlugin::setPluginNamespace(char const* plugin_namespace) noexcept {
namespace_ = plugin_namespace;
}
char const* PoolPlugin::getPluginNamespace() const noexcept {
return namespace_.c_str();
}
nvinfer1::DataType PoolPlugin::getOutputDataType(
int32_t index,
nvinfer1::DataType const* input_types,
int32_t nbInputs) const noexcept {
CHECK_EQ(index, 0);
CHECK_EQ((input_types[0] == nvinfer1::DataType::kFLOAT), true);
return input_types[0];
}
bool PoolPlugin::isOutputBroadcastAcrossBatch(int32_t outputIndex,
bool const* inputIsBroadcasted,
int32_t nbInputs) const noexcept {
return false;
}
bool PoolPlugin::canBroadcastInputAcrossBatch(int32_t inputIndex) const
noexcept {
return false;
}
nvinfer1::IPluginV2Ext* PoolPlugin::clone() const noexcept {
auto* plugin = new PoolPlugin(ceil_mode_,
pool_type_,
adaptive_,
exclusive_,
ksize_,
strides_,
paddings_,
input_shape_,
real_paddings_);
plugin->setPluginNamespace(namespace_.c_str());
return plugin;
}
void PoolPlugin::configurePlugin(nvinfer1::PluginTensorDesc const* in,
int32_t nb_input,
nvinfer1::PluginTensorDesc const* out,
int32_t nb_output) noexcept {
CHECK_EQ(nb_input, 1);
CHECK_EQ(nb_output, 1);
input_dims_ = in[0].dims;
data_format_ = in[0].format;
data_type_ = in[0].type;
}
bool PoolPlugin::supportsFormatCombination(
int32_t pos,
nvinfer1::PluginTensorDesc const* in_out,
int32_t nb_inputs,
int32_t nb_outputs) const noexcept {
CHECK_LT(pos, nb_inputs + nb_outputs);
CHECK_NOTNULL(in_out);
return ((in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
in_out[pos].format == nvinfer1::PluginFormat::kLINEAR);
}
nvinfer1::IPluginV2* PoolPluginCreator::createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept {
// auto* plugin = new UffPoolPluginV2(*fc);
field_collection_ = *fc;
plugin_name_ = name;
const nvinfer1::PluginField* fields = fc->fields;
bool ceil_mode;
PoolPlugin::PoolType pool_type;
bool adaptive;
bool exclusive;
std::vector<int> ksize;
std::vector<int> strides;
std::vector<int> paddings;
std::vector<int> real_paddings;
std::vector<int> input_shape;
std::vector<int> output_shape;
// TODO(wilber): add implement.
CHECK(false) << "not implement";
// for (int i = 0; i < fc->nbFields; ++i) {
// const char* attr_name = fields[i].name;
// if (!strcmp(attr_name, "ceil_mode")) {
// CHECK_EQ(fields[i].type == nvinfer1::PluginFieldType::kINT8, true);
// ceil_mode = *static_cast<const bool*>(fields[i].data);
// // mParam.numOutputBoxesPerClass =
// // *(static_cast<const int*>(fields[i].data));
// }
// }
return nullptr;
}
nvinfer1::IPluginV2* PoolPluginCreator::deserializePlugin(
const char* name, const void* serialData, size_t serialLength) noexcept {
auto* plugin = new PoolPlugin(serialData, serialLength);
plugin_name_ = name;
return plugin;
}
} // namespace plugin
} // namespace tensorrt
} // namespace backends
} // namespace infrt
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <NvInferRuntime.h>
#include <NvInferRuntimeCommon.h>
#include <stdio.h>
#include <cassert>
#include <string>
#include <vector>
#include "paddle/infrt/backends/tensorrt/plugin/plugin_utils.h"
#include "paddle/infrt/backends/tensorrt/trt_utils.h"
namespace infrt {
namespace backends {
namespace tensorrt {
namespace plugin {
static std::vector<int> CalcOutputSize(const std::vector<int>& input_shape,
const bool& ceil_mode,
const bool& adaptive,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& real_paddings) {
std::vector<int> output_shape = input_shape;
if (adaptive) {
output_shape[0] = ksize[0];
output_shape[1] = ksize[1];
} else {
int output_h = 0, output_w = 0;
if (ceil_mode) {
output_h = (input_shape[0] - ksize[0] + real_paddings[0] +
real_paddings[1] + strides[0] - 1) /
strides[0] +
1;
output_w = (input_shape[1] - ksize[1] + real_paddings[2] +
real_paddings[3] + strides[1] - 1) /
strides[1] +
1;
}
// TRT will use native layer when ceil_model=false
/*
else{
output_h = (input_shape[0] - ksize[0] + real_paddings[0] +
real_paddings[1]) / strides[0] + 1;
output_w = (input_shape[1] - ksize[1] + real_paddings[2] +
real_paddings[3]) / strides[1] + 1;
}
*/
output_shape[0] = output_h;
output_shape[1] = output_w;
}
return output_shape;
}
class PoolPlugin : public nvinfer1::IPluginV2IOExt {
public:
enum class PoolType {
max = 0,
avg,
};
PoolPlugin() {}
PoolPlugin(bool ceil_mode,
PoolType pool_type,
bool adaptive,
bool exclusive,
std::vector<int> ksize,
std::vector<int> strides,
std::vector<int> paddings,
std::vector<int> input_shape,
std::vector<int> real_paddings);
PoolPlugin(void const* serialData, size_t serialLength);
// IPluginV2 methods
const char* getPluginType() const noexcept override;
const char* getPluginVersion() const noexcept override;
int getNbOutputs() const noexcept override;
nvinfer1::Dims getOutputDimensions(int outputIndex,
const nvinfer1::Dims* inputs,
int nbInputs) noexcept override;
int32_t initialize() noexcept override;
void terminate() noexcept override;
size_t getWorkspaceSize(int32_t maxBatchSize) const noexcept override;
#if IS_TRT_VERSION_LT(8000)
int enqueue(int batchSize,
const void* const* inputs,
void** outputs,
#else
int enqueue(int batchSize,
const void* const* inputs,
void* const* outputs,
#endif
void* workspace,
cudaStream_t stream) noexcept override;
size_t getSerializationSize() const noexcept override;
void serialize(void* buffer) const noexcept override;
void destroy() noexcept override;
void setPluginNamespace(char const* pluginNamespace) noexcept override;
char const* getPluginNamespace() const noexcept override;
// IPluginV2Ext methods
nvinfer1::DataType getOutputDataType(int32_t index,
nvinfer1::DataType const* inputTypes,
int32_t nbInputs) const
noexcept override;
bool isOutputBroadcastAcrossBatch(int32_t outputIndex,
bool const* inputIsBroadcasted,
int32_t nbInputs) const noexcept override;
bool canBroadcastInputAcrossBatch(int32_t inputIndex) const noexcept override;
// void attachToContext(cudnnContext*,
// cublasContext*,
// IGpuAllocator*) noexcept override;
// void detachFromContext() noexcept override;
IPluginV2Ext* clone() const noexcept override;
// IPluginV2IOExt methods
void configurePlugin(nvinfer1::PluginTensorDesc const* in,
int32_t nb_input,
nvinfer1::PluginTensorDesc const* out,
int32_t nb_output) noexcept override;
bool supportsFormatCombination(int32_t pos,
nvinfer1::PluginTensorDesc const* inOut,
int32_t nb_inputs,
int32_t nb_outputs) const noexcept override;
private:
bool ceil_mode_;
PoolType pool_type_;
bool adaptive_;
bool exclusive_;
std::vector<int> ksize_;
std::vector<int> strides_;
std::vector<int> paddings_;
std::vector<int> real_paddings_;
std::vector<int> input_shape_;
std::vector<int> output_shape_;
private:
nvinfer1::Dims input_dims_;
nvinfer1::DataType data_type_;
nvinfer1::PluginFormat data_format_;
std::string namespace_;
};
class PoolPluginCreator : public nvinfer1::IPluginCreator {
public:
const char* getPluginName() const noexcept override { return "pool_plugin"; }
const char* getPluginVersion() const noexcept override { return "1"; }
const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override {
return &field_collection_;
}
nvinfer1::IPluginV2* createPlugin(
const char* name,
const nvinfer1::PluginFieldCollection* fc) noexcept override;
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serialData,
size_t serialLength) noexcept override;
void setPluginNamespace(const char* plugin_namespace) noexcept override {
plugin_namespace_ = plugin_namespace;
}
const char* getPluginNamespace() const noexcept override {
return plugin_namespace_.c_str();
}
private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
};
REGISTER_TRT_PLUGIN(PoolPluginCreator);
} // namespace plugin
} // namespace tensorrt
} // namespace backends
} // namespace infrt
...@@ -320,9 +320,9 @@ inline ::llvm::SmallVector<::mlir::Value, 4> CreatePaddleTrtPoolingOp( ...@@ -320,9 +320,9 @@ inline ::llvm::SmallVector<::mlir::Value, 4> CreatePaddleTrtPoolingOp(
} }
// if global_pooling == true or adaptive == true, padding will be ignored // if global_pooling == true or adaptive == true, padding will be ignored
if (global_pooling.getValue() || adaptive.getValue()) { // if (global_pooling.getValue() || adaptive.getValue()) {
paddings_attr = builder.getI32ArrayAttr({0, 0}); // paddings_attr = builder.getI32ArrayAttr({0, 0});
} // }
// if global_pooling == true, then we should update kernel size to input dims. // if global_pooling == true, then we should update kernel size to input dims.
if (global_pooling.getValue() == true) { if (global_pooling.getValue() == true) {
......
...@@ -72,7 +72,7 @@ int main(int argc, char** argv) { ...@@ -72,7 +72,7 @@ int main(int argc, char** argv) {
#endif #endif
context->loadAllAvailableDialects(); context->loadAllAvailableDialects();
module->dump(); // module->dump();
mlir::PassManager pm(context); mlir::PassManager pm(context);
mlir::OpPassManager& trt_pass_manager = pm.nest<mlir::FuncOp>(); mlir::OpPassManager& trt_pass_manager = pm.nest<mlir::FuncOp>();
...@@ -87,7 +87,7 @@ int main(int argc, char** argv) { ...@@ -87,7 +87,7 @@ int main(int argc, char** argv) {
std::cout << "\npass failed!\n" << std::endl; std::cout << "\npass failed!\n" << std::endl;
return 4; return 4;
} }
module->dump(); // module->dump();
::infrt::host_context::TestMlir(module.get(), &registry); ::infrt::host_context::TestMlir(module.get(), &registry);
return 0; return 0;
} }
...@@ -186,7 +186,7 @@ struct PD2TRT_Batch_Norm_Lower : public ::mlir::RewritePattern { ...@@ -186,7 +186,7 @@ struct PD2TRT_Batch_Norm_Lower : public ::mlir::RewritePattern {
create_scale_tensor_op->getLoc(), create_scale_tensor_op->getLoc(),
create_scale_tensor_op.output().getType(), create_scale_tensor_op.output().getType(),
create_scale_tensor_op.context(), create_scale_tensor_op.context(),
create_bias_tensor_op.dims(), create_scale_tensor_op.dims(),
::infrt::LayoutAttr::get(rewriter.getContext(), ::infrt::LayoutAttr::get(rewriter.getContext(),
::infrt::LayoutType::NCHW), ::infrt::LayoutType::NCHW),
create_scale_tensor_op.lod(), create_scale_tensor_op.lod(),
...@@ -206,7 +206,6 @@ struct PD2TRT_Batch_Norm_Lower : public ::mlir::RewritePattern { ...@@ -206,7 +206,6 @@ struct PD2TRT_Batch_Norm_Lower : public ::mlir::RewritePattern {
rewriter.getF32ArrayAttr(combile_bias_data)); rewriter.getF32ArrayAttr(combile_bias_data));
rewriter.replaceOp(create_bias_tensor_op, new_bias_op->getResults()); rewriter.replaceOp(create_bias_tensor_op, new_bias_op->getResults());
rewriter.setInsertionPoint(op);
trt::ScaleNdOp scaleNd_op; trt::ScaleNdOp scaleNd_op;
// resultTypes // resultTypes
::mlir::SmallVector<::mlir::Type, 4> resultTypes; ::mlir::SmallVector<::mlir::Type, 4> resultTypes;
...@@ -215,6 +214,7 @@ struct PD2TRT_Batch_Norm_Lower : public ::mlir::RewritePattern { ...@@ -215,6 +214,7 @@ struct PD2TRT_Batch_Norm_Lower : public ::mlir::RewritePattern {
} }
// attributes // attributes
rewriter.setInsertionPoint(op);
::mlir::SmallVector<::mlir::NamedAttribute, 8> attributes; ::mlir::SmallVector<::mlir::NamedAttribute, 8> attributes;
auto result = rewriter auto result = rewriter
.create<trt::ScaleNdOp>( .create<trt::ScaleNdOp>(
......
...@@ -52,6 +52,18 @@ static nvinfer1::Dims ArrayAttrToNvDims(const mlir::ArrayAttr& int_array_attr) { ...@@ -52,6 +52,18 @@ static nvinfer1::Dims ArrayAttrToNvDims(const mlir::ArrayAttr& int_array_attr) {
return dims; return dims;
} }
template <typename T>
static std::vector<T> ArrayAttrToVec(const mlir::ArrayAttr& int_array_attr) {
std::vector<T> ret;
ret.resize(int_array_attr.size());
CHECK(!int_array_attr.empty());
CHECK(int_array_attr[0].getType().isIntOrIndex());
for (size_t i = 0; i < int_array_attr.size(); ++i) {
ret[i] = int_array_attr[i].cast<mlir::IntegerAttr>().getInt();
}
return ret;
}
static nvinfer1::Weights TensorToWeights(::phi::DenseTensor* tensor) { static nvinfer1::Weights TensorToWeights(::phi::DenseTensor* tensor) {
CHECK_NOTNULL(tensor); CHECK_NOTNULL(tensor);
nvinfer1::Weights ret; nvinfer1::Weights ret;
......
...@@ -147,6 +147,10 @@ namespace tensorrt { ...@@ -147,6 +147,10 @@ namespace tensorrt {
} else if (trt::ScaleNdOp op = llvm::dyn_cast<trt::ScaleNdOp>(operation)) { } else if (trt::ScaleNdOp op = llvm::dyn_cast<trt::ScaleNdOp>(operation)) {
ScaleNdFunc( ScaleNdFunc(
op, network.get(), value_to_trt_tensor_map, value_to_tensor_map); op, network.get(), value_to_trt_tensor_map, value_to_tensor_map);
} else if (trt::ElementWiseOp op =
llvm::dyn_cast<trt::ElementWiseOp>(operation)) {
EltwiseFunc(
op, network.get(), value_to_trt_tensor_map, value_to_tensor_map);
} else { } else {
CHECK(false) << "not supported operation."; CHECK(false) << "not supported operation.";
} }
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <string> #include <string>
#include "paddle/infrt/backends/tensorrt/plugin/pool_op_plugin.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h" #include "paddle/infrt/dialect/tensorrt/trt_ops.h"
#include "paddle/infrt/kernel/tensorrt/trt_helper.h" #include "paddle/infrt/kernel/tensorrt/trt_helper.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
...@@ -78,6 +79,9 @@ inline void ConvFunc(trt::ConvolutionOp& op, // NOLINT ...@@ -78,6 +79,9 @@ inline void ConvFunc(trt::ConvolutionOp& op, // NOLINT
dims, dims,
kernel_weights, kernel_weights,
bias_weights); bias_weights);
layer->setPaddingNd(ArrayAttrToNvDims(op.paddings()));
layer->setStrideNd(ArrayAttrToNvDims(op.strides()));
CHECK_NOTNULL(layer); CHECK_NOTNULL(layer);
mlir::Value out_repr = op.output_tensor(); mlir::Value out_repr = op.output_tensor();
nvinfer1::ITensor* out_tensor = layer->getOutput(0); nvinfer1::ITensor* out_tensor = layer->getOutput(0);
...@@ -90,8 +94,8 @@ inline void PoolFunc(trt::PoolingOp& op, // NOLINT ...@@ -90,8 +94,8 @@ inline void PoolFunc(trt::PoolingOp& op, // NOLINT
ValueToTensorMap& value_to_tensor_map) { // NOLINT ValueToTensorMap& value_to_tensor_map) { // NOLINT
mlir::Value input_tensor_repr = op.input_tensor(); mlir::Value input_tensor_repr = op.input_tensor();
nvinfer1::ITensor* input_itensor = value_to_trt_tensor_map[input_tensor_repr]; nvinfer1::ITensor* input_itensor = value_to_trt_tensor_map[input_tensor_repr];
// nvinfer1::Dims input_shape = input_itensor->getDimensions(); nvinfer1::Dims input_shape = input_itensor->getDimensions();
// int input_dims = input_shape.nbDims; int input_dims = input_shape.nbDims;
auto padding_mode = op.padding_mode(); auto padding_mode = op.padding_mode();
auto pool_type = op.pool_type(); auto pool_type = op.pool_type();
...@@ -109,7 +113,35 @@ inline void PoolFunc(trt::PoolingOp& op, // NOLINT ...@@ -109,7 +113,35 @@ inline void PoolFunc(trt::PoolingOp& op, // NOLINT
if (adaptive) { if (adaptive) {
// TODO(Inference) // TODO(Inference)
CHECK(false) << "Not supported adaptive pool"; // CHECK(false) << "Not supported adaptive pool";
std::vector<int> input_shape_v;
for (int i = 0; i < input_dims; i++) {
input_shape_v.push_back(input_shape.d[i]);
}
auto paddings_val = ArrayAttrToVec<int>(paddings);
std::vector<int> real_paddings = paddings_val;
for (int i = 0; i < 2; ++i) {
int copy_pad = *(paddings_val.begin() + i);
real_paddings.insert(real_paddings.begin() + 2 * i + 1, copy_pad);
}
auto* plugin = new backends::tensorrt::plugin::PoolPlugin(
false,
backends::tensorrt::plugin::PoolPlugin::PoolType::avg,
adaptive,
exclusive,
ArrayAttrToVec<int>(ksize),
ArrayAttrToVec<int>(strides),
paddings_val,
input_shape_v,
real_paddings);
auto* layer = network->addPluginV2(&input_itensor, 1, *plugin);
mlir::Value out_repr = op.output_tensor();
nvinfer1::ITensor* out_tensor = layer->getOutput(0);
value_to_trt_tensor_map[out_repr] = out_tensor;
return;
} }
nvinfer1::Dims window_size = ArrayAttrToNvDims(ksize); nvinfer1::Dims window_size = ArrayAttrToNvDims(ksize);
...@@ -136,19 +168,41 @@ inline void FcFunc(trt::FullyConnectedOp& op, // NOLINT ...@@ -136,19 +168,41 @@ inline void FcFunc(trt::FullyConnectedOp& op, // NOLINT
mlir::Value input_tensor_repr = op.input_tensor(); mlir::Value input_tensor_repr = op.input_tensor();
CHECK(value_to_trt_tensor_map.count(input_tensor_repr)); CHECK(value_to_trt_tensor_map.count(input_tensor_repr));
nvinfer1::ITensor* input_itensor = value_to_trt_tensor_map[input_tensor_repr];
nvinfer1::Dims input_shape = input_itensor->getDimensions();
int input_dims = input_shape.nbDims;
CHECK_EQ(input_dims, 1) << "Now we only support 2-d input.";
// TODO(wilber): We should place the logic to ir. Now only support 2-d input
// and we reshape to 4-d.
nvinfer1::Dims reshape_before_fc_dim;
reshape_before_fc_dim.nbDims = input_dims + 2;
// padding shape "* x q x 1 x 1"
for (int i = 0; i < reshape_before_fc_dim.nbDims; i++) {
reshape_before_fc_dim.d[i] = 1;
}
reshape_before_fc_dim.d[0] = input_shape.d[0];
auto* reshape_before_fc_layer = network->addShuffle(*input_itensor);
reshape_before_fc_layer->setReshapeDimensions(reshape_before_fc_dim);
auto* reshape_itensor = reshape_before_fc_layer->getOutput(0);
auto kernel_weights = auto kernel_weights =
TensorToWeights(value_to_tensor_map[op.kernel_weights()]); TensorToWeights(value_to_tensor_map[op.kernel_weights()]);
auto bias_weights = TensorToWeights(value_to_tensor_map[op.bias_weights()]); auto bias_weights = TensorToWeights(value_to_tensor_map[op.bias_weights()]);
int out_channel_num = op.out_channel_num(); int out_channel_num = op.out_channel_num();
auto* layer = auto* layer = network->addFullyConnected(
network->addFullyConnected(*value_to_trt_tensor_map[input_tensor_repr], *reshape_itensor, out_channel_num, kernel_weights, bias_weights);
out_channel_num,
kernel_weights, // TODO(wilber): fix.
bias_weights); nvinfer1::Dims reshape_after_fc_dim;
reshape_after_fc_dim.nbDims = 1;
reshape_after_fc_dim.d[0] = layer->getOutput(0)->getDimensions().d[0];
auto* reshape_after_fc_layer = network->addShuffle(*layer->getOutput(0));
reshape_after_fc_layer->setReshapeDimensions(reshape_after_fc_dim);
mlir::Value out_repr = op.output_tensor(); mlir::Value out_repr = op.output_tensor();
nvinfer1::ITensor* out_tensor = layer->getOutput(0); nvinfer1::ITensor* out_tensor = reshape_after_fc_layer->getOutput(0);
value_to_trt_tensor_map[out_repr] = out_tensor; value_to_trt_tensor_map[out_repr] = out_tensor;
} }
...@@ -159,14 +213,12 @@ inline void ShuffleFunc(trt::ShuffleOp& op, // NOLINT ...@@ -159,14 +213,12 @@ inline void ShuffleFunc(trt::ShuffleOp& op, // NOLINT
mlir::Value input_tensor_repr = op.input_tensor(); mlir::Value input_tensor_repr = op.input_tensor();
nvinfer1::ITensor* input = value_to_trt_tensor_map[input_tensor_repr]; nvinfer1::ITensor* input = value_to_trt_tensor_map[input_tensor_repr];
int dims = input->getDimensions().nbDims; int dims = input->getDimensions().nbDims;
int start_axis = op.start_axis();
int start_axis = op.start_axisAttr().getInt(); int stop_axis = op.stop_axis();
int stop_axis = op.start_axisAttr().getInt();
nvinfer1::IShuffleLayer* layer = nullptr; nvinfer1::IShuffleLayer* layer = nullptr;
if (start_axis < 0) start_axis += dims + 1; if (start_axis < 0) start_axis += dims + 1;
if (stop_axis < 0) stop_axis += dims + 1; if (stop_axis < 0) stop_axis += dims + 1;
int dim_prod = 1; int dim_prod = 1;
nvinfer1::Dims flatten_dim; nvinfer1::Dims flatten_dim;
flatten_dim.nbDims = dims - (stop_axis - start_axis); flatten_dim.nbDims = dims - (stop_axis - start_axis);
...@@ -185,7 +237,6 @@ inline void ShuffleFunc(trt::ShuffleOp& op, // NOLINT ...@@ -185,7 +237,6 @@ inline void ShuffleFunc(trt::ShuffleOp& op, // NOLINT
layer = network->addShuffle(*value_to_trt_tensor_map[input_tensor_repr]); layer = network->addShuffle(*value_to_trt_tensor_map[input_tensor_repr]);
CHECK_NOTNULL(layer); CHECK_NOTNULL(layer);
layer->setReshapeDimensions(flatten_dim); layer->setReshapeDimensions(flatten_dim);
for (size_t i = 0; i < op->getNumResults(); ++i) { for (size_t i = 0; i < op->getNumResults(); ++i) {
nvinfer1::ITensor* out_tensor = layer->getOutput(i); nvinfer1::ITensor* out_tensor = layer->getOutput(i);
mlir::Value out_value = op->getResult(i); mlir::Value out_value = op->getResult(i);
...@@ -222,6 +273,30 @@ inline void ScaleNdFunc(trt::ScaleNdOp& op, // NOLINT ...@@ -222,6 +273,30 @@ inline void ScaleNdFunc(trt::ScaleNdOp& op, // NOLINT
value_to_trt_tensor_map[out_value] = out_tensor; value_to_trt_tensor_map[out_value] = out_tensor;
} }
} }
inline void EltwiseFunc(trt::ElementWiseOp& op, // NOLINT
nvinfer1::INetworkDefinition* network,
ValueToITensorMap& value_to_trt_tensor_map, // NOLINT
ValueToTensorMap& value_to_tensor_map) { // NOLINT
mlir::Value input1_tensor_repr = op.input1();
mlir::Value input2_tensor_repr = op.input2();
nvinfer1::ITensor* input1 = value_to_trt_tensor_map[input1_tensor_repr];
nvinfer1::ITensor* input2 = value_to_trt_tensor_map[input2_tensor_repr];
auto eltwise_operation = op.elementwise_operation();
auto* layer = network->addElementWise(
*input1,
*input2,
static_cast<nvinfer1::ElementWiseOperation>(eltwise_operation));
CHECK_NOTNULL(layer);
for (size_t i = 0; i < op->getNumResults(); ++i) {
nvinfer1::ITensor* out_tensor = layer->getOutput(i);
mlir::Value out_value = op->getResult(i);
value_to_trt_tensor_map[out_value] = out_tensor;
}
}
} // namespace tensorrt } // namespace tensorrt
} // namespace kernel } // namespace kernel
} // namespace infrt } // namespace infrt
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册