From 4c191812105f46aa1d2ff4cdfa422f3548d0b72e Mon Sep 17 00:00:00 2001 From: xiebaiyuan Date: Fri, 7 Sep 2018 15:04:19 +0800 Subject: [PATCH] shape op --- src/io/executor.cpp | 4 ++++ src/operators/kernel/central-arm-func/shape_arm_func.h | 10 +++++++++- src/operators/shape_op.cpp | 4 ++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/io/executor.cpp b/src/io/executor.cpp index 4b92ce0829..47936921a6 100644 --- a/src/io/executor.cpp +++ b/src/io/executor.cpp @@ -176,6 +176,7 @@ void Executor::LoadMemory(const framework::VarDesc var_desc, type_size = 8; break; case framework::VARTYPE_TYPE_INT32: + memory = tensor->mutable_data(); type_size = 4; break; case framework::VARTYPE_TYPE_INT64: @@ -308,6 +309,9 @@ bool Executor::varInputMemory( } case framework::VARTYPE_TYPE_INT32: { + tensor = var->template GetMutable(); + tensor->template mutable_data(); + is_mute_match = true; break; } diff --git a/src/operators/kernel/central-arm-func/shape_arm_func.h b/src/operators/kernel/central-arm-func/shape_arm_func.h index 846f4c474a..895877efd2 100644 --- a/src/operators/kernel/central-arm-func/shape_arm_func.h +++ b/src/operators/kernel/central-arm-func/shape_arm_func.h @@ -22,7 +22,15 @@ namespace paddle_mobile { namespace operators { template -void ShapeCompute(const ShapeParam& param) {} +void ShapeCompute(const ShapeParam& param) { + auto* in_t = param.InputX(); + auto* out_t = param.Out(); + auto out_data = out_t->mutable_data(); + auto in_dims = in_t->dims(); + for (int i = 0; i < in_dims.size(); ++i) { + out_data[i] = static_cast(in_dims[i]); + } +} } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/shape_op.cpp b/src/operators/shape_op.cpp index ed91fd5d66..55fbc80f57 100644 --- a/src/operators/shape_op.cpp +++ b/src/operators/shape_op.cpp @@ -20,6 +20,10 @@ namespace paddle_mobile { namespace operators { template void ShapeOp::InferShape() const { + PADDLE_MOBILE_ENFORCE(this->param_.InputX() != nullptr, + "Input (Input) of get_shape op should not be null."); + PADDLE_MOBILE_ENFORCE(this->param_.Out() != nullptr, + "Output (Out) of get_shape op should not be null."); this->param_.Out()->Resize(this->param_.InputX()->dims()); } -- GitLab