未验证 提交 d7461973 编写于 作者: X xiayanming 提交者: GitHub

[NPU] Support npu kernel for gather op fix bug (#31541)

* add gather npu op

* code review done

* update python new line

* precommit

* fix review

* del commit

* update gather_grad

* fix bug

* fix bug
上级 5d22e15b
......@@ -51,10 +51,13 @@ class GatherGradOpNPUKernel : public framework::OpKernel<T> {
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
// step1: Unsqueeze index
framework::Tensor tmp_tensor(index->type());
const auto index_dims = index->dims();
if (index_dims.size() == 1) {
framework::Tensor tmp_index = UnsqueezeTo(*index, 2);
index = &tmp_index;
tmp_tensor.ShareDataWith(*index);
std::vector<int64_t> new_dim = {index_dims[0], 1};
tmp_tensor.Resize(framework::make_ddim(new_dim));
index = &tmp_tensor;
}
auto stream =
......
......@@ -109,17 +109,17 @@ void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx,
auto dout = scope->Var("DOut");
auto tensor_dout = dout->GetMutable<f::LoDTensor>();
std::vector<int> init_index = {0, 1, 2, 0};
std::vector<int> init_index = {0, 1};
paddle::framework::TensorFromVector<int>(init_index, ctx, tensor_index);
tensor_index->Resize(paddle::framework::make_ddim({2, 2}));
tensor_index->Resize(paddle::framework::make_ddim({2}));
std::vector<T> init_x = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0};
TensorFromVector(init_x, ctx, tensor_x);
tensor_x->Resize(paddle::framework::make_ddim({3, 2}));
std::vector<T> init_dout = {5.0, 10.0};
std::vector<T> init_dout = {5.0, 10.0, 2.0, 3.0};
TensorFromVector(init_dout, ctx, tensor_dout);
tensor_dout->Resize(paddle::framework::make_ddim({2}));
tensor_dout->Resize(paddle::framework::make_ddim({2, 2}));
ctx.Wait();
......@@ -143,7 +143,7 @@ void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx,
uint32_t expected_size = 3 * 2;
EXPECT_EQ((uint32_t)dx_vec.size(), expected_size);
std::vector<T> expected_dx_vec = {0.0, 5.0, 0.0, 0.0, 10.0, 0.0};
std::vector<T> expected_dx_vec = {5.0, 10.0, 2.0, 3.0, 0.0, 0.0};
for (uint32_t i = 0; i < dx_vec.size(); i++) {
VLOG(3) << "dx_vec[i]=" << dx_vec[i];
EXPECT_EQ(dx_vec[i], expected_dx_vec[i]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册