未验证 提交 cc2059a0 编写于 作者: J jiangfan06 提交者: GitHub

[XPU] Fix the topk, set_value ops that using temporary tensors avoiding the...

[XPU] Fix the topk, set_value ops that using temporary tensors avoiding the memory overlaps during multi-stream inference (#54851)
上级 c6bd9fb8
...@@ -2621,7 +2621,9 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -2621,7 +2621,9 @@ Scope* OperatorWithKernel::PrepareData(
if (kernel_type_for_var.backend() == phi::Backend::GPU || if (kernel_type_for_var.backend() == phi::Backend::GPU ||
kernel_type_for_var.backend() == phi::Backend::GPUDNN || kernel_type_for_var.backend() == phi::Backend::GPUDNN ||
new_expected_kernel_key->backend() == phi::Backend::GPU || new_expected_kernel_key->backend() == phi::Backend::GPU ||
new_expected_kernel_key->backend() == phi::Backend::GPUDNN) { new_expected_kernel_key->backend() == phi::Backend::GPUDNN ||
kernel_type_for_var.backend() == phi::Backend::XPU ||
new_expected_kernel_key->backend() == phi::Backend::XPU) {
new_scope = TryCreateTransferScope( new_scope = TryCreateTransferScope(
kernel_type_for_var, *new_expected_kernel_key, &scope); kernel_type_for_var, *new_expected_kernel_key, &scope);
enable_cache_transfer_scope_ = true; enable_cache_transfer_scope_ = true;
...@@ -2629,7 +2631,9 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -2629,7 +2631,9 @@ Scope* OperatorWithKernel::PrepareData(
} else if (kernel_type_for_var.backend() == phi::Backend::GPU || } else if (kernel_type_for_var.backend() == phi::Backend::GPU ||
kernel_type_for_var.backend() == phi::Backend::GPUDNN || kernel_type_for_var.backend() == phi::Backend::GPUDNN ||
expected_kernel_key.backend() == phi::Backend::GPU || expected_kernel_key.backend() == phi::Backend::GPU ||
expected_kernel_key.backend() == phi::Backend::GPUDNN) { expected_kernel_key.backend() == phi::Backend::GPUDNN ||
kernel_type_for_var.backend() == phi::Backend::XPU ||
expected_kernel_key.backend() == phi::Backend::XPU) {
new_scope = TryCreateTransferScope( new_scope = TryCreateTransferScope(
kernel_type_for_var, expected_kernel_key, &scope); kernel_type_for_var, expected_kernel_key, &scope);
enable_cache_transfer_scope_ = true; enable_cache_transfer_scope_ = true;
......
...@@ -780,7 +780,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -780,7 +780,8 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32, XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16, phi::DataType::FLOAT16,
phi::DataType::INT16, phi::DataType::INT16,
phi::DataType::INT32})}, phi::DataType::INT32,
phi::DataType::INT64})},
{"strided_slice_grad", {"strided_slice_grad",
XPUKernelSet({phi::DataType::FLOAT32, XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16, phi::DataType::FLOAT16,
......
...@@ -29,19 +29,18 @@ namespace phi { ...@@ -29,19 +29,18 @@ namespace phi {
template <typename T, typename XPUType> template <typename T, typename XPUType>
void XPUElementwise(const XPUContext& dev_ctx, void XPUElementwise(const XPUContext& dev_ctx,
const DenseTensor& x, const T* x_data,
const DenseTensor& y, const DDim& x_dims,
const T* y_data,
const DDim& y_dims,
int axis, int axis,
DenseTensor* z, T* z_data,
std::function<int(xpu::Context*, std::function<int(xpu::Context*,
const XPUType*, const XPUType*,
const XPUType*, const XPUType*,
XPUType*, XPUType*,
const std::vector<int>&, const std::vector<int>&,
const std::vector<int>&)> func) { const std::vector<int>&)> func) {
dev_ctx.template Alloc<T>(z);
auto x_dims = x.dims();
auto y_dims = y.dims();
int max_dim = std::max(x_dims.size(), y_dims.size()); int max_dim = std::max(x_dims.size(), y_dims.size());
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
...@@ -78,9 +77,6 @@ void XPUElementwise(const XPUContext& dev_ctx, ...@@ -78,9 +77,6 @@ void XPUElementwise(const XPUContext& dev_ctx,
y_dims_vec[i + axis] = y_dims[i]; y_dims_vec[i + axis] = y_dims[i];
} }
} }
const T* x_data = x.data<T>();
const T* y_data = y.data<T>();
T* z_data = z->data<T>();
int ret = xpu::SUCCESS; int ret = xpu::SUCCESS;
...@@ -104,6 +100,30 @@ void XPUElementwise(const XPUContext& dev_ctx, ...@@ -104,6 +100,30 @@ void XPUElementwise(const XPUContext& dev_ctx,
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "elementwise"); PADDLE_ENFORCE_XDNN_SUCCESS(ret, "elementwise");
} }
template <typename T, typename XPUType>
void XPUElementwise(const XPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* z,
std::function<int(xpu::Context*,
const XPUType*,
const XPUType*,
XPUType*,
const std::vector<int>&,
const std::vector<int>&)> func) {
dev_ctx.template Alloc<T>(z);
const DDim& x_dims = x.dims();
const DDim& y_dims = y.dims();
const T* x_data = x.data<T>();
const T* y_data = y.data<T>();
T* z_data = z->data<T>();
XPUElementwise<T, XPUType>(
dev_ctx, x_data, x_dims, y_data, y_dims, axis, z_data, func);
}
template <typename T, typename XPUType> template <typename T, typename XPUType>
void XPUElementwiseGrad(const XPUContext& dev_ctx, void XPUElementwiseGrad(const XPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
......
...@@ -38,29 +38,53 @@ void InstanceNormKernel(const Context& dev_ctx, ...@@ -38,29 +38,53 @@ void InstanceNormKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(y); dev_ctx.template Alloc<T>(y);
dev_ctx.template Alloc<float>(saved_mean); dev_ctx.template Alloc<float>(saved_mean);
dev_ctx.template Alloc<float>(saved_var); dev_ctx.template Alloc<float>(saved_var);
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
// scale // scale
const auto scale_ptr = scale.get_ptr(); const auto scale_ptr = scale.get_ptr();
const float* scale_data_fp32 = nullptr; const float* scale_data_fp32 = nullptr;
DenseTensor scale_data;
if (scale_ptr == nullptr) { if (scale_ptr == nullptr) {
scale_data.Resize({c}); float* scale_data_temp = RAII_GUARD.alloc_l3_or_gm<float>(c);
dev_ctx.template Alloc<float>(&scale_data); int r = xpu::constant<float>(dev_ctx.x_context(), scale_data_temp, c, 1.f);
phi::funcs::set_constant(dev_ctx, &scale_data, static_cast<float>(1)); PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
scale_data_fp32 = scale_data.data<float>(); scale_data_fp32 = scale_data_temp;
} else if (scale_ptr->dtype() ==
phi::CppTypeToDataType<phi::dtype::float16>::Type()) {
float* scale_data_temp =
RAII_GUARD.alloc_l3_or_gm<float>(scale_ptr->numel());
int r = xpu::cast<XPUType, float>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(scale_ptr->data<T>()),
scale_data_temp,
scale_ptr->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
scale_data_fp32 = scale_data_temp;
} else { } else {
// no need to cast // no need to cast
scale_data_fp32 = scale_ptr->data<float>(); scale_data_fp32 = scale_ptr->data<float>();
} }
// bias // bias
const float* bias_data_fp32 = nullptr; const float* bias_data_fp32 = nullptr;
const auto* bias_ptr = bias.get_ptr(); const auto* bias_ptr = bias.get_ptr();
DenseTensor bias_data;
if (bias_ptr == nullptr) { if (bias_ptr == nullptr) {
bias_data.Resize({c}); float* bias_data_temp = RAII_GUARD.alloc_l3_or_gm<float>(c);
dev_ctx.template Alloc<float>(&bias_data); int r = xpu::constant<float>(dev_ctx.x_context(), bias_data_temp, c, 1.f);
phi::funcs::set_constant(dev_ctx, &bias_data, static_cast<float>(0)); PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
bias_data_fp32 = bias_data.data<float>(); bias_data_fp32 = bias_data_temp;
} else if (bias_ptr->dtype() ==
phi::CppTypeToDataType<phi::dtype::float16>::Type()) {
float* bias_data_temp = RAII_GUARD.alloc_l3_or_gm<float>(bias_ptr->numel());
int r = xpu::cast<XPUType, float>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(bias_ptr->data<T>()),
bias_data_temp,
bias_ptr->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
bias_data_fp32 = bias_data_temp;
} else { } else {
// no need to cast
bias_data_fp32 = bias_ptr->data<float>(); bias_data_fp32 = bias_ptr->data<float>();
} }
......
...@@ -83,7 +83,7 @@ void ScatterKernel(const Context &ctx, ...@@ -83,7 +83,7 @@ void ScatterKernel(const Context &ctx,
static_cast<int>(phi::product(phi::slice_ddim(x_dims, 1, x_dims.size()))); static_cast<int>(phi::product(phi::slice_ddim(x_dims, 1, x_dims.size())));
DenseTensor indices_cpu(index.type()); DenseTensor indices_cpu(index.type());
phi::Copy(ctx, index, phi::CPUPlace(), false, &indices_cpu); phi::Copy(ctx, index, phi::CPUPlace(), true, &indices_cpu);
int r = 0; int r = 0;
if (index_type == phi::DataType::INT32) { if (index_type == phi::DataType::INT32) {
......
...@@ -73,7 +73,8 @@ inline void CheckIsDimsMatch(const DDim& first, const DDim& second) { ...@@ -73,7 +73,8 @@ inline void CheckIsDimsMatch(const DDim& first, const DDim& second) {
template <typename T, typename Context, size_t RANK> template <typename T, typename Context, size_t RANK>
void SetValueImpl(const Context& dev_ctx, void SetValueImpl(const Context& dev_ctx,
const DenseTensor& in, const DenseTensor& in,
const DenseTensor& value, const T* value_data,
const DDim& value_dims,
const IntArray& starts, const IntArray& starts,
const IntArray& ends, const IntArray& ends,
const IntArray& steps, const IntArray& steps,
...@@ -139,8 +140,9 @@ void SetValueImpl(const Context& dev_ctx, ...@@ -139,8 +140,9 @@ void SetValueImpl(const Context& dev_ctx,
in.numel()); in.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
DenseTensor slice_tensor = xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
Empty<T>(dev_ctx, IntArray{slice_dims.Get(), slice_dims.size()}); int64_t slice_numels = phi::product(slice_dims);
XPUType* slice_data = RAII_GUARD.alloc_l3_or_gm<XPUType>(slice_numels);
int in_size = in_dims.size(); int in_size = in_dims.size();
std::vector<int> starts_indices(in_size, 0); std::vector<int> starts_indices(in_size, 0);
...@@ -186,17 +188,14 @@ void SetValueImpl(const Context& dev_ctx, ...@@ -186,17 +188,14 @@ void SetValueImpl(const Context& dev_ctx,
auto slice_shape = phi::vectorize<int>(slice_dims); auto slice_shape = phi::vectorize<int>(slice_dims);
r = xpu::strided_slice(dev_ctx.x_context(), r = xpu::strided_slice(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out->data<T>()), reinterpret_cast<const XPUType*>(out->data<T>()),
reinterpret_cast<XPUType*>(slice_tensor.data<T>()), slice_data,
out_shape, out_shape,
starts_indices, starts_indices,
ends_indices, ends_indices,
strides_indices); strides_indices);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_slice"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_slice");
r = xpu::constant(dev_ctx.x_context(), r = xpu::constant(dev_ctx.x_context(), slice_data, slice_numels, XPUType(0));
reinterpret_cast<XPUType*>(slice_tensor.data<T>()),
slice_tensor.numel(),
XPUType(0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
// Step 2: Set a tensor with the same shape as out tensor. And its data at // Step 2: Set a tensor with the same shape as out tensor. And its data at
...@@ -216,8 +215,7 @@ void SetValueImpl(const Context& dev_ctx, ...@@ -216,8 +215,7 @@ void SetValueImpl(const Context& dev_ctx,
// If do broadcasting on Tensor with shape [3] and [3], the result's shape // If do broadcasting on Tensor with shape [3] and [3], the result's shape
// is [3], which is right. // is [3], which is right.
slice_tensor.Resize(slice_dims_for_assign); CheckIsDimsMatch(slice_dims_for_assign, value_dims);
CheckIsDimsMatch(slice_dims_for_assign, value.dims());
// XPUElementwise can do broadcasting // XPUElementwise can do broadcasting
auto f = [](xpu::Context* ctx, auto f = [](xpu::Context* ctx,
const XPUType* x, const XPUType* x,
...@@ -227,16 +225,20 @@ void SetValueImpl(const Context& dev_ctx, ...@@ -227,16 +225,20 @@ void SetValueImpl(const Context& dev_ctx,
const std::vector<int>& yshape) { const std::vector<int>& yshape) {
return xpu::broadcast_add<XPUType>(ctx, x, y, z, xshape, yshape); return xpu::broadcast_add<XPUType>(ctx, x, y, z, xshape, yshape);
}; };
XPUElementwise<T, XPUType>( XPUElementwise<T, XPUType>(dev_ctx,
dev_ctx, slice_tensor, value, -1, &slice_tensor, f); reinterpret_cast<const T*>(slice_data),
slice_dims_for_assign,
slice_tensor.Resize(slice_dims); value_data,
value_dims,
-1,
reinterpret_cast<T*>(slice_data),
f);
// - Step 2.2 If stride < 0, flip the slice_tensor. // - Step 2.2 If stride < 0, flip the slice_tensor.
if (need_flip) { if (need_flip) {
r = xpu::flip(dev_ctx.x_context(), r = xpu::flip(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(slice_tensor.data<T>()), reinterpret_cast<const XPUType*>(slice_data),
reinterpret_cast<XPUType*>(slice_tensor.data<T>()), slice_data,
slice_shape, slice_shape,
flip_axis); flip_axis);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "flip"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "flip");
...@@ -244,7 +246,7 @@ void SetValueImpl(const Context& dev_ctx, ...@@ -244,7 +246,7 @@ void SetValueImpl(const Context& dev_ctx,
// Step 3: Set out tensor with value // Step 3: Set out tensor with value
r = xpu::strided_slice_view_update( r = xpu::strided_slice_view_update(
dev_ctx.x_context(), dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(slice_tensor.data<T>()), reinterpret_cast<const XPUType*>(slice_data),
reinterpret_cast<XPUType*>(out->data<T>()), reinterpret_cast<XPUType*>(out->data<T>()),
slice_shape, slice_shape,
out_shape, out_shape,
...@@ -255,9 +257,10 @@ void SetValueImpl(const Context& dev_ctx, ...@@ -255,9 +257,10 @@ void SetValueImpl(const Context& dev_ctx,
} }
template <typename T, typename Context> template <typename T, typename Context>
void SetTensorValueKernel(const Context& dev_ctx, void SetValueKernelImpl(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& value, const T* value_data,
const DDim& value_dims,
const IntArray& starts, const IntArray& starts,
const IntArray& ends, const IntArray& ends,
const IntArray& steps, const IntArray& steps,
...@@ -272,7 +275,8 @@ void SetTensorValueKernel(const Context& dev_ctx, ...@@ -272,7 +275,8 @@ void SetTensorValueKernel(const Context& dev_ctx,
case 1: case 1:
SetValueImpl<T, Context, 1>(dev_ctx, SetValueImpl<T, Context, 1>(dev_ctx,
x, x,
value, value_data,
value_dims,
starts, starts,
ends, ends,
steps, steps,
...@@ -284,7 +288,8 @@ void SetTensorValueKernel(const Context& dev_ctx, ...@@ -284,7 +288,8 @@ void SetTensorValueKernel(const Context& dev_ctx,
case 2: case 2:
SetValueImpl<T, Context, 2>(dev_ctx, SetValueImpl<T, Context, 2>(dev_ctx,
x, x,
value, value_data,
value_dims,
starts, starts,
ends, ends,
steps, steps,
...@@ -296,7 +301,8 @@ void SetTensorValueKernel(const Context& dev_ctx, ...@@ -296,7 +301,8 @@ void SetTensorValueKernel(const Context& dev_ctx,
case 3: case 3:
SetValueImpl<T, Context, 3>(dev_ctx, SetValueImpl<T, Context, 3>(dev_ctx,
x, x,
value, value_data,
value_dims,
starts, starts,
ends, ends,
steps, steps,
...@@ -308,7 +314,8 @@ void SetTensorValueKernel(const Context& dev_ctx, ...@@ -308,7 +314,8 @@ void SetTensorValueKernel(const Context& dev_ctx,
case 4: case 4:
SetValueImpl<T, Context, 4>(dev_ctx, SetValueImpl<T, Context, 4>(dev_ctx,
x, x,
value, value_data,
value_dims,
starts, starts,
ends, ends,
steps, steps,
...@@ -320,7 +327,8 @@ void SetTensorValueKernel(const Context& dev_ctx, ...@@ -320,7 +327,8 @@ void SetTensorValueKernel(const Context& dev_ctx,
case 5: case 5:
SetValueImpl<T, Context, 5>(dev_ctx, SetValueImpl<T, Context, 5>(dev_ctx,
x, x,
value, value_data,
value_dims,
starts, starts,
ends, ends,
steps, steps,
...@@ -332,7 +340,8 @@ void SetTensorValueKernel(const Context& dev_ctx, ...@@ -332,7 +340,8 @@ void SetTensorValueKernel(const Context& dev_ctx,
case 6: case 6:
SetValueImpl<T, Context, 6>(dev_ctx, SetValueImpl<T, Context, 6>(dev_ctx,
x, x,
value, value_data,
value_dims,
starts, starts,
ends, ends,
steps, steps,
...@@ -347,6 +356,30 @@ void SetTensorValueKernel(const Context& dev_ctx, ...@@ -347,6 +356,30 @@ void SetTensorValueKernel(const Context& dev_ctx,
} }
} }
template <typename T, typename Context>
void SetTensorValueKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& value,
const IntArray& starts,
const IntArray& ends,
const IntArray& steps,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& decrease_axes,
const std::vector<int64_t>& none_axes,
DenseTensor* out) {
SetValueKernelImpl<T, Context>(dev_ctx,
x,
value.data<T>(),
value.dims(),
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
out);
}
template <typename T, typename Context> template <typename T, typename Context>
void SetValueKernel(const Context& dev_ctx, void SetValueKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
...@@ -359,18 +392,30 @@ void SetValueKernel(const Context& dev_ctx, ...@@ -359,18 +392,30 @@ void SetValueKernel(const Context& dev_ctx,
const std::vector<int64_t>& shape, const std::vector<int64_t>& shape,
const std::vector<Scalar>& values, const std::vector<Scalar>& values,
DenseTensor* out) { DenseTensor* out) {
std::vector<T> assgin_values; using XPUType = typename XPUTypeTrait<T>::Type;
assgin_values.reserve(values.size()); std::vector<T> assign_values;
assign_values.reserve(values.size());
for (const auto& val : values) { for (const auto& val : values) {
assgin_values.push_back(val.to<T>()); assign_values.push_back(val.to<T>());
} }
DenseTensor value_tensor = Empty<T>(dev_ctx, shape);
phi::TensorFromVector(assgin_values, dev_ctx, &value_tensor);
value_tensor.Resize(phi::make_ddim(shape));
SetTensorValueKernel<T, Context>(dev_ctx, xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
auto value_dims = phi::make_ddim(shape);
XPUType* value_data =
RAII_GUARD.alloc_l3_or_gm<XPUType>(phi::product(value_dims));
phi::CPUPlace src_place;
auto dst_place = dev_ctx.GetPlace();
memory_utils::Copy(dst_place,
value_data,
src_place,
assign_values.data(),
assign_values.size() * sizeof(T));
SetValueKernelImpl<T, Context>(dev_ctx,
x, x,
value_tensor, reinterpret_cast<const T*>(value_data),
value_dims,
starts, starts,
ends, ends,
steps, steps,
......
...@@ -117,5 +117,6 @@ PD_REGISTER_KERNEL(strided_slice_raw, ...@@ -117,5 +117,6 @@ PD_REGISTER_KERNEL(strided_slice_raw,
phi::StridedSliceRawKernel, phi::StridedSliceRawKernel,
int, int,
int16_t, int16_t,
int64_t,
float, float,
phi::dtype::float16) {} phi::dtype::float16) {}
...@@ -60,8 +60,8 @@ void TopkKernel(const Context& dev_ctx, ...@@ -60,8 +60,8 @@ void TopkKernel(const Context& dev_ctx,
size_t k = k_scalar.to<int>(); size_t k = k_scalar.to<int>();
if (axis + 1 == in_dims.size()) { if (axis + 1 == in_dims.size()) {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int32_t* indices_int_data = Alloc_l3_or_gm<Context, int32_t>( int32_t* indices_int_data =
dev_ctx, &RAII_GUARD, indices->numel()); RAII_GUARD.alloc_l3_or_gm<int32_t>(indices->numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(indices_int_data); PADDLE_ENFORCE_XDNN_NOT_NULL(indices_int_data);
const size_t row = const size_t row =
...@@ -106,8 +106,7 @@ void TopkKernel(const Context& dev_ctx, ...@@ -106,8 +106,7 @@ void TopkKernel(const Context& dev_ctx,
} }
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUType* trans_in_data = XPUType* trans_in_data = RAII_GUARD.alloc_l3_or_gm<XPUType>(x.numel());
Alloc_l3_or_gm<Context, XPUType>(dev_ctx, &RAII_GUARD, x.numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(trans_in_data); PADDLE_ENFORCE_XDNN_NOT_NULL(trans_in_data);
// Transpose and save interval output to trans_in // Transpose and save interval output to trans_in
...@@ -123,16 +122,14 @@ void TopkKernel(const Context& dev_ctx, ...@@ -123,16 +122,14 @@ void TopkKernel(const Context& dev_ctx,
r, r,
XPUAPIErrorMsg[r])); XPUAPIErrorMsg[r]));
XPUType* trans_out_data = XPUType* trans_out_data = RAII_GUARD.alloc_l3_or_gm<XPUType>(out->numel());
Alloc_l3_or_gm<Context, XPUType>(dev_ctx, &RAII_GUARD, out->numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(trans_out_data); PADDLE_ENFORCE_XDNN_NOT_NULL(trans_out_data);
int64_t* trans_idx_data = int64_t* trans_idx_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(out->numel());
Alloc_l3_or_gm<Context, int64_t>(dev_ctx, &RAII_GUARD, out->numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(trans_idx_data); PADDLE_ENFORCE_XDNN_NOT_NULL(trans_idx_data);
int32_t* trans_idx_int32_data = int32_t* trans_idx_int32_data =
Alloc_l3_or_gm<Context, int32_t>(dev_ctx, &RAII_GUARD, out->numel()); RAII_GUARD.alloc_l3_or_gm<int32_t>(out->numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(trans_idx_int32_data); PADDLE_ENFORCE_XDNN_NOT_NULL(trans_idx_int32_data);
const size_t row = const size_t row =
phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册