diff --git a/mindspore/lite/include/context.h b/mindspore/lite/include/context.h index d709dba358b791ba2165dcfd7c245f47df33a221..f7e5472704889ee2d1fbf838199440edb366d194 100644 --- a/mindspore/lite/include/context.h +++ b/mindspore/lite/include/context.h @@ -64,16 +64,12 @@ class MS_API Context { /// \brief Destructor of MindSpore Lite Context. virtual ~Context(); - void InferShapeInterrupt() { infer_shape_interrupt_ = true; } - public: bool float16_priority = true; /**< allow priority select float16 kernel */ DeviceContext device_ctx_{DT_CPU}; int thread_num_ = 2; /**< thread number config for thread pool */ std::shared_ptr allocator = nullptr; CpuBindMode cpu_bind_mode_ = MID_CPU; - bool infer_shape_interrupt_ = false; - bool running_ = false; }; } // namespace mindspore::lite #endif // MINDSPORE_LITE_INCLUDE_CONTEXT_H_ diff --git a/mindspore/lite/src/lite_kernel.h b/mindspore/lite/src/lite_kernel.h index 6fc66081778d756ab45c00c7a1febe6b91b65e0a..6a4a42bc3d42b83e286fb334369a3084dee209e5 100644 --- a/mindspore/lite/src/lite_kernel.h +++ b/mindspore/lite/src/lite_kernel.h @@ -82,9 +82,7 @@ class LiteKernel { virtual int Prepare() { if (!InferShapeDone()) { (const_cast(primitive_))->InferShape(in_tensors_, out_tensors_); - if (need_reinit_) { - Init(); - } + ReSize(); } auto &outputs = this->out_tensors(); @@ -152,8 +150,6 @@ class LiteKernel { void set_desc(const KernelKey kernel_key) { desc_ = kernel_key; } - void set_need_reinit() { need_reinit_ = true; } - const mindspore::lite::PrimitiveC *GetPrimitive() const { return primitive_; } protected: @@ -170,7 +166,6 @@ class LiteKernel { std::vector in_kernels_; std::vector out_kernels_; bool train_mode_ = false; - bool need_reinit_ = false; bool is_model_output_ = false; }; diff --git a/mindspore/lite/src/ops/primitive_c.h b/mindspore/lite/src/ops/primitive_c.h index 80c3e0d92de81daf9d505c4f36968ffb01d4a47d..0ddef8a4c52e1cf2916e82f9c1e4111ba330ec9e 100644 --- a/mindspore/lite/src/ops/primitive_c.h +++ b/mindspore/lite/src/ops/primitive_c.h @@ -57,7 +57,7 @@ class PrimitiveC { void SetInferFlag(bool flag); - virtual int InferShape(std::vector inputs_, std::vector outputs_); + virtual int InferShape(std::vector inputs, std::vector outputs); int Type() const; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/prior_box.cc b/mindspore/lite/src/runtime/kernel/arm/base/prior_box.cc index 5329feae5ae5dc88e4da8bb0fb054b319441b9ff..e354183e5b82068de5700412bae24ce169187efc 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/prior_box.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/prior_box.cc @@ -40,16 +40,17 @@ int PriorBoxCPUKernel::Init() { return RET_NULL_PTR; } - if (context_->infer_shape_interrupt_ && !context_->running_) { - set_need_reinit(); - return RET_OK; - } MS_ASSERT(in_tensors_.size() == kInputNum); MS_ASSERT(out_tensors_.size() == kOutputNum); - auto ret = GeneratePriorBox(); + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} - return ret; +int PriorBoxCPUKernel::ReSize() { + return GeneratePriorBox(); } int PriorBoxCPUKernel::GeneratePriorBox() { @@ -158,6 +159,11 @@ int RunPriorBox(int task_id, LiteParallelGroupEnv *penv, void *cdata) { } int PriorBoxCPUKernel::Run() { + auto prepare_ret = Prepare(); + if (prepare_ret != RET_OK) { + MS_LOG(ERROR) << "Prepare fail! Ret error code[" << prepare_ret << "]"; + return prepare_ret; + } int error_code = LiteBackendParallelLaunch(RunPriorBox, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "PriorBox run error, error_code[" << error_code << "]"; @@ -168,18 +174,18 @@ int PriorBoxCPUKernel::Run() { kernel::LiteKernel *CpuPriorBoxKernelCreator(const std::vector &inputs, const std::vector &outputs, - OpParameter *opParameter, const Context *ctx, + OpParameter *op_parameter, const Context *ctx, const kernel::KernelKey &desc, const mindspore::lite::PrimitiveC *primitive) { - if (opParameter == nullptr) { - MS_LOG(ERROR) << "Input opParameter is nullptr!"; + if (op_parameter == nullptr) { + MS_LOG(ERROR) << "Input op_parameter is nullptr!"; return nullptr; } if (desc.type != schema::PrimitiveType_PriorBox) { MS_LOG(ERROR) << "PriorBox invalid desc type " << desc.type; return nullptr; } - auto *kernel = new (std::nothrow) PriorBoxCPUKernel(opParameter, inputs, outputs, ctx, primitive); + auto *kernel = new (std::nothrow) PriorBoxCPUKernel(op_parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "new PriorBoxCPUKernel fail!"; return nullptr; @@ -187,8 +193,8 @@ kernel::LiteKernel *CpuPriorBoxKernelCreator(const std::vectorInit(); if (ret != RET_OK) { delete kernel; - MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " - << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(op_parameter->type_)); return nullptr; } return kernel; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/prior_box.h b/mindspore/lite/src/runtime/kernel/arm/base/prior_box.h index 54a04cc4ea4309340a323d1be29e2929f70024e1..ace70e6ae73923cbbbe1046019ac4e5dbf5eff79 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/prior_box.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/prior_box.h @@ -36,7 +36,7 @@ class PriorBoxCPUKernel : public LiteKernel { ~PriorBoxCPUKernel() = default; int Init() override; - int ReSize() override { return 0; } + int ReSize() override; int Run() override; int PriorBoxImpl(int task_id);