未验证 提交 c5a1e49c 编写于 作者: Z Zhang Jun 提交者: GitHub

add swish using TensorRT layer (#44561)

* update

* empty commit

* update

* update

* update
上级 3e170163
......@@ -482,10 +482,18 @@ class OpConverter {
template <typename T>
// Create and add Multi-D constant float/int32 layer
nvinfer1::ITensor* AddConstantLayer(const T* data,
const std::vector<int32_t>& weight_dims,
const std::string& weight_name) {
nvinfer1::Dims shape,
const std::string& weight_name = "") {
if (!(std::is_same<T, float>::value ||
std::is_same<T, platform::float16>::value ||
std::is_same<T, int32_t>::value)) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported data type (%s) for TensorRT AddConstantLayer, only "
"supports float, half or int32_t."));
}
int data_size = std::accumulate(
weight_dims.begin(), weight_dims.end(), 1, std::multiplies<int>());
shape.d, shape.d + shape.nbDims, 1, std::multiplies<int>());
std::unique_ptr<framework::Tensor> tmp_tensor(new framework::Tensor());
tmp_tensor->Resize({data_size});
auto* tmp_data = tmp_tensor->mutable_data<T>(platform::CPUPlace());
......@@ -502,12 +510,9 @@ class OpConverter {
TensorRTEngine::Weight weight{trt_dtype,
static_cast<void*>(tmp_data),
static_cast<size_t>(data_size)};
nvinfer1::Dims trt_dims;
trt_dims.nbDims = weight_dims.size();
for (size_t i = 0; i < weight_dims.size(); i++)
trt_dims.d[i] = weight_dims[i];
auto const_layer =
TRT_ENGINE_ADD_LAYER(engine_, Constant, trt_dims, weight.get());
TRT_ENGINE_ADD_LAYER(engine_, Constant, shape, weight.get());
return const_layer->getOutput(0);
}
......@@ -516,6 +521,14 @@ class OpConverter {
nvinfer1::ITensor* Add1DConstantLayer(const std::vector<T>& data,
const std::string& weight_name = "",
bool scalar = false) {
if (!(std::is_same<T, float>::value ||
std::is_same<T, platform::float16>::value ||
std::is_same<T, int32_t>::value)) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported data type (%s) for TensorRT AddConstantLayer, only "
"supports float, half or int32_t."));
}
std::unique_ptr<framework::Tensor> tmp_tensor(new framework::Tensor());
int data_size = data.size();
tmp_tensor->Resize({data_size});
......@@ -549,12 +562,13 @@ class OpConverter {
return Add1DConstantLayer(tmp_data, weight_name, scalar);
}
nvinfer1::ITensor* Add1DConstantLayer(int32_t data,
template <typename T>
nvinfer1::ITensor* Add1DConstantLayer(T data,
const std::string& weight_name = "",
bool scalar = false) {
std::vector<int> tmp_data;
tmp_data.push_back(data);
return Add1DConstantLayer(tmp_data, weight_name, scalar);
std::vector<T> input_data;
input_data.push_back(data);
return Add1DConstantLayer(input_data, weight_name, scalar);
}
// For cases when input is not middle-tensor , but persistable tensor
......
......@@ -91,7 +91,7 @@ class SplitOpConverter : public OpConverter {
start_point += output_lengths[i];
} else {
this_len_tensor = avg_len_tensor;
auto* i_tensor = Add1DConstantLayer(i);
auto* i_tensor = Add1DConstantLayer(static_cast<int>(i));
start_point_tensor = Prod(i_tensor, avg_len_tensor);
}
......
......@@ -63,17 +63,22 @@ class SwishOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000)
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
plugin::SwishPluginDynamic* plugin =
new plugin::SwishPluginDynamic(beta, with_fp16);
layer = engine_->AddDynamicPlugin(&input, input_num, plugin);
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"));
#endif
int32_t rank = input->getDimensions().nbDims;
nvinfer1::Dims constant_shape;
constant_shape.nbDims = rank;
std::fill(constant_shape.d, constant_shape.d + rank, 1);
std::vector<float> weight_data{beta};
auto* beta_data = AddConstantLayer(weight_data.data(), constant_shape);
auto* input_mul_with_beta = Prod(beta_data, input);
auto* sigmoid = TRT_ENGINE_ADD_LAYER(engine_,
Activation,
*input_mul_with_beta,
nvinfer1::ActivationType::kSIGMOID);
layer = TRT_ENGINE_ADD_LAYER(engine_,
ElementWise,
*input,
*(sigmoid->getOutput(0)),
nvinfer1::ElementWiseOperation::kPROD);
} else {
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册