未验证 提交 7a5af630 编写于 作者: F furnace 提交者: GitHub

[NPU] fix expand op (#38526)

* [NPU] fix expand op

* [NPU] optimize codes

* [NPU] optimize codes
上级 23aa7b08
......@@ -81,15 +81,31 @@ class ExpandNPUKernel : public framework::OpKernel<T> {
out_dims[i] *= expand_times[i];
}
out0->Resize(out_dims);
out0->mutable_data<T>(context.device_context().GetPlace());
const auto& runner =
NpuOpRunner("TileD", {*in0}, {*out0}, {{"multiples", expand_times}});
auto place = context.GetPlace();
auto stream =
context.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
out0->Resize(out_dims);
out0->mutable_data<T>(place);
bool is_expand_times_all_one =
(out0->numel() == in0->numel()) ? true : false;
if (is_expand_times_all_one) {
memory::Copy(BOOST_GET_CONST(platform::NPUPlace, place),
out0->mutable_data<T>(place),
BOOST_GET_CONST(platform::NPUPlace, place), in0->data<T>(),
in0->numel() * sizeof(T), stream);
if (out_dims != in_dims) {
out0->Resize(out_dims);
}
} else {
const auto& runner =
NpuOpRunner("TileD", {*in0}, {*out0}, {{"multiples", expand_times}});
runner.Run(stream);
}
}
};
} // namespace operators
} // namespace paddle
......
......@@ -132,5 +132,26 @@ class TestExpandNet(unittest.TestCase):
self.assertTrue(np.allclose(npu_loss, cpu_loss))
# ------------------------------------------------
# Special Cases for NPU
# ------------------------------------------------
class TestExpand_expand_times_all_one(TestExpand):
def setUp(self):
self.set_npu()
self.op_type = "expand"
self.place = paddle.NPUPlace(0)
self.init_dtype()
np.random.seed(SEED)
x = np.random.randn(3, 1, 7).astype(self.dtype)
out = np.tile(x, [1, 1, 1])
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.attrs = {'expand_times': [1, 1, 1]}
self.outputs = {'Out': out}
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册