From ca174025911891256dedda6d876bb41fe518ee4b Mon Sep 17 00:00:00 2001 From: ronnywang <524019753@qq.com> Date: Mon, 26 Jul 2021 07:42:05 -0500 Subject: [PATCH] fix bug for index_sample_op_npu (#34383) --- paddle/fluid/operators/index_sample_op_npu.cc | 53 +++++++++++-------- 1 file changed, 31 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/operators/index_sample_op_npu.cc b/paddle/fluid/operators/index_sample_op_npu.cc index f5a4100c635..ef7e8583e30 100644 --- a/paddle/fluid/operators/index_sample_op_npu.cc +++ b/paddle/fluid/operators/index_sample_op_npu.cc @@ -20,6 +20,35 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; +template +void IndexSampleGather(const paddle::platform::NPUDeviceContext& dev_ctx, + const Tensor* index, const Tensor* input, Tensor* out) { + auto index_dims = index->dims(); + auto input_dims = input->dims(); + auto batch_size = input_dims[0]; + auto index_length = index_dims[1]; + + std::vector gather_index_vec; + std::vector index_vec; + framework::TensorToVector(*index, dev_ctx, &index_vec); + for (auto i = 0; i < batch_size; ++i) { + for (auto j = 0; j < index_length; j++) { + gather_index_vec.push_back(i); + gather_index_vec.push_back(index_vec[i * index_length + j]); + } + } + Tensor gather_index; + framework::TensorFromVector(gather_index_vec, dev_ctx, &gather_index); + gather_index.Resize({batch_size, index_length, 2}); + + NpuOpRunner runner; + runner.SetType("GatherNd") + .AddInput(*input) + .AddInput(gather_index) + .AddOutput(*out); + runner.Run(dev_ctx.stream()); +} + template class IndexSampleNPUKernel : public framework::OpKernel { public: @@ -31,32 +60,12 @@ class IndexSampleNPUKernel : public framework::OpKernel { auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); - Tensor transformed_index; const auto& index_type = index->type(); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, true, - platform::errors::InvalidArgument( - "Input(Index) holds the wrong type, it holds %s, but " - "desires to be %s or %s", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); if (index_type == framework::proto::VarType::INT32) { - transformed_index.mutable_data(index->dims(), - dev_ctx.GetPlace()); - const auto& cast_runner = NpuOpRunner( - "Cast", {*index}, {transformed_index}, {{"dst_type", ACL_INT64}}); - cast_runner.Run(dev_ctx.stream()); + IndexSampleGather(dev_ctx, index, input, out); } else { - transformed_index.ShareDataWith(*index); + IndexSampleGather(dev_ctx, index, input, out); } - - const auto& runner = NpuOpRunner( - "GatherElements", {*input, transformed_index}, {*out}, {{"dim", 1}}); - runner.Run(dev_ctx.stream()); } }; -- GitLab