未验证 提交 67a094b5 编写于 作者: A Aganlengzi 提交者: GitHub

[NPU] add index_select_grad kernel and unit tests (#35594)

* [NPU] add index_select_grad kernel and unit tests

* dim=0 not need transpose
上级 e93c18a3
......@@ -21,12 +21,12 @@ namespace operators {
template <typename DeviceContext, typename T>
class IndexSelectNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index");
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* index = ctx.Input<Tensor>("Index");
auto dim = ctx.Attr<int>("dim");
auto *out = ctx.Output<Tensor>("Out");
auto* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto stream =
......@@ -43,7 +43,104 @@ class IndexSelectNPUKernel : public framework::OpKernel<T> {
}
};
// todo: add class 'IndexSelectGradNPUKernel' here.
template <typename DeviceContext, typename T>
class IndexSelectGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* index = ctx.Input<Tensor>("Index");
auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
auto x_dims = x_grad->dims();
auto out_dims = out_grad->dims();
int dim = ctx.Attr<int>("dim");
if (dim < 0) {
dim += out_dims.size();
}
Tensor casted_index;
if (index->type() != framework::proto::VarType::INT32) {
casted_index.mutable_data<int32_t>(index->dims(), ctx.GetPlace());
const auto& cast_runner = NpuOpRunner("Cast", {*index}, {casted_index},
{{"dst_type", ACL_INT32}});
cast_runner.Run(stream);
} else {
casted_index.ShareDataWith(*index);
}
if (dim == 0) {
x_grad->mutable_data<T>(ctx.GetPlace());
const auto& zeros_runner = NpuOpRunner("ZerosLike", {*x_grad}, {*x_grad});
zeros_runner.Run(stream);
NpuOpRunner runner;
runner.SetType("UnsortedSegmentSum")
.AddInput(*out_grad)
.AddInput(casted_index)
.AddInput(std::vector<int64_t>{x_dims[dim]})
.AddOutput(*x_grad);
runner.Run(stream);
} else {
Tensor transed_out_grad;
std::vector<int> in_trans_perm;
in_trans_perm.push_back(dim);
for (int i = 0; i < out_dims.size(); ++i) {
if (i == dim) continue;
in_trans_perm.push_back(i);
}
framework::DDim transed_out_dims(out_dims);
for (size_t i = 0; i < in_trans_perm.size(); ++i) {
transed_out_dims[i] = out_dims[in_trans_perm[i]];
}
transed_out_grad.mutable_data<T>(transed_out_dims, ctx.GetPlace());
framework::NPUAttributeMap in_trans_attr = {{"perm", in_trans_perm}};
const auto& in_trans_runner = NpuOpRunner(
"TransposeD", {*out_grad}, {transed_out_grad}, in_trans_attr);
in_trans_runner.Run(stream);
Tensor sum_out;
framework::DDim sum_dims(x_dims);
sum_dims[0] = x_dims[dim];
auto idx = 1;
for (int i = 0; i < x_dims.size(); ++i) {
if (i == dim) continue;
sum_dims[idx++] = x_dims[i];
}
sum_out.mutable_data<T>(sum_dims, ctx.GetPlace());
const auto& zeros_runner = NpuOpRunner("ZerosLike", {sum_out}, {sum_out});
zeros_runner.Run(stream);
NpuOpRunner runner;
runner.SetType("UnsortedSegmentSum")
.AddInput(transed_out_grad)
.AddInput(casted_index)
.AddInput(std::vector<int64_t>{x_dims[dim]})
.AddOutput(sum_out);
runner.Run(stream);
std::vector<int> out_trans_perm;
for (int i = 1; i < 1 + dim; ++i) {
out_trans_perm.push_back(i);
}
out_trans_perm.push_back(0);
for (int i = 1 + dim; i < x_dims.size(); ++i) {
out_trans_perm.push_back(i);
}
framework::NPUAttributeMap out_trans_attr = {{"perm", out_trans_perm}};
x_grad->mutable_data<T>(ctx.GetPlace());
const auto& out_trans_runner =
NpuOpRunner("TransposeD", {sum_out}, {*x_grad}, out_trans_attr);
out_trans_runner.Run(stream);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -54,4 +151,8 @@ REGISTER_OP_NPU_KERNEL(
ops::IndexSelectNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::IndexSelectNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::IndexSelectNPUKernel<paddle::platform::NPUDeviceContext, int64_t>);
// todo: register npu index_select_grad kernel here.
REGISTER_OP_NPU_KERNEL(
index_select_grad,
ops::IndexSelectGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::IndexSelectGradNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::IndexSelectGradNPUKernel<paddle::platform::NPUDeviceContext, int64_t>);
......@@ -35,7 +35,10 @@ class TestNPUIndexSelect(OpTest):
x_np = np.random.random(self.x_shape).astype(self.x_type)
index_np = np.random.randint(
low=0, high=self.x_shape[self.dim], size=self.index_size)
low=0,
high=self.x_shape[self.dim],
size=self.index_size,
dtype=self.index_type)
# compute real output as baseline.
outer_loop = np.prod(self.x_shape[:self.dim])
......@@ -56,18 +59,14 @@ class TestNPUIndexSelect(OpTest):
self.attrs = {'dim': self.dim}
self.outputs = {'Out': out}
# todo: comment second line when index_select grad npu op is ready.
def set_npu(self):
self.__class__.use_npu = True
self.__class__.no_need_check_grad = True
def test_check_output(self):
self.check_output_with_place(self.place)
# todo: replace first line with second line when index_select grad npu op is ready.
def test_check_grad(self):
pass
#self.check_grad_with_place(self.place, ['X'], 'Out')
self.check_grad_with_place(self.place, ['X'], 'Out')
def config(self):
self.x_shape = (100, 4, 5)
......@@ -86,6 +85,24 @@ class TestNPUIndexSelectCase2(TestNPUIndexSelect):
self.index_size = 10
class TestNPUIndexSelectCase3(TestNPUIndexSelect):
def config(self):
self.dim = 0
self.x_type = np.float32
self.index_type = np.int32
self.x_shape = (10, 10, 4, 10)
self.index_size = 10
class TestNPUIndexSelectCase4(TestNPUIndexSelect):
def config(self):
self.dim = -1
self.x_type = np.float32
self.index_type = np.int32
self.x_shape = (10, 10, 4, 10)
self.index_size = 10
class TestNPUIndexSelectAPI(unittest.TestCase):
def input_data(self):
self.data_x = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册