未验证 提交 9bd739fb 编写于 作者: H hong19860320 提交者: GitHub

[ARM] [NPU] Fix the overflow of layer_norm op bridge, fix the registration of...

[ARM] [NPU] Fix the overflow of layer_norm op bridge, fix the registration of lstm op kernel (#4007) (#4017)

* [NPU] Fix the overflow of layer_norm op bridge
test=develop

* [ARM] Fix the registration of lstm op kernel
test=develop
上级 faacbd27
......@@ -208,6 +208,8 @@ REGISTER_LITE_KERNEL(lstm,
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Weight", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("C0", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("H0", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Hidden", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Cell", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("BatchGate", {LiteType::GetTensorTy(TARGET(kARM))})
......
......@@ -20,6 +20,7 @@
#include <utility>
#include <vector>
#include "graph/compatible/all_ops.h"
#include "graph/op/all_ops.h"
#include "lite/core/op_lite.h"
#include "lite/core/tensor.h"
......
......@@ -127,12 +127,11 @@ int LayerNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
}
// LayerNorm node
auto layer_norm_node = graph->Add<ge::op::InstanceNorm>(y_name);
auto layer_norm_op = layer_norm_node->data<ge::op::InstanceNorm>();
auto layer_norm_node = graph->Add<hiai::op::LayerNorm>(y_name);
auto layer_norm_op = layer_norm_node->data<hiai::op::LayerNorm>();
layer_norm_op->set_input_x(*x_node->data());
layer_norm_op->set_input_scale(*scale_node->data());
layer_norm_op->set_input_bias(*bias_node->data());
layer_norm_op->set_attr_reduction_indices(ge::AttrValue::LIST_INT({3}));
layer_norm_op->set_input_gamma(*scale_node->data());
layer_norm_op->set_input_beta(*bias_node->data());
layer_norm_op->set_attr_epsilon(epsilon);
// Reshaped Y node if needs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册