未验证 提交 ca174025 编写于 作者: R ronnywang 提交者: GitHub

fix bug for index_sample_op_npu (#34383)

上级 a0bbc992
...@@ -20,6 +20,35 @@ namespace paddle { ...@@ -20,6 +20,35 @@ namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename IndexT>
void IndexSampleGather(const paddle::platform::NPUDeviceContext& dev_ctx,
const Tensor* index, const Tensor* input, Tensor* out) {
auto index_dims = index->dims();
auto input_dims = input->dims();
auto batch_size = input_dims[0];
auto index_length = index_dims[1];
std::vector<IndexT> gather_index_vec;
std::vector<IndexT> index_vec;
framework::TensorToVector(*index, dev_ctx, &index_vec);
for (auto i = 0; i < batch_size; ++i) {
for (auto j = 0; j < index_length; j++) {
gather_index_vec.push_back(i);
gather_index_vec.push_back(index_vec[i * index_length + j]);
}
}
Tensor gather_index;
framework::TensorFromVector(gather_index_vec, dev_ctx, &gather_index);
gather_index.Resize({batch_size, index_length, 2});
NpuOpRunner runner;
runner.SetType("GatherNd")
.AddInput(*input)
.AddInput(gather_index)
.AddOutput(*out);
runner.Run(dev_ctx.stream());
}
template <typename T> template <typename T>
class IndexSampleNPUKernel : public framework::OpKernel<T> { class IndexSampleNPUKernel : public framework::OpKernel<T> {
public: public:
...@@ -31,32 +60,12 @@ class IndexSampleNPUKernel : public framework::OpKernel<T> { ...@@ -31,32 +60,12 @@ class IndexSampleNPUKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<framework::LoDTensor>("Out"); auto* out = ctx.Output<framework::LoDTensor>("Out");
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
Tensor transformed_index;
const auto& index_type = index->type(); const auto& index_type = index->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) { if (index_type == framework::proto::VarType::INT32) {
transformed_index.mutable_data<int64_t>(index->dims(), IndexSampleGather<int32_t>(dev_ctx, index, input, out);
dev_ctx.GetPlace());
const auto& cast_runner = NpuOpRunner(
"Cast", {*index}, {transformed_index}, {{"dst_type", ACL_INT64}});
cast_runner.Run(dev_ctx.stream());
} else { } else {
transformed_index.ShareDataWith(*index); IndexSampleGather<int64_t>(dev_ctx, index, input, out);
} }
const auto& runner = NpuOpRunner(
"GatherElements", {*input, transformed_index}, {*out}, {{"dim", 1}});
runner.Run(dev_ctx.stream());
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册