未验证 提交 6cd7609a 编写于 作者: W Wangzheee 提交者: GitHub

fix gpu mem alloc: use phi::memory_utils::Alloc (#53721)

上级 13cdaab6
......@@ -18,6 +18,9 @@
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h"
......@@ -381,18 +384,16 @@ int FusedTokenPrunePluginDynamic::enqueue(
// 3. compute new pos id
// Determine temporary device storage requirements
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cub::DeviceScan::ExclusiveSum(d_temp_storage,
temp_storage_bytes,
pruned_token_lengths_,
output3,
B + 1);
cub::DeviceScan::ExclusiveSum(
NULL, temp_storage_bytes, pruned_token_lengths_, output3, B + 1);
// Allocate temporary storage
cudaMalloc(&d_temp_storage, temp_storage_bytes);
platform::CUDAPlace place(platform::GetCurrentDeviceId());
auto d_temp_storage = phi::memory_utils::Alloc(place, temp_storage_bytes);
// Run exclusive prefix sum
cub::DeviceScan::ExclusiveSum(d_temp_storage,
cub::DeviceScan::ExclusiveSum(d_temp_storage->ptr(),
temp_storage_bytes,
pruned_token_lengths_,
output3,
......
......@@ -12,8 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_output_convert_plugin.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/common/memory_utils.h"
#include "cub/cub.cuh"
#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_output_convert_plugin.h"
namespace paddle {
namespace inference {
......@@ -178,16 +182,15 @@ int TransformerInputConvertPlugin::enqueue(
const int32_t HiddenSize = input0_desc.dims.d[2]; // hidden size
// Determine temporary device storage requirements
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cub::DeviceScan::ExclusiveSum(
d_temp_storage, temp_storage_bytes, input1, output2, B + 1);
NULL, temp_storage_bytes, input1, output2, B + 1);
// Allocate temporary storage
cudaMalloc(&d_temp_storage, temp_storage_bytes);
platform::CUDAPlace place(platform::GetCurrentDeviceId());
auto d_temp_storage = phi::memory_utils::Alloc(place, temp_storage_bytes);
// Run exclusive prefix sum
cub::DeviceScan::ExclusiveSum(
d_temp_storage, temp_storage_bytes, input1, output2, B + 1);
d_temp_storage->ptr(), temp_storage_bytes, input1, output2, B + 1);
const int32_t vector_length = HiddenSize;
int32_t num_threads;
if (vector_length < 1024) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册