提交 e43a96b5 编写于 作者: C chenjianping

support resize when init kernel

上级 cea3ed8c
...@@ -62,6 +62,7 @@ class LiteKernel { ...@@ -62,6 +62,7 @@ class LiteKernel {
const lite::Primitive *primitive) const lite::Primitive *primitive)
: opParameter(parameter), inputs_(inputs), outputs_(outputs), train_mode(false), primitive_(primitive), : opParameter(parameter), inputs_(inputs), outputs_(outputs), train_mode(false), primitive_(primitive),
context_(ctx) { context_(ctx) {
opParameter->thread_num_ = ctx->thread_num_;
this->in_kernel_.clear(); this->in_kernel_.clear();
this->out_kernel_.clear(); this->out_kernel_.clear();
} }
...@@ -69,12 +70,13 @@ class LiteKernel { ...@@ -69,12 +70,13 @@ class LiteKernel {
virtual ~LiteKernel() { delete opParameter; } virtual ~LiteKernel() { delete opParameter; }
virtual int Prepare() { virtual int Prepare() {
if (primitive_ != nullptr && !primitive_->GetInferFlag()) { if (!InferShapeDone()) {
(const_cast<lite::Primitive *>(primitive_))->InferShape(inputs_, outputs_); (const_cast<lite::Primitive *>(primitive_))->InferShape(inputs_, outputs_);
}
if (need_reinit) { if (need_reinit) {
Init(); Init();
} }
}
auto &outputs = this->GetOutputs(); auto &outputs = this->GetOutputs();
for (auto *output : outputs) { for (auto *output : outputs) {
MS_ASSERT(output != nullptr); MS_ASSERT(output != nullptr);
...@@ -126,6 +128,13 @@ class LiteKernel { ...@@ -126,6 +128,13 @@ class LiteKernel {
} }
protected: protected:
bool InferShapeDone() {
if (primitive_ != nullptr && !primitive_->GetInferFlag()) {
return false;
}
return true;
}
KernelKey desc; KernelKey desc;
std::string name; std::string name;
OpParameter *opParameter = nullptr; OpParameter *opParameter = nullptr;
......
...@@ -32,10 +32,6 @@ using mindspore::schema::PrimitiveType_ArgMin; ...@@ -32,10 +32,6 @@ using mindspore::schema::PrimitiveType_ArgMin;
namespace mindspore::kernel { namespace mindspore::kernel {
int ArgMinMaxBaseCPUKernel::Init() { int ArgMinMaxBaseCPUKernel::Init() {
if (context_->infer_shape_interrupt_ && !context_->running_) {
SetNeedReInit();
return RET_OK;
}
auto param = reinterpret_cast<ArgMinMaxParameter *>(opParameter); auto param = reinterpret_cast<ArgMinMaxParameter *>(opParameter);
switch (opParameter->type_) { switch (opParameter->type_) {
case PrimitiveType_ArgMax: case PrimitiveType_ArgMax:
...@@ -49,8 +45,13 @@ int ArgMinMaxBaseCPUKernel::Init() { ...@@ -49,8 +45,13 @@ int ArgMinMaxBaseCPUKernel::Init() {
return RET_ERROR; return RET_ERROR;
} }
return RET_OK;
}
int ArgMinMaxBaseCPUKernel::ReSize() {
auto in_shape = inputs_.at(0)->shape(); auto in_shape = inputs_.at(0)->shape();
auto dims_size = in_shape.size(); auto dims_size = in_shape.size();
auto param = reinterpret_cast<ArgMinMaxParameter *>(opParameter);
int axis = param->axis_ < 0 ? param->axis_ + dims_size : param->axis_; int axis = param->axis_ < 0 ? param->axis_ + dims_size : param->axis_;
param->axis_ = axis; param->axis_ = axis;
param->dims_size_ = dims_size; param->dims_size_ = dims_size;
......
...@@ -26,15 +26,13 @@ class ArgMinMaxBaseCPUKernel : public LiteKernel { ...@@ -26,15 +26,13 @@ class ArgMinMaxBaseCPUKernel : public LiteKernel {
ArgMinMaxBaseCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, ArgMinMaxBaseCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx, const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const lite::Primitive *primitive) const lite::Primitive *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), data_from_allocator_(false) { : LiteKernel(parameter, inputs, outputs, ctx, primitive), data_from_allocator_(false) {}
opParameter->thread_num_ = ctx->thread_num_;
}
virtual ~ArgMinMaxBaseCPUKernel() { FreeTmpMemory(); } virtual ~ArgMinMaxBaseCPUKernel() { FreeTmpMemory(); }
int Init() override; int Init() override;
int ReSize() override { return 0; } int ReSize() override;
int Run() override; int Run() override;
......
...@@ -30,10 +30,6 @@ using mindspore::schema::PrimitiveType_BatchToSpace; ...@@ -30,10 +30,6 @@ using mindspore::schema::PrimitiveType_BatchToSpace;
namespace mindspore::kernel { namespace mindspore::kernel {
int BatchToSpaceBaseCPUKernel::Init() { int BatchToSpaceBaseCPUKernel::Init() {
if (inputs_[0]->GetFormat() != schema::Format_NHWC) {
MS_LOG(ERROR) << "batch_to_space only support NHWC now!";
return RET_FORMAT_ERR;
}
BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->opParameter); BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->opParameter);
for (int i = 0; i < BATCH_TO_SPACE_CROPS_SIZE; ++i) { for (int i = 0; i < BATCH_TO_SPACE_CROPS_SIZE; ++i) {
if (param->crops_[i] != 0) { if (param->crops_[i] != 0) {
...@@ -43,6 +39,14 @@ int BatchToSpaceBaseCPUKernel::Init() { ...@@ -43,6 +39,14 @@ int BatchToSpaceBaseCPUKernel::Init() {
return RET_OK; return RET_OK;
} }
int BatchToSpaceBaseCPUKernel::ReSize() {
if (inputs_[0]->GetFormat() != schema::Format_NHWC) {
MS_LOG(ERROR) << "batch_to_space only support NHWC now!";
return RET_FORMAT_ERR;
}
return RET_OK;
}
kernel::LiteKernel *CpuBatchToSpaceInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs, kernel::LiteKernel *CpuBatchToSpaceInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *op_parameter, const lite::Context *ctx, OpParameter *op_parameter, const lite::Context *ctx,
......
...@@ -35,7 +35,7 @@ class BatchToSpaceBaseCPUKernel : public LiteKernel { ...@@ -35,7 +35,7 @@ class BatchToSpaceBaseCPUKernel : public LiteKernel {
int Init() override; int Init() override;
int ReSize() override { return 0; } int ReSize() override;
int Run() override { return 0; } int Run() override { return 0; }
......
...@@ -31,11 +31,9 @@ using mindspore::lite::RET_PARAM_INVALID; ...@@ -31,11 +31,9 @@ using mindspore::lite::RET_PARAM_INVALID;
using mindspore::schema::PrimitiveType_DepthToSpace; using mindspore::schema::PrimitiveType_DepthToSpace;
namespace mindspore::kernel { namespace mindspore::kernel {
int DepthToSpaceBaseCPUKernel::Init() { int DepthToSpaceBaseCPUKernel::Init() { return RET_OK; }
if (context_->infer_shape_interrupt_ && !context_->running_) {
SetNeedReInit(); int DepthToSpaceBaseCPUKernel::ReSize() {
return RET_OK;
}
if (inputs_[0]->GetFormat() != schema::Format_NHWC) { if (inputs_[0]->GetFormat() != schema::Format_NHWC) {
MS_LOG(ERROR) << "depth_to_space only support NHWC now!"; MS_LOG(ERROR) << "depth_to_space only support NHWC now!";
return RET_FORMAT_ERR; return RET_FORMAT_ERR;
......
...@@ -35,7 +35,7 @@ class DepthToSpaceBaseCPUKernel : public LiteKernel { ...@@ -35,7 +35,7 @@ class DepthToSpaceBaseCPUKernel : public LiteKernel {
int Init() override; int Init() override;
int ReSize() override { return 0; } int ReSize() override;
int Run() override { return 0; } int Run() override { return 0; }
}; };
......
...@@ -36,7 +36,15 @@ int ArgMinMaxCPUKernel::Init() { ...@@ -36,7 +36,15 @@ int ArgMinMaxCPUKernel::Init() {
} }
auto param = reinterpret_cast<ArgMinMaxParameter *>(opParameter); auto param = reinterpret_cast<ArgMinMaxParameter *>(opParameter);
param->data_type_ = kNumberTypeFloat32; param->data_type_ = kNumberTypeFloat32;
if (!InferShapeDone()) {
return RET_OK; return RET_OK;
}
return ReSize();
}
int ArgMinMaxCPUKernel::ReSize() {
ArgMinMaxBaseCPUKernel::FreeTmpMemory();
return ArgMinMaxBaseCPUKernel::ReSize();
} }
int ArgMinMaxCPUKernel::Run() { int ArgMinMaxCPUKernel::Run() {
......
...@@ -30,7 +30,7 @@ class ArgMinMaxCPUKernel : public ArgMinMaxBaseCPUKernel { ...@@ -30,7 +30,7 @@ class ArgMinMaxCPUKernel : public ArgMinMaxBaseCPUKernel {
~ArgMinMaxCPUKernel() = default; ~ArgMinMaxCPUKernel() = default;
int Init() override; int Init() override;
int ReSize() override { return 0; } int ReSize() override;
int Run() override; int Run() override;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
......
...@@ -24,7 +24,19 @@ using mindspore::lite::RET_OK; ...@@ -24,7 +24,19 @@ using mindspore::lite::RET_OK;
namespace mindspore::kernel { namespace mindspore::kernel {
int BatchToSpaceCPUKernel::Init() { int BatchToSpaceCPUKernel::Init() {
return BatchToSpaceBaseCPUKernel::Init(); auto ret = BatchToSpaceBaseCPUKernel::Init();
if (ret != RET_OK) {
return ret;
}
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int BatchToSpaceCPUKernel::ReSize() {
return BatchToSpaceBaseCPUKernel::ReSize();
} }
int BatchToSpaceCPUKernel::Run() { int BatchToSpaceCPUKernel::Run() {
......
...@@ -29,7 +29,7 @@ class BatchToSpaceCPUKernel : public BatchToSpaceBaseCPUKernel { ...@@ -29,7 +29,7 @@ class BatchToSpaceCPUKernel : public BatchToSpaceBaseCPUKernel {
~BatchToSpaceCPUKernel() = default; ~BatchToSpaceCPUKernel() = default;
int Init() override; int Init() override;
int ReSize() override { return 0; } int ReSize() override;
int Run() override; int Run() override;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
......
...@@ -37,7 +37,15 @@ int DepthToSpaceCPUKernel::Init() { ...@@ -37,7 +37,15 @@ int DepthToSpaceCPUKernel::Init() {
} }
DepthToSpaceParameter *param = reinterpret_cast<DepthToSpaceParameter *>(opParameter); DepthToSpaceParameter *param = reinterpret_cast<DepthToSpaceParameter *>(opParameter);
param->data_type_size_ = sizeof(float); param->data_type_size_ = sizeof(float);
if (!InferShapeDone()) {
return RET_OK; return RET_OK;
}
return ReSize();
}
int DepthToSpaceCPUKernel::ReSize() {
return DepthToSpaceBaseCPUKernel::ReSize();
} }
int DepthToSpaceCPUKernel::Run() { int DepthToSpaceCPUKernel::Run() {
......
...@@ -29,7 +29,7 @@ class DepthToSpaceCPUKernel : public DepthToSpaceBaseCPUKernel { ...@@ -29,7 +29,7 @@ class DepthToSpaceCPUKernel : public DepthToSpaceBaseCPUKernel {
~DepthToSpaceCPUKernel() = default; ~DepthToSpaceCPUKernel() = default;
int Init() override; int Init() override;
int ReSize() override { return 0; } int ReSize() override;
int Run() override; int Run() override;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
......
...@@ -40,14 +40,21 @@ int ArgMinMaxInt8CPUKernel::Init() { ...@@ -40,14 +40,21 @@ int ArgMinMaxInt8CPUKernel::Init() {
auto out_quant_args = out_tensor->GetQuantParams(); auto out_quant_args = out_tensor->GetQuantParams();
out_quant_arg_.scale_ = out_quant_args.front().scale; out_quant_arg_.scale_ = out_quant_args.front().scale;
out_quant_arg_.zp_ = out_quant_args.front().zeroPoint; out_quant_arg_.zp_ = out_quant_args.front().zeroPoint;
if (!InferShapeDone()) {
return RET_OK; return RET_OK;
}
return ReSize();
}
int ArgMinMaxInt8CPUKernel::ReSize() {
return ArgMinMaxBaseCPUKernel::ReSize();
} }
int ArgMinMaxInt8CPUKernel::Run() { int ArgMinMaxInt8CPUKernel::Run() {
auto ret = Prepare(); auto ret = Prepare();
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare failed."; MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
return RET_ERROR; return ret;
} }
auto input = inputs_.at(0); auto input = inputs_.at(0);
......
...@@ -31,7 +31,7 @@ class ArgMinMaxInt8CPUKernel : public ArgMinMaxBaseCPUKernel { ...@@ -31,7 +31,7 @@ class ArgMinMaxInt8CPUKernel : public ArgMinMaxBaseCPUKernel {
~ArgMinMaxInt8CPUKernel() = default; ~ArgMinMaxInt8CPUKernel() = default;
int Init() override; int Init() override;
int ReSize() override { return 0; } int ReSize() override;
int Run() override; int Run() override;
private: private:
QuantArg in_quant_arg_; QuantArg in_quant_arg_;
......
...@@ -38,7 +38,14 @@ int BatchToSpaceInt8CPUKernel::Init() { ...@@ -38,7 +38,14 @@ int BatchToSpaceInt8CPUKernel::Init() {
auto out_quant_args = out_tensor->GetQuantParams(); auto out_quant_args = out_tensor->GetQuantParams();
out_quant_arg_.scale_ = out_quant_args.front().scale; out_quant_arg_.scale_ = out_quant_args.front().scale;
out_quant_arg_.zp_ = out_quant_args.front().zeroPoint; out_quant_arg_.zp_ = out_quant_args.front().zeroPoint;
if (!InferShapeDone()) {
return RET_OK; return RET_OK;
}
return ReSize();
}
int BatchToSpaceInt8CPUKernel::ReSize() {
return BatchToSpaceBaseCPUKernel::ReSize();
} }
int BatchToSpaceInt8CPUKernel::Run() { int BatchToSpaceInt8CPUKernel::Run() {
......
...@@ -30,7 +30,7 @@ class BatchToSpaceInt8CPUKernel : public BatchToSpaceBaseCPUKernel { ...@@ -30,7 +30,7 @@ class BatchToSpaceInt8CPUKernel : public BatchToSpaceBaseCPUKernel {
~BatchToSpaceInt8CPUKernel() = default; ~BatchToSpaceInt8CPUKernel() = default;
int Init() override; int Init() override;
int ReSize() override { return 0; } int ReSize() override;
int Run() override; int Run() override;
private: private:
QuantArg in_quant_arg_; QuantArg in_quant_arg_;
......
...@@ -42,7 +42,14 @@ int DepthToSpaceInt8CPUKernel::Init() { ...@@ -42,7 +42,14 @@ int DepthToSpaceInt8CPUKernel::Init() {
auto out_quant_args = out_tensor->GetQuantParams(); auto out_quant_args = out_tensor->GetQuantParams();
out_quant_arg_.scale_ = out_quant_args.front().scale; out_quant_arg_.scale_ = out_quant_args.front().scale;
out_quant_arg_.zp_ = out_quant_args.front().zeroPoint; out_quant_arg_.zp_ = out_quant_args.front().zeroPoint;
if (!InferShapeDone()) {
return RET_OK; return RET_OK;
}
return ReSize();
}
int DepthToSpaceInt8CPUKernel::ReSize() {
return DepthToSpaceBaseCPUKernel::ReSize();
} }
int DepthToSpaceInt8CPUKernel::Run() { int DepthToSpaceInt8CPUKernel::Run() {
......
...@@ -30,7 +30,7 @@ class DepthToSpaceInt8CPUKernel : public DepthToSpaceBaseCPUKernel { ...@@ -30,7 +30,7 @@ class DepthToSpaceInt8CPUKernel : public DepthToSpaceBaseCPUKernel {
~DepthToSpaceInt8CPUKernel() = default; ~DepthToSpaceInt8CPUKernel() = default;
int Init() override; int Init() override;
int ReSize() override { return 0; } int ReSize() override;
int Run() override; int Run() override;
private: private:
QuantArg in_quant_arg_; QuantArg in_quant_arg_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册