diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index 80aacbf7efe2f13a6cb2b04201e036561e682bf1..87eab5cb7ea462c792ad351bf72a934668318195 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -46,7 +46,7 @@ add_kernel(reduce_max_compute_arm ARM basic SRCS reduce_max_compute.cc DEPS ${li add_kernel(sequence_expand_compute_arm ARM basic SRCS sequence_expand_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(im2sequence_compute_arm ARM basic SRCS im2sequence_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(sequence_pool_compute_arm ARM basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} math_arm) -add_kernel(layer_norm_compute_arm ARM extra SRCS layer_norm_compute.cc DEPS ${lite_kernel_deps} math_arm) +#add_kernel(layer_norm_compute_arm ARM extra SRCS layer_norm_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(gather_compute_arm ARM extra SRCS gather_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(reduce_mean_compute_arm ARM extra SRCS reduce_mean_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(stack_compute_arm ARM extra SRCS stack_compute.cc DEPS ${lite_kernel_deps} math_arm) @@ -100,5 +100,5 @@ lite_cc_test(test_dropout_compute_arm SRCS dropout_compute_test.cc DEPS dropout_ lite_cc_test(test_transpose_compute_arm SRCS transpose_compute_test.cc DEPS transpose_compute_arm COMPILE_LEVEL extra) lite_cc_test(test_argmax_compute_arm SRCS argmax_compute_test.cc DEPS argmax_compute_arm) lite_cc_test(test_axpy_compute_arm SRCS axpy_compute_test.cc DEPS axpy_compute_arm) -lite_cc_test(test_layer_norm_compute_arm SRCS layer_norm_compute_test.cc DEPS layer_norm_compute_arm) +#lite_cc_test(test_layer_norm_compute_arm SRCS layer_norm_compute_test.cc DEPS layer_norm_compute_arm) lite_cc_test(test_conv_transpose_compute_arm SRCS conv_transpose_compute_test.cc DEPS conv_transpose_compute_arm) diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 2381f3f9265291bf7a436877c60997bd0d1f3498..981dab38318314400576eb83cd35132497ece2b4 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -317,7 +317,7 @@ struct SplitParam { lite::Tensor* x{}; std::vector output{}; lite::Tensor* axis_tensor; - std::vector* sections_tensor_list{}; + std::vector sections_tensor_list{}; int axis{-1}; int num{0}; @@ -380,7 +380,6 @@ struct MeanGradParam { struct FillConstantParam { int dtype{static_cast(VarDescAPI::VarDataType::FP32)}; std::vector shape{}; - float value{0.0f}; // useless for x86, keep it for compatibility bool force_cpu{false}; diff --git a/lite/operators/split_op.cc b/lite/operators/split_op.cc index 36ab7dff9a6026ad5b6423b1395f3d502c9bb4d6..ec98a0d6c3ba3b1e5cd1c7992b58e96917d21057 100644 --- a/lite/operators/split_op.cc +++ b/lite/operators/split_op.cc @@ -39,15 +39,13 @@ bool SplitOp::InferShape() const { const int outs_number = outs.size(); std::vector outs_dims; outs_dims.reserve(outs_number); - std::vector *sections_tensor_list_ = + std::vector sections_tensor_list_ = param_.sections_tensor_list; - if (sections.size() > 0 && sections_tensor_list_->size() > 0) { + if (sections.size() > 0 && sections_tensor_list_.size() > 0) { std::vector vec_sections; - for (size_t i = 0; i < sections_tensor_list_->size(); ++i) { + for (size_t i = 0; i < sections_tensor_list_.size(); ++i) { auto dim = in_dims; - // lite::TensorLite aa = sections_tensor_list_[i]; - dim[axis] = (*sections_tensor_list_)[i].data()[0]; - // final_axes.push_back(axes_tensor_vct[i].data()[0]); + dim[axis] = sections_tensor_list_[i]->data()[0]; outs_dims.push_back(dim); } } else if (num > 0) { @@ -87,15 +85,20 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { for (auto var : outs) { param_.output.push_back(scope->FindVar(var)->GetMutable()); } - if (opdesc.HasAttr("AxisTensor")) { + std::vector input_arg_names = opdesc.InputArgumentNames(); + if (std::find(input_arg_names.begin(), input_arg_names.end(), "AxisTensor") != + input_arg_names.end()) { auto args = opdesc.Input("AxisTensor"); auto *var = scope->FindVar(args.front()); param_.axis_tensor = var->GetMutable(); } - if (opdesc.HasAttr("SectionsTensorList")) { + if (std::find(input_arg_names.begin(), + input_arg_names.end(), + "SectionsTensorList") != input_arg_names.end()) { auto args = opdesc.Input("SectionsTensorList"); auto *var = scope->FindVar(args.front()); - param_.sections_tensor_list = var->GetMutable>(); + param_.sections_tensor_list = + *(var->GetMutable>()); } return true; }