未验证 提交 61ec5d82 编写于 作者: H hong19860320 提交者: GitHub

[ARM] [NPU] Fix the overflow of layer_norm op bridge, fix the registration of...

[ARM] [NPU] Fix the overflow of layer_norm op bridge, fix the registration of lstm op kernel (#4007)

* [NPU] Fix the overflow of layer_norm op bridge
test=develop

* [ARM] Fix the registration of lstm op kernel
test=develop
上级 6b203635
...@@ -208,6 +208,8 @@ REGISTER_LITE_KERNEL(lstm, ...@@ -208,6 +208,8 @@ REGISTER_LITE_KERNEL(lstm,
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Weight", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Weight", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("C0", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("H0", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Hidden", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Hidden", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Cell", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Cell", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("BatchGate", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("BatchGate", {LiteType::GetTensorTy(TARGET(kARM))})
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "graph/compatible/all_ops.h" #include "graph/compatible/all_ops.h"
#include "graph/op/all_ops.h"
#include "lite/core/op_lite.h" #include "lite/core/op_lite.h"
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
......
...@@ -127,12 +127,11 @@ int LayerNormConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -127,12 +127,11 @@ int LayerNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} }
// LayerNorm node // LayerNorm node
auto layer_norm_node = graph->Add<ge::op::InstanceNorm>(y_name); auto layer_norm_node = graph->Add<hiai::op::LayerNorm>(y_name);
auto layer_norm_op = layer_norm_node->data<ge::op::InstanceNorm>(); auto layer_norm_op = layer_norm_node->data<hiai::op::LayerNorm>();
layer_norm_op->set_input_x(*x_node->data()); layer_norm_op->set_input_x(*x_node->data());
layer_norm_op->set_input_scale(*scale_node->data()); layer_norm_op->set_input_gamma(*scale_node->data());
layer_norm_op->set_input_bias(*bias_node->data()); layer_norm_op->set_input_beta(*bias_node->data());
layer_norm_op->set_attr_reduction_indices(ge::AttrValue::LIST_INT({3}));
layer_norm_op->set_attr_epsilon(epsilon); layer_norm_op->set_attr_epsilon(epsilon);
// Reshaped Y node if needs // Reshaped Y node if needs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册