From 9bd739fb48f89d5b1f0db4d525d5c0b23989b47c Mon Sep 17 00:00:00 2001 From: hong19860320 <9973393+hong19860320@users.noreply.github.com> Date: Thu, 30 Jul 2020 09:00:11 +0800 Subject: [PATCH] [ARM] [NPU] Fix the overflow of layer_norm op bridge, fix the registration of lstm op kernel (#4007) (#4017) * [NPU] Fix the overflow of layer_norm op bridge test=develop * [ARM] Fix the registration of lstm op kernel test=develop --- lite/kernels/arm/lstm_compute.cc | 2 ++ lite/kernels/npu/bridges/graph.h | 1 + lite/kernels/npu/bridges/layer_norm_op.cc | 9 ++++----- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/lite/kernels/arm/lstm_compute.cc b/lite/kernels/arm/lstm_compute.cc index 5335e230a0..5cc0a995da 100644 --- a/lite/kernels/arm/lstm_compute.cc +++ b/lite/kernels/arm/lstm_compute.cc @@ -208,6 +208,8 @@ REGISTER_LITE_KERNEL(lstm, .BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Weight", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("C0", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("H0", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Hidden", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Cell", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("BatchGate", {LiteType::GetTensorTy(TARGET(kARM))}) diff --git a/lite/kernels/npu/bridges/graph.h b/lite/kernels/npu/bridges/graph.h index 1bc588496a..18a0093ffd 100644 --- a/lite/kernels/npu/bridges/graph.h +++ b/lite/kernels/npu/bridges/graph.h @@ -20,6 +20,7 @@ #include #include #include "graph/compatible/all_ops.h" +#include "graph/op/all_ops.h" #include "lite/core/op_lite.h" #include "lite/core/tensor.h" diff --git a/lite/kernels/npu/bridges/layer_norm_op.cc b/lite/kernels/npu/bridges/layer_norm_op.cc index 8c12724a14..235e97dc72 100644 --- a/lite/kernels/npu/bridges/layer_norm_op.cc +++ b/lite/kernels/npu/bridges/layer_norm_op.cc @@ -127,12 +127,11 @@ int LayerNormConverter(void* ctx, OpLite* op, KernelBase* kernel) { } // LayerNorm node - auto layer_norm_node = graph->Add(y_name); - auto layer_norm_op = layer_norm_node->data(); + auto layer_norm_node = graph->Add(y_name); + auto layer_norm_op = layer_norm_node->data(); layer_norm_op->set_input_x(*x_node->data()); - layer_norm_op->set_input_scale(*scale_node->data()); - layer_norm_op->set_input_bias(*bias_node->data()); - layer_norm_op->set_attr_reduction_indices(ge::AttrValue::LIST_INT({3})); + layer_norm_op->set_input_gamma(*scale_node->data()); + layer_norm_op->set_input_beta(*bias_node->data()); layer_norm_op->set_attr_epsilon(epsilon); // Reshaped Y node if needs -- GitLab