提交 6a87f61b 编写于 作者: R ReeseWang

add trt stack op, test=develop

上级 f1b1c753
......@@ -1035,4 +1035,5 @@ USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm);
USE_TRT_CONVERTER(skip_layernorm);
USE_TRT_CONVERTER(slice);
USE_TRT_CONVERTER(scale);
USE_TRT_CONVERTER(stack);
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h"
#include <iostream>
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* Stack converter from fluid to tensorRT.
*/
class StackOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert fluid stack op to tensorrt stack layer";
framework::OpDesc op_desc(op, nullptr);
auto input = op_desc.Input("X");
int input_num = input.size();
nvinfer1::ITensor** inputs =
(nvinfer1::ITensor**)malloc(input_num * sizeof(nvinfer1::ITensor*));
for (int i = 0; i < input_num; ++i) {
inputs[i] = engine_->GetITensor(input[i]);
}
auto idim = inputs[0]->getDimensions();
std::cerr << "Stack input: " << idim.nbDims << " " << idim.d[0] << " "
<< idim.d[1] << " " << idim.d[2] << std::endl;
int axis = BOOST_GET_CONST(int, op_desc.GetAttr("axis"));
if (axis < 0) {
axis = axis + inputs[0]->getDimensions().nbDims + 1;
}
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000)
plugin::StackPluginDynamic* plugin =
new plugin::StackPluginDynamic(axis, input_num);
layer = engine_->AddPluginV2(inputs, input_num, plugin);
assert(layer != nullptr);
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"));
#endif
} else {
PADDLE_THROW(platform::errors::Fatal(
"You are running the Ernie(Bert) model in static"
"shape mode, which is not supported for the time being.\n"
"You can use the config.SetTRTDynamicShapeInfo(...) interface"
" to set the shape information to run the dynamic shape mode."));
}
auto output_name = op_desc.Output("Y").front();
RreplenishLayerAndOutput(layer, "stack", {output_name}, test_mode);
free(inputs);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(stack, StackOpConverter);
......@@ -55,8 +55,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"fc",
"relu6",
"concat"};
std::unordered_set<std::string> teller_set{
"mul",
std::unordered_set<std::string> teller_set{"mul",
"conv2d",
"pool2d",
"relu",
......@@ -82,7 +81,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"gelu",
"layer_norm",
"scale",
};
"slice",
"stack"};
};
bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc,
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
size_t StackPluginDynamic::getSerializationSize() const { return 0; }
void StackPluginDynamic::serialize(void* buffer) const {}
nvinfer1::DimsExprs StackPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) {
nvinfer1::DimsExprs output(inputs[0]);
output.nbDims = inputs[0].nbDims + 1;
for (int i = inputs[0].nbDims; i > axis_; --i) {
output.d[i] = inputs[0].d[i - 1];
}
output.d[axis_] = expr_builder.constant(nb_inputs);
return output;
}
bool StackPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs,
int nb_outputs) {
PADDLE_ENFORCE_NOT_NULL(
in_out, platform::errors::InvalidArgument(
"The input of stack plugin should not be nullptr."));
PADDLE_ENFORCE_LT(
pos, nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos, nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc& in = in_out[pos];
if (pos == 0) {
#ifdef SUPPORTS_CUDA_FP16
return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#else
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#endif
}
const nvinfer1::PluginTensorDesc& prev = in_out[pos - 1];
// output
return in.type == prev.type && in.format == prev.format;
}
nvinfer1::DataType StackPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType* input_types, int nb_inputs) const {
PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
"The index should be equal to 0"));
return input_types[0];
}
template <typename T>
__global__ void StackKernel(const T* const* input, T* output, int num_stack,
int base_unit) {
int stack_id = blockIdx.x;
int lead_id = blockIdx.y;
for (int i = threadIdx.x; i < base_unit; i += blockDim.x) {
output[lead_id * num_stack * base_unit + stack_id * base_unit + i] =
input[stack_id][lead_id * base_unit + i];
}
}
int StackPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs, void* const* outputs,
void* workspace, cudaStream_t stream) {
auto input_dims = input_desc[0].dims; // (batch, seq, seq)
auto out_dims = output_desc[0].dims; // (batch, num_head, seq, seq)
auto out_num_dims = out_dims.nbDims;
int base_unit = 1;
for (int i = axis_ + 1; i < out_num_dims; ++i) {
PADDLE_ENFORCE_GT(out_dims.d[i], 0,
platform::errors::InvalidArgument(
"Input dimensions should be greater than 0"));
base_unit *= out_dims.d[i];
}
int lead_unit = 1;
for (int i = 0; i < axis_; ++i) {
PADDLE_ENFORCE_GT(out_dims.d[i], 0,
platform::errors::InvalidArgument(
"Input dimensions should be greater than 0"));
lead_unit *= out_dims.d[i];
}
cudaMemcpyAsync(reinterpret_cast<void*>(in_ptr_gpu_),
reinterpret_cast<const void* const>(inputs),
sizeof(void*) * out_dims.d[axis_], cudaMemcpyHostToDevice,
stream);
int num_stacks = out_dims.d[axis_];
dim3 num_blocks(num_stacks, lead_unit);
int num_threads = 256;
auto infer_type = input_desc[0].type;
if (infer_type == nvinfer1::DataType::kFLOAT) {
float* output = static_cast<float*>(outputs[0]);
StackKernel<float><<<num_blocks, num_threads, 0, stream>>>(
reinterpret_cast<const float* const*>(in_ptr_gpu_), output, num_stacks,
base_unit);
} else if (infer_type == nvinfer1::DataType::kHALF) {
#ifdef SUPPORTS_CUDA_FP16
__half* output = static_cast<__half*>(outputs[0]);
StackKernel<__half><<<num_blocks, num_threads, 0, stream>>>(
reinterpret_cast<const __half* const*>(in_ptr_gpu_), output, num_stacks,
base_unit);
#else
PADDLE_THROW(platform::errors::Fatal(
"The cuda archs you specific should greater than 600."));
#endif
} else {
PADDLE_THROW(
platform::errors::Fatal("The Stack TRT Plugin's input type only "
"support float or half currently."));
}
return cudaGetLastError() != cudaSuccess;
}
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <cassert>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#if IS_TRT_VERSION_GE(6000)
class StackPluginDynamic : public DynamicPluginTensorRT {
public:
StackPluginDynamic(int axis, int num_stack)
: axis_(axis), num_stack_(num_stack) {
int device_id;
cudaGetDevice(&device_id);
in_ptr_tensor_.Resize({num_stack});
in_ptr_gpu_ =
in_ptr_tensor_.mutable_data<int64_t>(platform::CUDAPlace(device_id));
}
StackPluginDynamic(void const* serialData, size_t serialLength) {}
~StackPluginDynamic() {}
nvinfer1::IPluginV2DynamicExt* clone() const override {
return new StackPluginDynamic(axis_, num_stack_);
}
const char* getPluginType() const override { return "stack_plugin"; }
int getNbOutputs() const override { return 1; }
int initialize() override { return 0; }
size_t getSerializationSize() const override;
void serialize(void* buffer) const override;
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs, int nbOutputs) override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) override {}
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const override {
return 0;
}
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const override;
void destroy() override { delete this; }
private:
int axis_;
int num_stack_;
framework::Tensor in_ptr_tensor_;
int64_t* in_ptr_gpu_;
};
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册