提交 e43a96b5 编写于 作者: C chenjianping

support resize when init kernel

上级 cea3ed8c
......@@ -62,6 +62,7 @@ class LiteKernel {
const lite::Primitive *primitive)
: opParameter(parameter), inputs_(inputs), outputs_(outputs), train_mode(false), primitive_(primitive),
context_(ctx) {
opParameter->thread_num_ = ctx->thread_num_;
this->in_kernel_.clear();
this->out_kernel_.clear();
}
......@@ -69,12 +70,13 @@ class LiteKernel {
virtual ~LiteKernel() { delete opParameter; }
virtual int Prepare() {
if (primitive_ != nullptr && !primitive_->GetInferFlag()) {
if (!InferShapeDone()) {
(const_cast<lite::Primitive *>(primitive_))->InferShape(inputs_, outputs_);
if (need_reinit) {
Init();
}
}
if (need_reinit) {
Init();
}
auto &outputs = this->GetOutputs();
for (auto *output : outputs) {
MS_ASSERT(output != nullptr);
......@@ -126,6 +128,13 @@ class LiteKernel {
}
protected:
bool InferShapeDone() {
if (primitive_ != nullptr && !primitive_->GetInferFlag()) {
return false;
}
return true;
}
KernelKey desc;
std::string name;
OpParameter *opParameter = nullptr;
......
......@@ -32,10 +32,6 @@ using mindspore::schema::PrimitiveType_ArgMin;
namespace mindspore::kernel {
int ArgMinMaxBaseCPUKernel::Init() {
if (context_->infer_shape_interrupt_ && !context_->running_) {
SetNeedReInit();
return RET_OK;
}
auto param = reinterpret_cast<ArgMinMaxParameter *>(opParameter);
switch (opParameter->type_) {
case PrimitiveType_ArgMax:
......@@ -49,8 +45,13 @@ int ArgMinMaxBaseCPUKernel::Init() {
return RET_ERROR;
}
return RET_OK;
}
int ArgMinMaxBaseCPUKernel::ReSize() {
auto in_shape = inputs_.at(0)->shape();
auto dims_size = in_shape.size();
auto param = reinterpret_cast<ArgMinMaxParameter *>(opParameter);
int axis = param->axis_ < 0 ? param->axis_ + dims_size : param->axis_;
param->axis_ = axis;
param->dims_size_ = dims_size;
......
......@@ -26,15 +26,13 @@ class ArgMinMaxBaseCPUKernel : public LiteKernel {
ArgMinMaxBaseCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const lite::Primitive *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), data_from_allocator_(false) {
opParameter->thread_num_ = ctx->thread_num_;
}
: LiteKernel(parameter, inputs, outputs, ctx, primitive), data_from_allocator_(false) {}
virtual ~ArgMinMaxBaseCPUKernel() { FreeTmpMemory(); }
int Init() override;
int ReSize() override { return 0; }
int ReSize() override;
int Run() override;
......
......@@ -30,10 +30,6 @@ using mindspore::schema::PrimitiveType_BatchToSpace;
namespace mindspore::kernel {
int BatchToSpaceBaseCPUKernel::Init() {
if (inputs_[0]->GetFormat() != schema::Format_NHWC) {
MS_LOG(ERROR) << "batch_to_space only support NHWC now!";
return RET_FORMAT_ERR;
}
BatchToSpaceParameter *param = reinterpret_cast<BatchToSpaceParameter *>(this->opParameter);
for (int i = 0; i < BATCH_TO_SPACE_CROPS_SIZE; ++i) {
if (param->crops_[i] != 0) {
......@@ -43,6 +39,14 @@ int BatchToSpaceBaseCPUKernel::Init() {
return RET_OK;
}
int BatchToSpaceBaseCPUKernel::ReSize() {
if (inputs_[0]->GetFormat() != schema::Format_NHWC) {
MS_LOG(ERROR) << "batch_to_space only support NHWC now!";
return RET_FORMAT_ERR;
}
return RET_OK;
}
kernel::LiteKernel *CpuBatchToSpaceInt8KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *op_parameter, const lite::Context *ctx,
......
......@@ -35,7 +35,7 @@ class BatchToSpaceBaseCPUKernel : public LiteKernel {
int Init() override;
int ReSize() override { return 0; }
int ReSize() override;
int Run() override { return 0; }
......
......@@ -31,11 +31,9 @@ using mindspore::lite::RET_PARAM_INVALID;
using mindspore::schema::PrimitiveType_DepthToSpace;
namespace mindspore::kernel {
int DepthToSpaceBaseCPUKernel::Init() {
if (context_->infer_shape_interrupt_ && !context_->running_) {
SetNeedReInit();
return RET_OK;
}
int DepthToSpaceBaseCPUKernel::Init() { return RET_OK; }
int DepthToSpaceBaseCPUKernel::ReSize() {
if (inputs_[0]->GetFormat() != schema::Format_NHWC) {
MS_LOG(ERROR) << "depth_to_space only support NHWC now!";
return RET_FORMAT_ERR;
......
......@@ -35,7 +35,7 @@ class DepthToSpaceBaseCPUKernel : public LiteKernel {
int Init() override;
int ReSize() override { return 0; }
int ReSize() override;
int Run() override { return 0; }
};
......
......@@ -36,7 +36,15 @@ int ArgMinMaxCPUKernel::Init() {
}
auto param = reinterpret_cast<ArgMinMaxParameter *>(opParameter);
param->data_type_ = kNumberTypeFloat32;
return RET_OK;
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int ArgMinMaxCPUKernel::ReSize() {
ArgMinMaxBaseCPUKernel::FreeTmpMemory();
return ArgMinMaxBaseCPUKernel::ReSize();
}
int ArgMinMaxCPUKernel::Run() {
......
......@@ -30,7 +30,7 @@ class ArgMinMaxCPUKernel : public ArgMinMaxBaseCPUKernel {
~ArgMinMaxCPUKernel() = default;
int Init() override;
int ReSize() override { return 0; }
int ReSize() override;
int Run() override;
};
} // namespace mindspore::kernel
......
......@@ -24,7 +24,19 @@ using mindspore::lite::RET_OK;
namespace mindspore::kernel {
int BatchToSpaceCPUKernel::Init() {
return BatchToSpaceBaseCPUKernel::Init();
auto ret = BatchToSpaceBaseCPUKernel::Init();
if (ret != RET_OK) {
return ret;
}
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int BatchToSpaceCPUKernel::ReSize() {
return BatchToSpaceBaseCPUKernel::ReSize();
}
int BatchToSpaceCPUKernel::Run() {
......
......@@ -29,7 +29,7 @@ class BatchToSpaceCPUKernel : public BatchToSpaceBaseCPUKernel {
~BatchToSpaceCPUKernel() = default;
int Init() override;
int ReSize() override { return 0; }
int ReSize() override;
int Run() override;
};
} // namespace mindspore::kernel
......
......@@ -37,7 +37,15 @@ int DepthToSpaceCPUKernel::Init() {
}
DepthToSpaceParameter *param = reinterpret_cast<DepthToSpaceParameter *>(opParameter);
param->data_type_size_ = sizeof(float);
return RET_OK;
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int DepthToSpaceCPUKernel::ReSize() {
return DepthToSpaceBaseCPUKernel::ReSize();
}
int DepthToSpaceCPUKernel::Run() {
......
......@@ -29,7 +29,7 @@ class DepthToSpaceCPUKernel : public DepthToSpaceBaseCPUKernel {
~DepthToSpaceCPUKernel() = default;
int Init() override;
int ReSize() override { return 0; }
int ReSize() override;
int Run() override;
};
} // namespace mindspore::kernel
......
......@@ -40,14 +40,21 @@ int ArgMinMaxInt8CPUKernel::Init() {
auto out_quant_args = out_tensor->GetQuantParams();
out_quant_arg_.scale_ = out_quant_args.front().scale;
out_quant_arg_.zp_ = out_quant_args.front().zeroPoint;
return RET_OK;
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int ArgMinMaxInt8CPUKernel::ReSize() {
return ArgMinMaxBaseCPUKernel::ReSize();
}
int ArgMinMaxInt8CPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare failed.";
return RET_ERROR;
MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
return ret;
}
auto input = inputs_.at(0);
......
......@@ -31,7 +31,7 @@ class ArgMinMaxInt8CPUKernel : public ArgMinMaxBaseCPUKernel {
~ArgMinMaxInt8CPUKernel() = default;
int Init() override;
int ReSize() override { return 0; }
int ReSize() override;
int Run() override;
private:
QuantArg in_quant_arg_;
......
......@@ -38,7 +38,14 @@ int BatchToSpaceInt8CPUKernel::Init() {
auto out_quant_args = out_tensor->GetQuantParams();
out_quant_arg_.scale_ = out_quant_args.front().scale;
out_quant_arg_.zp_ = out_quant_args.front().zeroPoint;
return RET_OK;
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int BatchToSpaceInt8CPUKernel::ReSize() {
return BatchToSpaceBaseCPUKernel::ReSize();
}
int BatchToSpaceInt8CPUKernel::Run() {
......
......@@ -30,7 +30,7 @@ class BatchToSpaceInt8CPUKernel : public BatchToSpaceBaseCPUKernel {
~BatchToSpaceInt8CPUKernel() = default;
int Init() override;
int ReSize() override { return 0; }
int ReSize() override;
int Run() override;
private:
QuantArg in_quant_arg_;
......
......@@ -42,7 +42,14 @@ int DepthToSpaceInt8CPUKernel::Init() {
auto out_quant_args = out_tensor->GetQuantParams();
out_quant_arg_.scale_ = out_quant_args.front().scale;
out_quant_arg_.zp_ = out_quant_args.front().zeroPoint;
return RET_OK;
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int DepthToSpaceInt8CPUKernel::ReSize() {
return DepthToSpaceBaseCPUKernel::ReSize();
}
int DepthToSpaceInt8CPUKernel::Run() {
......
......@@ -30,7 +30,7 @@ class DepthToSpaceInt8CPUKernel : public DepthToSpaceBaseCPUKernel {
~DepthToSpaceInt8CPUKernel() = default;
int Init() override;
int ReSize() override { return 0; }
int ReSize() override;
int Run() override;
private:
QuantArg in_quant_arg_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册