未验证 提交 b7a86e92 编写于 作者: Z Zhaolong Xing 提交者: GitHub

fix dy shape bug in trt7.1 (#26273)

test=develop
上级 c45481d7
......@@ -83,7 +83,12 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<T>& shape, std::string input,
} else if (shape.size() == 3UL) {
return nvinfer1::Dims3(shape[0], shape[1], shape[2]);
}
return nvinfer1::Dims4(shape[0], shape[1], 1, 1);
nvinfer1::Dims dims;
dims.nbDims = shape.size();
for (size_t i = 0; i < shape.size(); i++) {
dims.d[i] = shape[i];
}
return dims;
}
}
} // NOLINT
......
......@@ -76,6 +76,16 @@ nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic<T>::getOutputDimensions(
return ret;
}
template <typename T>
void EmbEltwiseLayernormPluginDynamic<T>::terminate() {
for (auto ptr : embs_gpu_) {
if (ptr) cudaFree(ptr);
}
if (bias_gpu_) cudaFree(bias_gpu_);
if (scale_gpu_) cudaFree(scale_gpu_);
}
template <typename T>
bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
......@@ -153,7 +163,7 @@ int EmbEltwiseLayernormPluginDynamic<T>::enqueue(
int64_t *emb_ptr_gpu_d =
emb_ptr_tensor.mutable_data<int64_t>(platform::CUDAPlace(device_id));
std::vector<int64_t> in_ptr, emb_ptr;
std::vector<uintptr_t> in_ptr, emb_ptr;
for (int i = 0; i < input_num; i++) {
in_ptr.push_back(reinterpret_cast<uintptr_t>(inputs[i]));
emb_ptr.push_back(reinterpret_cast<uintptr_t>(embs_gpu_[i]));
......
......@@ -81,9 +81,13 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
}
nvinfer1::IPluginV2DynamicExt* clone() const override {
return new EmbEltwiseLayernormPluginDynamic(
auto ptr = new EmbEltwiseLayernormPluginDynamic(
embs_, bias_, scale_, emb_sizes_, bias_size_, scale_size_, hidden_size_,
eps_);
ptr->embs_gpu_ = embs_gpu_;
ptr->bias_gpu_ = bias_gpu_;
ptr->scale_gpu_ = scale_gpu_;
return ptr;
}
const char* getPluginType() const override {
......@@ -111,6 +115,7 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
return sum_num;
}
void terminate() override;
void serialize(void* buffer) const override {
// SerializeValue(&buffer, with_fp16_);
SerializeValue(&buffer, emb_sizes_);
......
......@@ -80,6 +80,12 @@ int PReluPlugin::enqueue(int batch_size, const void *const *inputs,
#if IS_TRT_VERSION_GE(6000)
void PReluPluginDynamic::terminate() {
if (p_gpu_weight_) {
cudaFree(p_gpu_weight_);
}
}
int PReluPluginDynamic::initialize() {
cudaMalloc(&p_gpu_weight_, sizeof(float) * weight_.size());
cudaMemcpy(p_gpu_weight_, weight_.data(), weight_.size() * sizeof(float),
......
......@@ -102,12 +102,15 @@ class PReluPluginDynamic : public DynamicPluginTensorRT {
}
~PReluPluginDynamic() { cudaFree(p_gpu_weight_); }
nvinfer1::IPluginV2DynamicExt* clone() const override {
return new PReluPluginDynamic(weight_.data(), weight_.size(), mode_);
auto ptr = new PReluPluginDynamic(weight_.data(), weight_.size(), mode_);
ptr->p_gpu_weight_ = p_gpu_weight_;
return ptr;
}
const char* getPluginType() const override { return "prelu_plugin"; }
int getNbOutputs() const override { return 1; }
int initialize() override;
void terminate() override;
size_t getSerializationSize() const override;
void serialize(void* buffer) const override;
......
......@@ -51,8 +51,11 @@ class SkipLayerNormPluginDynamic : public DynamicPluginTensorRT {
}
nvinfer1::IPluginV2DynamicExt* clone() const override {
return new SkipLayerNormPluginDynamic(
auto ptr = new SkipLayerNormPluginDynamic(
bias_.data(), scale_.data(), bias_size_, scale_size_, eps_, ban_fp16_);
ptr->bias_gpu_ = bias_gpu_;
ptr->scale_gpu_ = bias_gpu_;
return ptr;
}
const char* getPluginType() const override { return "skip_layernorm_plugin"; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册