未验证 提交 ca5567e1 编写于 作者: Z zhangbo9674 提交者: GitHub

refine expand_as_v2 XPU kernel, test=kunlun (#45501)

上级 fbd83812
......@@ -23,11 +23,11 @@ namespace phi {
template <typename Context, typename T>
void ExpandAs(const Context& context,
const DenseTensor& in0,
const DenseTensor& x,
const std::vector<int>& target_shape,
DenseTensor* out0) {
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto in_dims = in0.dims();
auto in_dims = x.dims();
auto vec_in_dims = phi::vectorize<int>(in_dims);
auto diff = target_shape.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
......@@ -50,23 +50,23 @@ void ExpandAs(const Context& context,
}
}
phi::DDim out_dims = phi::make_ddim(target_shape);
out0->Resize(out_dims);
context.template Alloc<T>(out0);
auto& in0_shape = vec_in_dims;
auto out0_shape = phi::vectorize<int>(out_dims);
out->Resize(out_dims);
context.template Alloc<T>(out);
auto& x_shape = vec_in_dims;
auto out_shape = phi::vectorize<int>(out_dims);
int r = XPU_SUCCESS;
if (std::is_same<T, bool>::value) {
auto in0_data = reinterpret_cast<const int8_t*>(in0.data<T>());
auto out0_data = reinterpret_cast<int8_t*>(out0->data<T>());
auto x_data = reinterpret_cast<const int8_t*>(x.data<T>());
auto out_data = reinterpret_cast<int8_t*>(out->data<T>());
r = xpu::broadcast<int8_t>(
context.x_context(), in0_data, out0_data, in0_shape, out0_shape);
context.x_context(), x_data, out_data, x_shape, out_shape);
} else {
auto in0_data = reinterpret_cast<const XPUType*>(in0.data<T>());
auto out0_data = reinterpret_cast<XPUType*>(out0->data<T>());
auto x_data = reinterpret_cast<const XPUType*>(x.data<T>());
auto out_data = reinterpret_cast<XPUType*>(out->data<T>());
r = xpu::broadcast<XPUType>(
context.x_context(), in0_data, out0_data, in0_shape, out0_shape);
context.x_context(), x_data, out_data, x_shape, out_shape);
}
PADDLE_ENFORCE_EQ(
r,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册