未验证 提交 ca5567e1 编写于 作者: Z zhangbo9674 提交者: GitHub

refine expand_as_v2 XPU kernel, test=kunlun (#45501)

上级 fbd83812
...@@ -23,11 +23,11 @@ namespace phi { ...@@ -23,11 +23,11 @@ namespace phi {
template <typename Context, typename T> template <typename Context, typename T>
void ExpandAs(const Context& context, void ExpandAs(const Context& context,
const DenseTensor& in0, const DenseTensor& x,
const std::vector<int>& target_shape, const std::vector<int>& target_shape,
DenseTensor* out0) { DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
auto in_dims = in0.dims(); auto in_dims = x.dims();
auto vec_in_dims = phi::vectorize<int>(in_dims); auto vec_in_dims = phi::vectorize<int>(in_dims);
auto diff = target_shape.size() - vec_in_dims.size(); auto diff = target_shape.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1); vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
...@@ -50,23 +50,23 @@ void ExpandAs(const Context& context, ...@@ -50,23 +50,23 @@ void ExpandAs(const Context& context,
} }
} }
phi::DDim out_dims = phi::make_ddim(target_shape); phi::DDim out_dims = phi::make_ddim(target_shape);
out0->Resize(out_dims); out->Resize(out_dims);
context.template Alloc<T>(out0); context.template Alloc<T>(out);
auto& in0_shape = vec_in_dims; auto& x_shape = vec_in_dims;
auto out0_shape = phi::vectorize<int>(out_dims); auto out_shape = phi::vectorize<int>(out_dims);
int r = XPU_SUCCESS; int r = XPU_SUCCESS;
if (std::is_same<T, bool>::value) { if (std::is_same<T, bool>::value) {
auto in0_data = reinterpret_cast<const int8_t*>(in0.data<T>()); auto x_data = reinterpret_cast<const int8_t*>(x.data<T>());
auto out0_data = reinterpret_cast<int8_t*>(out0->data<T>()); auto out_data = reinterpret_cast<int8_t*>(out->data<T>());
r = xpu::broadcast<int8_t>( r = xpu::broadcast<int8_t>(
context.x_context(), in0_data, out0_data, in0_shape, out0_shape); context.x_context(), x_data, out_data, x_shape, out_shape);
} else { } else {
auto in0_data = reinterpret_cast<const XPUType*>(in0.data<T>()); auto x_data = reinterpret_cast<const XPUType*>(x.data<T>());
auto out0_data = reinterpret_cast<XPUType*>(out0->data<T>()); auto out_data = reinterpret_cast<XPUType*>(out->data<T>());
r = xpu::broadcast<XPUType>( r = xpu::broadcast<XPUType>(
context.x_context(), in0_data, out0_data, in0_shape, out0_shape); context.x_context(), x_data, out_data, x_shape, out_shape);
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, r,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册