diff --git a/src/io/executor.cpp b/src/io/executor.cpp index 4b92ce0829ce23426406b88783f6e232e33445b0..47936921a6984c61cc02c222461346081b5bccdf 100644 --- a/src/io/executor.cpp +++ b/src/io/executor.cpp @@ -176,6 +176,7 @@ void Executor::LoadMemory(const framework::VarDesc var_desc, type_size = 8; break; case framework::VARTYPE_TYPE_INT32: + memory = tensor->mutable_data(); type_size = 4; break; case framework::VARTYPE_TYPE_INT64: @@ -308,6 +309,9 @@ bool Executor::varInputMemory( } case framework::VARTYPE_TYPE_INT32: { + tensor = var->template GetMutable(); + tensor->template mutable_data(); + is_mute_match = true; break; } diff --git a/src/operators/kernel/central-arm-func/shape_arm_func.h b/src/operators/kernel/central-arm-func/shape_arm_func.h index 846f4c474ab067f02cfd809a876abc7b62ab784f..895877efd27cfe3e00ad62ebd3e0584eacc47ed7 100644 --- a/src/operators/kernel/central-arm-func/shape_arm_func.h +++ b/src/operators/kernel/central-arm-func/shape_arm_func.h @@ -22,7 +22,15 @@ namespace paddle_mobile { namespace operators { template -void ShapeCompute(const ShapeParam& param) {} +void ShapeCompute(const ShapeParam& param) { + auto* in_t = param.InputX(); + auto* out_t = param.Out(); + auto out_data = out_t->mutable_data(); + auto in_dims = in_t->dims(); + for (int i = 0; i < in_dims.size(); ++i) { + out_data[i] = static_cast(in_dims[i]); + } +} } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/shape_op.cpp b/src/operators/shape_op.cpp index ed91fd5d663ebb734051e3636aa5f95e3623e4d9..55fbc80f5795e303605f645d8caaa6edc577c25c 100644 --- a/src/operators/shape_op.cpp +++ b/src/operators/shape_op.cpp @@ -20,6 +20,10 @@ namespace paddle_mobile { namespace operators { template void ShapeOp::InferShape() const { + PADDLE_MOBILE_ENFORCE(this->param_.InputX() != nullptr, + "Input (Input) of get_shape op should not be null."); + PADDLE_MOBILE_ENFORCE(this->param_.Out() != nullptr, + "Output (Out) of get_shape op should not be null."); this->param_.Out()->Resize(this->param_.InputX()->dims()); }