diff --git a/paddle/fluid/operators/expand_op_npu.cc b/paddle/fluid/operators/expand_op_npu.cc index 8ecdd5e8cb695a796511117082375264196032e1..e9f31f8ddd698c09869361882625e5c26d72dbc5 100644 --- a/paddle/fluid/operators/expand_op_npu.cc +++ b/paddle/fluid/operators/expand_op_npu.cc @@ -81,14 +81,30 @@ class ExpandNPUKernel : public framework::OpKernel { out_dims[i] *= expand_times[i]; } - out0->Resize(out_dims); - out0->mutable_data(context.device_context().GetPlace()); - const auto& runner = - NpuOpRunner("TileD", {*in0}, {*out0}, {{"multiples", expand_times}}); + auto place = context.GetPlace(); auto stream = context.template device_context() .stream(); - runner.Run(stream); + + out0->Resize(out_dims); + out0->mutable_data(place); + + bool is_expand_times_all_one = + (out0->numel() == in0->numel()) ? true : false; + + if (is_expand_times_all_one) { + memory::Copy(BOOST_GET_CONST(platform::NPUPlace, place), + out0->mutable_data(place), + BOOST_GET_CONST(platform::NPUPlace, place), in0->data(), + in0->numel() * sizeof(T), stream); + if (out_dims != in_dims) { + out0->Resize(out_dims); + } + } else { + const auto& runner = + NpuOpRunner("TileD", {*in0}, {*out0}, {{"multiples", expand_times}}); + runner.Run(stream); + } } }; } // namespace operators diff --git a/python/paddle/fluid/tests/unittests/npu/test_expand_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_expand_op_npu.py index 375003f79e500f99ceaf374ce898d998263700e0..89ac9e09aa3488c25000c7801f108e036f33934e 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_expand_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_expand_op_npu.py @@ -132,5 +132,26 @@ class TestExpandNet(unittest.TestCase): self.assertTrue(np.allclose(npu_loss, cpu_loss)) +# ------------------------------------------------ +# Special Cases for NPU +# ------------------------------------------------ + + +class TestExpand_expand_times_all_one(TestExpand): + def setUp(self): + self.set_npu() + self.op_type = "expand" + self.place = paddle.NPUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + x = np.random.randn(3, 1, 7).astype(self.dtype) + out = np.tile(x, [1, 1, 1]) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = {'expand_times': [1, 1, 1]} + self.outputs = {'Out': out} + + if __name__ == '__main__': unittest.main()