diff --git a/paddle/fluid/operators/scatter_op_npu.cc b/paddle/fluid/operators/scatter_op_npu.cc index de368e6e802193b1b3e8059a0911ce063663d2dc..8d92ea41665135f0fbc8ab7f0d4bb9678d16249c 100644 --- a/paddle/fluid/operators/scatter_op_npu.cc +++ b/paddle/fluid/operators/scatter_op_npu.cc @@ -48,18 +48,49 @@ class ScatterNPUKernel : public framework::OpKernel { index = &tmp_tensor; } - auto stream = - ctx.template device_context() - .stream(); + const auto& dev_ctx = + ctx.template device_context(); + auto op_func_update = [](const std::vector& inputs, + const std::vector& outputs, + const NPUAttributeMap& attrs, + const platform::NPUDeviceContext& dev_ctx) { + const auto& runner = + NpuOpRunner("TensorScatterUpdate", inputs, outputs, attrs); + runner.Run(dev_ctx.stream()); + }; + auto op_func_add = [](const std::vector& inputs, + const std::vector& outputs, + const NPUAttributeMap& attrs, + const platform::NPUDeviceContext& dev_ctx) { + const auto& runner = + NpuOpRunner("TensorScatterAdd", inputs, outputs, attrs); + runner.Run(dev_ctx.stream()); + }; if (overwrite) { - const auto& runner_update = NpuOpRunner( - "TensorScatterUpdate", {*x, *index, *updates}, {*out}, {}); - runner_update.Run(stream); + if (x->type() == framework::proto::VarType::INT64) { + NpuOpRunner::TypeAdapter( + {*x, *index, *updates}, {*out}, {}, dev_ctx, op_func_update, + {framework::proto::VarType::INT32, framework::proto::VarType::INT32, + framework::proto::VarType::INT32}, + {framework::proto::VarType::INT32}); + } else { + const auto& runner_update = NpuOpRunner( + "TensorScatterUpdate", {*x, *index, *updates}, {*out}, {}); + runner_update.Run(dev_ctx.stream()); + } } else { - const auto& runner_add = - NpuOpRunner("TensorScatterAdd", {*x, *index, *updates}, {*out}, {}); - runner_add.Run(stream); + if (x->type() == framework::proto::VarType::INT64) { + NpuOpRunner::TypeAdapter( + {*x, *index, *updates}, {*out}, {}, dev_ctx, op_func_add, + {framework::proto::VarType::INT32, framework::proto::VarType::INT32, + framework::proto::VarType::INT32}, + {framework::proto::VarType::INT32}); + } else { + const auto& runner_add = + NpuOpRunner("TensorScatterAdd", {*x, *index, *updates}, {*out}, {}); + runner_add.Run(dev_ctx.stream()); + } } } }; @@ -70,6 +101,10 @@ namespace ops = paddle::operators; REGISTER_OP_NPU_KERNEL( scatter, ops::ScatterNPUKernel, +#ifdef PADDLE_WITH_ASCEND_INT64 + ops::ScatterNPUKernel, +#endif + ops::ScatterNPUKernel, ops::ScatterNPUKernel); #endif diff --git a/python/paddle/fluid/tests/unittests/npu/test_scatter_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_scatter_op_npu.py index c05b53d9a48621c61c8c958c360a13d2e9bf9466..c353654641932ecc4389e11c35e4b15d1f3db3cc 100755 --- a/python/paddle/fluid/tests/unittests/npu/test_scatter_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_scatter_op_npu.py @@ -27,7 +27,7 @@ paddle.enable_static() SEED = 2021 -class TestCast1(OpTest): +class TestCast1_FP32(OpTest): def setUp(self): self.set_npu() self.op_type = "scatter" @@ -50,7 +50,7 @@ class TestCast1(OpTest): self.check_output_with_place(self.place) -class TestCast2(OpTest): +class TestCast_INT32(OpTest): def setUp(self): self.set_npu() self.op_type = "scatter" @@ -73,7 +73,7 @@ class TestCast2(OpTest): self.check_output_with_place(self.place) -class TestCast3(OpTest): +class TestCast2_FP32(OpTest): def setUp(self): self.set_npu() self.op_type = "scatter" @@ -96,7 +96,7 @@ class TestCast3(OpTest): self.check_output_with_place(self.place) -class TestCast4(OpTest): +class TestCast3_FP32(OpTest): def setUp(self): self.set_npu() self.op_type = "scatter" @@ -120,5 +120,28 @@ class TestCast4(OpTest): self.check_output_with_place(self.place) +class TestCast_INT64(OpTest): + def setUp(self): + self.set_npu() + self.op_type = "scatter" + self.place = paddle.NPUPlace(0) + + ref_np = np.ones((3, 2)).astype("int64") + index_np = np.array([1]).astype("int32") + updates_np = np.zeros((1, 2)).astype("int64") + + output_np = np.copy(ref_np) + output_np[index_np] = updates_np + self.inputs = {'X': ref_np, 'Ids': index_np, 'Updates': updates_np} + self.outputs = {'Out': output_np} + self.attrs = {'overwrite': True} + + def set_npu(self): + self.__class__.use_npu = True + + def test_check_output(self): + self.check_output_with_place(self.place) + + if __name__ == '__main__': unittest.main()