diff --git a/lite/core/mir/memory_optimize_pass.cc b/lite/core/mir/memory_optimize_pass.cc index efcc7cef992e8c26b746357cdddb90a92f072aa3..ee78fac9a88aa339514778dcc03e2c907487fb39 100644 --- a/lite/core/mir/memory_optimize_pass.cc +++ b/lite/core/mir/memory_optimize_pass.cc @@ -78,6 +78,7 @@ void MemoryOptimizePass::CollectLifeCycleByDevice( // Collect the invalid input and output variables that will not be reused. std::unordered_set invalid_var_names; for (auto& op_node : graph->StmtTopologicalOrder()) { + // variables of invalid_op_nodes wil not be reused if (!op_node->IsStmt()) continue; auto op_info = op_node->AsStmt().op_info(); auto op_type = op_info->Type(); @@ -120,6 +121,13 @@ void MemoryOptimizePass::CollectLifeCycleByDevice( } } + // non-tensor(like tensor_array) variables will not be reused + for (auto& node : graph->nodes()) { + if (node.IsArg() && !node.arg()->type->IsTensor()) { + invalid_var_names.insert(node.arg()->name); + } + } + for (auto& op_node : graph->StmtTopologicalOrder()) { if (op_node->IsStmt()) { std::vector var_nodes(op_node->inlinks.begin(), diff --git a/lite/kernels/arm/elementwise_compute.cc b/lite/kernels/arm/elementwise_compute.cc index f824624cad3e97881a0d624e359b5e0ac7924c34..8115700f5950ddfcb71df49e6a21528563f23d95 100644 --- a/lite/kernels/arm/elementwise_compute.cc +++ b/lite/kernels/arm/elementwise_compute.cc @@ -182,32 +182,32 @@ void ElementwiseSubActivationCompute::Run() { template void ElementwiseMulCompute::Run() { auto& param = this->template Param(); - if (param.X->precision() == PRECISION(kFloat)) { - auto* x_data = param.X->template data(); - auto* y_data = param.Y->template data(); - auto* out_data = param.Out->template mutable_data(); - int axis = param.axis; - auto x_dims = param.X->dims(); - auto y_dims = param.Y->dims(); - int pre, n, post; - if (x_dims.size() < y_dims.size() && - is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) { - lite::arm::math::elementwise_mul_broadcast( - y_data, x_data, out_data, pre, n, post); - } else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { - lite::arm::math::elementwise_mul_broadcast( - x_data, y_data, out_data, pre, n, post); - } else { - lite::arm::math::elementwise_mul( - x_data, y_data, out_data, x_dims.production()); - } - } else if (param.X->precision() == PRECISION(kInt64)) { - lite::arm::math::elementwise_compute_basic(param, "mul", ""); + auto* x_data = param.X->template data(); + auto* y_data = param.Y->template data(); + auto* out_data = param.Out->template mutable_data(); + int axis = param.axis; + auto x_dims = param.X->dims(); + auto y_dims = param.Y->dims(); + int pre, n, post; + if (x_dims.size() < y_dims.size() && + is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) { + lite::arm::math::elementwise_mul_broadcast( + y_data, x_data, out_data, pre, n, post); + } else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { + lite::arm::math::elementwise_mul_broadcast( + x_data, y_data, out_data, pre, n, post); } else { - LOG(FATAL) << "unsupport input type"; + lite::arm::math::elementwise_mul( + x_data, y_data, out_data, x_dims.production()); } } +template <> +void ElementwiseMulCompute::Run() { + auto& param = this->template Param(); + lite::arm::math::elementwise_compute_basic(param, "mul", ""); +} + void ElementwiseMulActivationCompute::Run() { auto& param = Param(); const float* x_data = param.X->data(); @@ -420,6 +420,16 @@ REGISTER_LITE_KERNEL( .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .Finalize(); +using elementwise_mul_int64 = + paddle::lite::kernels::arm::ElementwiseMulCompute; +REGISTER_LITE_KERNEL( + elementwise_mul, kARM, kInt64, kNCHW, elementwise_mul_int64, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .Finalize(); + REGISTER_LITE_KERNEL( fusion_elementwise_mul_activation, kARM, diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index ee38fd6a1bdecec290942def401544ddb32f1ce5..25bbcc7687dd3000301f95c7ab365d2157f196dd 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -730,15 +730,15 @@ struct IncrementParam { }; struct WriteToArrayParam { - const lite::Tensor* X{}; - const lite::Tensor* I{}; - std::vector* Out{}; + const lite::Tensor* X{nullptr}; + const lite::Tensor* I{nullptr}; + std::vector* Out{nullptr}; }; struct ReadFromArrayParam { - const std::vector* X{}; - const lite::Tensor* I{}; - lite::Tensor* Out{}; + const std::vector* X{nullptr}; + const lite::Tensor* I{nullptr}; + lite::Tensor* Out{nullptr}; }; struct BeamSearchParam {