未验证 提交 0fd70d71 编写于 作者: W Wangzheee 提交者: GitHub

fix_matmul_op_int8_plugin (#37525)

上级 2a905f6b
...@@ -299,13 +299,13 @@ void MatmulPlugin::configurePlugin(const nvinfer1::PluginTensorDesc* inputs, ...@@ -299,13 +299,13 @@ void MatmulPlugin::configurePlugin(const nvinfer1::PluginTensorDesc* inputs,
matmulDesc_, CUBLASLT_MATMUL_DESC_POINTER_MODE, &matmul_model, matmulDesc_, CUBLASLT_MATMUL_DESC_POINTER_MODE, &matmul_model,
sizeof(matmul_model))); sizeof(matmul_model)));
float alpha_tem[n_]; std::vector<float> alpha_tem(n_, 0);
for (int i = 0; i < n_; i++) { for (int i = 0; i < n_; i++) {
alpha_tem[i] = alpha_ * inscale_0 * inscale_1 / outscale; alpha_tem[i] = alpha_ * inscale_0 * inscale_1 / outscale;
} }
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
cudaMalloc((void**)&alpha_scale_, n_ * sizeof(float))); cudaMalloc((void**)&alpha_scale_, n_ * sizeof(float)));
cudaMemcpyAsync(alpha_scale_, alpha_tem, n_ * sizeof(float), cudaMemcpyAsync(alpha_scale_, &alpha_tem[0], n_ * sizeof(float),
cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
float zero_tem = zero; float zero_tem = zero;
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
...@@ -624,13 +624,13 @@ void MatmulPluginDynamic::configurePlugin( ...@@ -624,13 +624,13 @@ void MatmulPluginDynamic::configurePlugin(
sizeof(int8_t) * ((m_max + 32 - 1) / 32 * 32) / 32 * ldctransform)); sizeof(int8_t) * ((m_max + 32 - 1) / 32 * 32) / 32 * ldctransform));
if (type_ == nvinfer1::DataType::kINT8) { if (type_ == nvinfer1::DataType::kINT8) {
float alpha_tem[n_max]; std::vector<float> alpha_tem(n_max, 0);
for (int i = 0; i < n_max; i++) { for (int i = 0; i < n_max; i++) {
alpha_tem[i] = alpha_ * inscale_0 * inscale_1 / outscale; alpha_tem[i] = alpha_ * inscale_0 * inscale_1 / outscale;
} }
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
cudaMalloc((void**)&alpha_scale_, n_max * sizeof(float))); cudaMalloc((void**)&alpha_scale_, n_max * sizeof(float)));
cudaMemcpyAsync(alpha_scale_, alpha_tem, n_max * sizeof(float), cudaMemcpyAsync(alpha_scale_, &alpha_tem[0], n_max * sizeof(float),
cudaMemcpyHostToDevice); cudaMemcpyHostToDevice);
float zero_tem = zero; float zero_tem = zero;
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册