提交 98565d8b 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4768 [MS][LITE][Develop]kernels support resize

Merge pull request !4768 from chenjianping/lite_dev2
......@@ -64,16 +64,12 @@ class MS_API Context {
/// \brief Destructor of MindSpore Lite Context.
virtual ~Context();
void InferShapeInterrupt() { infer_shape_interrupt_ = true; }
public:
bool float16_priority = true; /**< allow priority select float16 kernel */
DeviceContext device_ctx_{DT_CPU};
int thread_num_ = 2; /**< thread number config for thread pool */
std::shared_ptr<Allocator> allocator = nullptr;
CpuBindMode cpu_bind_mode_ = MID_CPU;
bool infer_shape_interrupt_ = false;
bool running_ = false;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_INCLUDE_CONTEXT_H_
......@@ -82,9 +82,7 @@ class LiteKernel {
virtual int Prepare() {
if (!InferShapeDone()) {
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->InferShape(in_tensors_, out_tensors_);
if (need_reinit_) {
Init();
}
ReSize();
}
auto &outputs = this->out_tensors();
......@@ -152,8 +150,6 @@ class LiteKernel {
void set_desc(const KernelKey kernel_key) { desc_ = kernel_key; }
void set_need_reinit() { need_reinit_ = true; }
const mindspore::lite::PrimitiveC *GetPrimitive() const { return primitive_; }
protected:
......@@ -170,7 +166,6 @@ class LiteKernel {
std::vector<LiteKernel *> in_kernels_;
std::vector<LiteKernel *> out_kernels_;
bool train_mode_ = false;
bool need_reinit_ = false;
bool is_model_output_ = false;
};
......
......@@ -57,7 +57,7 @@ class PrimitiveC {
void SetInferFlag(bool flag);
virtual int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_);
virtual int InferShape(std::vector<lite::tensor::Tensor *> inputs, std::vector<lite::tensor::Tensor *> outputs);
int Type() const;
......
......@@ -40,16 +40,17 @@ int PriorBoxCPUKernel::Init() {
return RET_NULL_PTR;
}
if (context_->infer_shape_interrupt_ && !context_->running_) {
set_need_reinit();
return RET_OK;
}
MS_ASSERT(in_tensors_.size() == kInputNum);
MS_ASSERT(out_tensors_.size() == kOutputNum);
auto ret = GeneratePriorBox();
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
return ret;
int PriorBoxCPUKernel::ReSize() {
return GeneratePriorBox();
}
int PriorBoxCPUKernel::GeneratePriorBox() {
......@@ -158,6 +159,11 @@ int RunPriorBox(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
}
int PriorBoxCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail! Ret error code[" << prepare_ret << "]";
return prepare_ret;
}
int error_code = LiteBackendParallelLaunch(RunPriorBox, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "PriorBox run error, error_code[" << error_code << "]";
......@@ -168,18 +174,18 @@ int PriorBoxCPUKernel::Run() {
kernel::LiteKernel *CpuPriorBoxKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const Context *ctx,
OpParameter *op_parameter, const Context *ctx,
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!";
if (op_parameter == nullptr) {
MS_LOG(ERROR) << "Input op_parameter is nullptr!";
return nullptr;
}
if (desc.type != schema::PrimitiveType_PriorBox) {
MS_LOG(ERROR) << "PriorBox invalid desc type " << desc.type;
return nullptr;
}
auto *kernel = new (std::nothrow) PriorBoxCPUKernel(opParameter, inputs, outputs, ctx, primitive);
auto *kernel = new (std::nothrow) PriorBoxCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new PriorBoxCPUKernel fail!";
return nullptr;
......@@ -187,8 +193,8 @@ kernel::LiteKernel *CpuPriorBoxKernelCreator(const std::vector<lite::tensor::Ten
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
return nullptr;
}
return kernel;
......
......@@ -36,7 +36,7 @@ class PriorBoxCPUKernel : public LiteKernel {
~PriorBoxCPUKernel() = default;
int Init() override;
int ReSize() override { return 0; }
int ReSize() override;
int Run() override;
int PriorBoxImpl(int task_id);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册