diff --git a/paddle/fluid/inference/tensorrt/plugin/lookup_table.cu b/paddle/fluid/inference/tensorrt/plugin/lookup_table.cu index 41886d24aa144f593bc231790d111ea08b116794..31d599bd2a67cdc1b72e7ad1a2a8c7e04029844e 100644 --- a/paddle/fluid/inference/tensorrt/plugin/lookup_table.cu +++ b/paddle/fluid/inference/tensorrt/plugin/lookup_table.cu @@ -98,9 +98,11 @@ LookupTablePluginDynamic::LookupTablePluginDynamic(void const* data, deserialize_value(&data, &length, &mWeightSize); deserialize_value(&data, &length, &mWeightWidth); char const* d = static_cast(data); - cudaMalloc(&mWeightDev, mWeightSize * sizeof(mType)); - cudaMemcpy( - mWeightDev, d, mWeightSize * sizeof(mType), cudaMemcpyHostToDevice); + cudaMalloc(&mWeightDev, mWeightSize * getElementSize(mType)); + cudaMemcpy(mWeightDev, + d, + mWeightSize * getElementSize(mType), + cudaMemcpyHostToDevice); } // IPluginV2DynamicExt Methods