未验证 提交 6cd7609a 编写于 作者: W Wangzheee 提交者: GitHub

fix gpu mem alloc: use phi::memory_utils::Alloc (#53721)

上级 13cdaab6
...@@ -18,6 +18,9 @@ ...@@ -18,6 +18,9 @@
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
...@@ -381,18 +384,16 @@ int FusedTokenPrunePluginDynamic::enqueue( ...@@ -381,18 +384,16 @@ int FusedTokenPrunePluginDynamic::enqueue(
// 3. compute new pos id // 3. compute new pos id
// Determine temporary device storage requirements // Determine temporary device storage requirements
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
cub::DeviceScan::ExclusiveSum(d_temp_storage, cub::DeviceScan::ExclusiveSum(
temp_storage_bytes, NULL, temp_storage_bytes, pruned_token_lengths_, output3, B + 1);
pruned_token_lengths_,
output3,
B + 1);
// Allocate temporary storage // Allocate temporary storage
cudaMalloc(&d_temp_storage, temp_storage_bytes);
platform::CUDAPlace place(platform::GetCurrentDeviceId());
auto d_temp_storage = phi::memory_utils::Alloc(place, temp_storage_bytes);
// Run exclusive prefix sum // Run exclusive prefix sum
cub::DeviceScan::ExclusiveSum(d_temp_storage, cub::DeviceScan::ExclusiveSum(d_temp_storage->ptr(),
temp_storage_bytes, temp_storage_bytes,
pruned_token_lengths_, pruned_token_lengths_,
output3, output3,
......
...@@ -12,8 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_output_convert_plugin.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/common/memory_utils.h"
#include "cub/cub.cuh" #include "cub/cub.cuh"
#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_output_convert_plugin.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -178,16 +182,15 @@ int TransformerInputConvertPlugin::enqueue( ...@@ -178,16 +182,15 @@ int TransformerInputConvertPlugin::enqueue(
const int32_t HiddenSize = input0_desc.dims.d[2]; // hidden size const int32_t HiddenSize = input0_desc.dims.d[2]; // hidden size
// Determine temporary device storage requirements // Determine temporary device storage requirements
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
cub::DeviceScan::ExclusiveSum( cub::DeviceScan::ExclusiveSum(
d_temp_storage, temp_storage_bytes, input1, output2, B + 1); NULL, temp_storage_bytes, input1, output2, B + 1);
// Allocate temporary storage // Allocate temporary storage
cudaMalloc(&d_temp_storage, temp_storage_bytes); platform::CUDAPlace place(platform::GetCurrentDeviceId());
auto d_temp_storage = phi::memory_utils::Alloc(place, temp_storage_bytes);
// Run exclusive prefix sum // Run exclusive prefix sum
cub::DeviceScan::ExclusiveSum( cub::DeviceScan::ExclusiveSum(
d_temp_storage, temp_storage_bytes, input1, output2, B + 1); d_temp_storage->ptr(), temp_storage_bytes, input1, output2, B + 1);
const int32_t vector_length = HiddenSize; const int32_t vector_length = HiddenSize;
int32_t num_threads; int32_t num_threads;
if (vector_length < 1024) { if (vector_length < 1024) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册