未验证 提交 d7461973 编写于 作者: X xiayanming 提交者: GitHub

[NPU] Support npu kernel for gather op fix bug (#31541)

* add gather npu op

* code review done

* update python new line

* precommit

* fix review

* del commit

* update gather_grad

* fix bug

* fix bug
上级 5d22e15b
...@@ -51,10 +51,13 @@ class GatherGradOpNPUKernel : public framework::OpKernel<T> { ...@@ -51,10 +51,13 @@ class GatherGradOpNPUKernel : public framework::OpKernel<T> {
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
// step1: Unsqueeze index // step1: Unsqueeze index
framework::Tensor tmp_tensor(index->type());
const auto index_dims = index->dims(); const auto index_dims = index->dims();
if (index_dims.size() == 1) { if (index_dims.size() == 1) {
framework::Tensor tmp_index = UnsqueezeTo(*index, 2); tmp_tensor.ShareDataWith(*index);
index = &tmp_index; std::vector<int64_t> new_dim = {index_dims[0], 1};
tmp_tensor.Resize(framework::make_ddim(new_dim));
index = &tmp_tensor;
} }
auto stream = auto stream =
......
...@@ -109,17 +109,17 @@ void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx, ...@@ -109,17 +109,17 @@ void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx,
auto dout = scope->Var("DOut"); auto dout = scope->Var("DOut");
auto tensor_dout = dout->GetMutable<f::LoDTensor>(); auto tensor_dout = dout->GetMutable<f::LoDTensor>();
std::vector<int> init_index = {0, 1, 2, 0}; std::vector<int> init_index = {0, 1};
paddle::framework::TensorFromVector<int>(init_index, ctx, tensor_index); paddle::framework::TensorFromVector<int>(init_index, ctx, tensor_index);
tensor_index->Resize(paddle::framework::make_ddim({2, 2})); tensor_index->Resize(paddle::framework::make_ddim({2}));
std::vector<T> init_x = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; std::vector<T> init_x = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0};
TensorFromVector(init_x, ctx, tensor_x); TensorFromVector(init_x, ctx, tensor_x);
tensor_x->Resize(paddle::framework::make_ddim({3, 2})); tensor_x->Resize(paddle::framework::make_ddim({3, 2}));
std::vector<T> init_dout = {5.0, 10.0}; std::vector<T> init_dout = {5.0, 10.0, 2.0, 3.0};
TensorFromVector(init_dout, ctx, tensor_dout); TensorFromVector(init_dout, ctx, tensor_dout);
tensor_dout->Resize(paddle::framework::make_ddim({2})); tensor_dout->Resize(paddle::framework::make_ddim({2, 2}));
ctx.Wait(); ctx.Wait();
...@@ -143,7 +143,7 @@ void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx, ...@@ -143,7 +143,7 @@ void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx,
uint32_t expected_size = 3 * 2; uint32_t expected_size = 3 * 2;
EXPECT_EQ((uint32_t)dx_vec.size(), expected_size); EXPECT_EQ((uint32_t)dx_vec.size(), expected_size);
std::vector<T> expected_dx_vec = {0.0, 5.0, 0.0, 0.0, 10.0, 0.0}; std::vector<T> expected_dx_vec = {5.0, 10.0, 2.0, 3.0, 0.0, 0.0};
for (uint32_t i = 0; i < dx_vec.size(); i++) { for (uint32_t i = 0; i < dx_vec.size(); i++) {
VLOG(3) << "dx_vec[i]=" << dx_vec[i]; VLOG(3) << "dx_vec[i]=" << dx_vec[i];
EXPECT_EQ(dx_vec[i], expected_dx_vec[i]); EXPECT_EQ(dx_vec[i], expected_dx_vec[i]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册