未验证 提交 3e088aaf 编写于 作者: F furnace 提交者: GitHub

[NPU] add int64 support for argsort op (#37434)

* [NPU] add int64 support for argsort op

* [NPU] delete debug codes
上级 1127fecb
......@@ -46,6 +46,18 @@ static void CastToInt64(const framework::ExecutionContext& ctx,
.Run(stream);
}
static void CastToFP32(const framework::ExecutionContext& ctx,
const aclrtStream& stream, const Tensor& in,
Tensor* out) {
out->mutable_data<float>(ctx.GetPlace());
NpuOpRunner runner;
runner.SetType("Cast")
.AddInput(in)
.AddOutput(*out)
.AddAttr("dst_type", ACL_FLOAT)
.Run(stream);
}
template <typename T>
class ArgsortNPUKernel : public framework::OpKernel<T> {
public:
......@@ -66,41 +78,91 @@ class ArgsortNPUKernel : public framework::OpKernel<T> {
Tensor indices_tmp(framework::proto::VarType::INT32);
indices_tmp.Resize(indices->dims());
if (axis == -1 || axis + 1 == in_dims.size()) {
output->mutable_data<T>(ctx.GetPlace());
indices_tmp.mutable_data<int32_t>(ctx.GetPlace());
const auto& runner =
NpuOpRunner("Sort", {*input}, {*output, indices_tmp}, attr);
runner.Run(stream);
} else {
std::vector<int64_t> perm;
for (int64_t i = 0; i < in_dims.size(); i++) {
perm.emplace_back(i);
if (input->type() == framework::proto::VarType::INT64) {
Tensor input_fp32(framework::proto::VarType::FP32);
input_fp32.Resize(input->dims());
CastToFP32(ctx, stream, *input, &input_fp32);
Tensor output_fp32(framework::proto::VarType::FP32);
output_fp32.Resize(output->dims());
if (axis == -1 || axis + 1 == in_dims.size()) {
output_fp32.mutable_data<float>(ctx.GetPlace());
indices_tmp.mutable_data<int32_t>(ctx.GetPlace());
const auto& runner =
NpuOpRunner("Sort", {input_fp32}, {output_fp32, indices_tmp}, attr);
runner.Run(stream);
CastToInt64(ctx, stream, output_fp32, output);
} else {
std::vector<int64_t> perm;
for (int64_t i = 0; i < in_dims.size(); i++) {
perm.emplace_back(i);
}
std::swap(perm[axis], perm[in_dims.size() - 1]);
std::vector<int64_t> shape;
for (size_t i = 0; i < perm.size(); i++) {
shape.emplace_back(in_dims[perm[i]]);
}
auto trans_dims = framework::make_ddim(shape);
Tensor trans_input(input_fp32.type());
trans_input.Resize(trans_dims);
TranposeNPU<float>(ctx, stream, &perm, input_fp32, &trans_input);
Tensor trans_output(input_fp32.type());
Tensor trans_indices(framework::proto::VarType::INT32);
trans_output.mutable_data<float>(trans_dims, ctx.GetPlace());
trans_indices.mutable_data<int32_t>(trans_dims, ctx.GetPlace());
const auto& runner = NpuOpRunner("Sort", {trans_input},
{trans_output, trans_indices}, attr);
runner.Run(stream);
TranposeNPU<float>(ctx, stream, &perm, trans_output, &output_fp32);
TranposeNPU<int32_t>(ctx, stream, &perm, trans_indices, &indices_tmp);
CastToInt64(ctx, stream, output_fp32, output);
}
std::swap(perm[axis], perm[in_dims.size() - 1]);
std::vector<int64_t> shape;
for (size_t i = 0; i < perm.size(); i++) {
shape.emplace_back(in_dims[perm[i]]);
} else {
if (axis == -1 || axis + 1 == in_dims.size()) {
output->mutable_data<T>(ctx.GetPlace());
indices_tmp.mutable_data<int32_t>(ctx.GetPlace());
const auto& runner =
NpuOpRunner("Sort", {*input}, {*output, indices_tmp}, attr);
runner.Run(stream);
} else {
std::vector<int64_t> perm;
for (int64_t i = 0; i < in_dims.size(); i++) {
perm.emplace_back(i);
}
std::swap(perm[axis], perm[in_dims.size() - 1]);
std::vector<int64_t> shape;
for (size_t i = 0; i < perm.size(); i++) {
shape.emplace_back(in_dims[perm[i]]);
}
auto trans_dims = framework::make_ddim(shape);
Tensor trans_input(input->type());
trans_input.Resize(trans_dims);
TranposeNPU<T>(ctx, stream, &perm, *input, &trans_input);
Tensor trans_output(input->type());
Tensor trans_indices(framework::proto::VarType::INT32);
trans_output.mutable_data<T>(trans_dims, ctx.GetPlace());
trans_indices.mutable_data<int32_t>(trans_dims, ctx.GetPlace());
const auto& runner = NpuOpRunner("Sort", {trans_input},
{trans_output, trans_indices}, attr);
runner.Run(stream);
TranposeNPU<T>(ctx, stream, &perm, trans_output, output);
TranposeNPU<int32_t>(ctx, stream, &perm, trans_indices, &indices_tmp);
}
auto trans_dims = framework::make_ddim(shape);
Tensor trans_input(input->type());
trans_input.Resize(trans_dims);
TranposeNPU<T>(ctx, stream, &perm, *input, &trans_input);
Tensor trans_output(input->type());
Tensor trans_indices(framework::proto::VarType::INT32);
trans_output.mutable_data<T>(trans_dims, ctx.GetPlace());
trans_indices.mutable_data<int32_t>(trans_dims, ctx.GetPlace());
const auto& runner = NpuOpRunner("Sort", {trans_input},
{trans_output, trans_indices}, attr);
runner.Run(stream);
TranposeNPU<T>(ctx, stream, &perm, trans_output, output);
TranposeNPU<int32_t>(ctx, stream, &perm, trans_indices, &indices_tmp);
}
CastToInt64(ctx, stream, indices_tmp, indices);
}
};
......@@ -208,6 +270,9 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(argsort, ops::ArgsortNPUKernel<float>,
#ifdef PADDLE_WITH_ASCEND_INT64
ops::ArgsortNPUKernel<int64_t>,
#endif
ops::ArgsortNPUKernel<plat::float16>);
REGISTER_OP_NPU_KERNEL(argsort_grad, ops::ArgsortGradNPUKernel<float>,
......
......@@ -209,5 +209,87 @@ class TestArgsortOpDescendingAxisNeg2NPUFP32(TestArgsortOpAxisNeg2NPUFP32):
self.descending = True
# test cases for int64
class TestArgsortOpAxis0NPUINT64(TestArgsortOp):
def setUp(self):
self.set_npu()
self.op_type = "argsort"
self.place = paddle.NPUPlace(0)
self.init_dtype()
self.init_inputshape()
self.init_axis()
self.init_direction()
self.x = np.random.randint(
low=-100, high=100, size=self.input_shape,
dtype=self.dtype).astype(self.dtype)
self.inputs = {"X": self.x}
self.attrs = {"axis": self.axis, "descending": self.descending}
self.get_output()
self.outputs = {"Out": self.sorted_x, "Indices": self.indices}
def init_axis(self):
self.axis = 0
def init_dtype(self):
self.dtype = np.int64
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-2)
def set_npu(self):
self.__class__.use_npu = True
class TestArgsortOpAxis1NPUINT64(TestArgsortOpAxis0NPUINT64):
def init_axis(self):
self.axis = 1
class TestArgsortOpAxis2NPUINT64(TestArgsortOpAxis0NPUINT64):
def init_axis(self):
self.axis = 2
class TestArgsortOpAxisNeg1NPUINT64(TestArgsortOpAxis0NPUINT64):
def init_axis(self):
self.axis = -1
class TestArgsortOpAxisNeg2NPUINT64(TestArgsortOpAxis0NPUINT64):
def init_axis(self):
self.axis = -2
class TestArgsortOpDescendingAxisNPUINT64(TestArgsortOpAxis0NPUINT64):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxis0NPUINT64(TestArgsortOpAxis0NPUINT64):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxis1NPUINT64(TestArgsortOpAxis1NPUINT64):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxis2NPUINT64(TestArgsortOpAxis2NPUINT64):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxisNeg1NPUINT64(TestArgsortOpAxisNeg1NPUINT64):
def init_direction(self):
self.descending = True
class TestArgsortOpDescendingAxisNeg2NPUINT64(TestArgsortOpAxisNeg2NPUINT64):
def init_direction(self):
self.descending = True
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册