提交 ecb87385 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3847 modify populate op parameter

Merge pull request !3847 from yangruoqi713/lite
...@@ -69,320 +69,339 @@ ...@@ -69,320 +69,339 @@
#include "src/runtime/kernel/arm/opclib/fp32/quantize.h" #include "src/runtime/kernel/arm/opclib/fp32/quantize.h"
namespace mindspore::kernel { namespace mindspore::kernel {
FillParameter *PopulateFillParam(const lite::Primitive *primitive) { OpParameter *PopulateFillParameter(const lite::Primitive *primitive) {
auto param = primitive->Value()->value_as_Fill(); auto param = primitive->Value()->value_as_Fill();
FillParameter *parameter = new (std::nothrow) FillParameter(); FillParameter *fill_param = new (std::nothrow) FillParameter();
if (parameter == nullptr) { if (fill_param == nullptr) {
MS_LOG(ERROR) << "new FillParameter failed."; MS_LOG(ERROR) << "new FillParameter failed.";
return nullptr; return nullptr;
} }
fill_param->op_parameter_.type_ = primitive->Type();
auto flatDims = param->dims(); auto flatDims = param->dims();
parameter->num_dims_ = flatDims->size(); fill_param->num_dims_ = flatDims->size();
int i = 0; int i = 0;
for (auto iter = flatDims->begin(); iter != flatDims->end(); iter++) { for (auto iter = flatDims->begin(); iter != flatDims->end(); iter++) {
parameter->dims_[i++] = *iter; fill_param->dims_[i++] = *iter;
} }
return parameter; return reinterpret_cast<OpParameter *>(fill_param);
} }
ExpandDimsParameter *PopulateExpandDimsParam(const lite::Primitive *primitive) { OpParameter *PopulateExpandDimsParameter(const lite::Primitive *primitive) {
auto param = primitive->Value()->value_as_ExpandDims(); auto param = primitive->Value()->value_as_ExpandDims();
ExpandDimsParameter *parameter = new (std::nothrow) ExpandDimsParameter(); ExpandDimsParameter *expand_dims_param = new (std::nothrow) ExpandDimsParameter();
if (parameter == nullptr) { if (expand_dims_param == nullptr) {
MS_LOG(ERROR) << "new ExpandDimsParameter failed."; MS_LOG(ERROR) << "new ExpandDimsParameter failed.";
return nullptr; return nullptr;
} }
parameter->dim_ = param->dim(); expand_dims_param->op_parameter_.type_ = primitive->Type();
return parameter; expand_dims_param->dim_ = param->dim();
return reinterpret_cast<OpParameter *>(expand_dims_param);
} }
PoolingParameter *PopulatePoolingParam(const lite::Primitive *primitive) { OpParameter *PopulatePoolingParameter(const lite::Primitive *primitive) {
auto pooling_primitive = primitive->Value()->value_as_Pooling(); auto pooling_primitive = primitive->Value()->value_as_Pooling();
// todo use malloc instead // todo use malloc instead
PoolingParameter *parameter = new (std::nothrow) PoolingParameter(); PoolingParameter *pooling_param = new (std::nothrow) PoolingParameter();
if (parameter == nullptr) { if (pooling_param == nullptr) {
MS_LOG(ERROR) << "new PoolingParameter failed."; MS_LOG(ERROR) << "new PoolingParameter failed.";
return nullptr; return nullptr;
} }
parameter->global_ = pooling_primitive->global(); pooling_param->op_parameter_.type_ = primitive->Type();
parameter->window_w_ = pooling_primitive->windowW(); pooling_param->global_ = pooling_primitive->global();
parameter->window_h_ = pooling_primitive->windowH(); pooling_param->window_w_ = pooling_primitive->windowW();
pooling_param->window_h_ = pooling_primitive->windowH();
// todo format // todo format
auto pooling_lite_primitive = (lite::Pooling *)primitive; auto pooling_lite_primitive = (lite::Pooling *)primitive;
MS_ASSERT(nullptr != pooling_lite_primitive); MS_ASSERT(nullptr != pooling_lite_primitive);
parameter->pad_u_ = pooling_lite_primitive->PadUp(); pooling_param->pad_u_ = pooling_lite_primitive->PadUp();
parameter->pad_d_ = pooling_lite_primitive->PadDown(); pooling_param->pad_d_ = pooling_lite_primitive->PadDown();
parameter->pad_l_ = pooling_lite_primitive->PadLeft(); pooling_param->pad_l_ = pooling_lite_primitive->PadLeft();
parameter->pad_r_ = pooling_lite_primitive->PadRight(); pooling_param->pad_r_ = pooling_lite_primitive->PadRight();
parameter->stride_w_ = pooling_primitive->strideW(); pooling_param->stride_w_ = pooling_primitive->strideW();
parameter->stride_h_ = pooling_primitive->strideH(); pooling_param->stride_h_ = pooling_primitive->strideH();
auto pool_mode = pooling_primitive->poolingMode(); auto pool_mode = pooling_primitive->poolingMode();
switch (pool_mode) { switch (pool_mode) {
case schema::PoolMode_MAX_POOLING: case schema::PoolMode_MAX_POOLING:
parameter->max_pooling_ = true; pooling_param->max_pooling_ = true;
parameter->avg_pooling_ = false; pooling_param->avg_pooling_ = false;
break; break;
case schema::PoolMode_MEAN_POOLING: case schema::PoolMode_MEAN_POOLING:
parameter->max_pooling_ = false; pooling_param->max_pooling_ = false;
parameter->avg_pooling_ = true; pooling_param->avg_pooling_ = true;
break; break;
default: default:
parameter->max_pooling_ = false; pooling_param->max_pooling_ = false;
parameter->avg_pooling_ = false; pooling_param->avg_pooling_ = false;
break; break;
} }
auto round_mode = pooling_primitive->roundMode(); auto round_mode = pooling_primitive->roundMode();
switch (round_mode) { switch (round_mode) {
case schema::RoundMode_FLOOR: case schema::RoundMode_FLOOR:
parameter->round_floor_ = true; pooling_param->round_floor_ = true;
parameter->round_ceil_ = false; pooling_param->round_ceil_ = false;
break; break;
case schema::RoundMode_CEIL: case schema::RoundMode_CEIL:
parameter->round_floor_ = false; pooling_param->round_floor_ = false;
parameter->round_ceil_ = true; pooling_param->round_ceil_ = true;
break; break;
default: default:
parameter->round_floor_ = false; pooling_param->round_floor_ = false;
parameter->round_ceil_ = false; pooling_param->round_ceil_ = false;
break; break;
} }
return parameter; return reinterpret_cast<OpParameter *>(pooling_param);
} }
MatMulParameter *PopulateFullconnectionParameter(const lite::Primitive *primitive) { OpParameter *PopulateFullconnectionParameter(const lite::Primitive *primitive) {
auto param = primitive->Value()->value_as_FullConnection(); auto param = primitive->Value()->value_as_FullConnection();
MatMulParameter *parameter = new (std::nothrow) MatMulParameter(); MatMulParameter *matmul_param = new (std::nothrow) MatMulParameter();
if (parameter == nullptr) { if (matmul_param == nullptr) {
MS_LOG(ERROR) << "new FullconnectionParameter failed."; MS_LOG(ERROR) << "new FullconnectionParameter failed.";
return nullptr; return nullptr;
} }
parameter->b_transpose_ = true; matmul_param->op_parameter_.type_ = primitive->Type();
parameter->a_transpose_ = false; matmul_param->b_transpose_ = true;
parameter->has_bias_ = param->hasBias(); matmul_param->a_transpose_ = false;
parameter->minf_ = -FLT_MAX; matmul_param->has_bias_ = param->hasBias();
parameter->maxf_ = FLT_MAX; matmul_param->minf_ = -FLT_MAX;
return parameter; matmul_param->maxf_ = FLT_MAX;
return reinterpret_cast<OpParameter *>(matmul_param);
} }
MatMulParameter *PopulateMatMulParameter(const lite::Primitive *primitive) { OpParameter *PopulateMatMulParameter(const lite::Primitive *primitive) {
auto param = primitive->Value()->value_as_MatMul(); auto param = primitive->Value()->value_as_MatMul();
MatMulParameter *parameter = new (std::nothrow) MatMulParameter(); MatMulParameter *matmul_param = new (std::nothrow) MatMulParameter();
if (parameter == nullptr) { if (matmul_param == nullptr) {
MS_LOG(ERROR) << "new FullconnectionParameter failed."; MS_LOG(ERROR) << "new FullconnectionParameter failed.";
return nullptr; return nullptr;
} }
parameter->b_transpose_ = param->transposeB(); matmul_param->op_parameter_.type_ = primitive->Type();
parameter->a_transpose_ = param->transposeA(); matmul_param->b_transpose_ = param->transposeB();
parameter->has_bias_ = false; matmul_param->a_transpose_ = param->transposeA();
parameter->minf_ = -FLT_MAX; matmul_param->has_bias_ = false;
parameter->maxf_ = FLT_MAX; matmul_param->minf_ = -FLT_MAX;
return parameter; matmul_param->maxf_ = FLT_MAX;
return reinterpret_cast<OpParameter *>(matmul_param);
} }
ConvParameter *PopulateConvParameter(const lite::Primitive *primitive) { OpParameter *PopulateConvParameter(const lite::Primitive *primitive) {
ConvParameter *parameter = new (std::nothrow) ConvParameter(); ConvParameter *conv_param = new (std::nothrow) ConvParameter();
if (parameter == nullptr) { if (conv_param == nullptr) {
MS_LOG(ERROR) << "new ConvParameter failed."; MS_LOG(ERROR) << "new ConvParameter failed.";
return nullptr; return nullptr;
} }
conv_param->op_parameter_.type_ = primitive->Type();
auto conv_primitive = primitive->Value()->value_as_Conv2D(); auto conv_primitive = primitive->Value()->value_as_Conv2D();
parameter->kernel_h_ = conv_primitive->kernelH(); conv_param->kernel_h_ = conv_primitive->kernelH();
parameter->kernel_w_ = conv_primitive->kernelW(); conv_param->kernel_w_ = conv_primitive->kernelW();
// todo format // todo format
parameter->group_ = conv_primitive->group(); conv_param->group_ = conv_primitive->group();
parameter->stride_h_ = conv_primitive->strideH(); conv_param->stride_h_ = conv_primitive->strideH();
parameter->stride_w_ = conv_primitive->strideW(); conv_param->stride_w_ = conv_primitive->strideW();
auto conv2d_lite_primitive = (lite::Conv2D *)primitive; auto conv2d_lite_primitive = (lite::Conv2D *)primitive;
MS_ASSERT(nullptr != conv2d_lite_primitive); MS_ASSERT(nullptr != conv2d_lite_primitive);
parameter->pad_u_ = conv2d_lite_primitive->PadUp(); conv_param->pad_u_ = conv2d_lite_primitive->PadUp();
parameter->pad_d_ = conv2d_lite_primitive->PadDown(); conv_param->pad_d_ = conv2d_lite_primitive->PadDown();
parameter->pad_l_ = conv2d_lite_primitive->PadLeft(); conv_param->pad_l_ = conv2d_lite_primitive->PadLeft();
parameter->pad_r_ = conv2d_lite_primitive->PadRight(); conv_param->pad_r_ = conv2d_lite_primitive->PadRight();
parameter->pad_h_ = conv2d_lite_primitive->PadUp(); conv_param->pad_h_ = conv2d_lite_primitive->PadUp();
parameter->pad_w_ = conv2d_lite_primitive->PadLeft(); conv_param->pad_w_ = conv2d_lite_primitive->PadLeft();
parameter->dilation_h_ = conv_primitive->dilateH(); conv_param->dilation_h_ = conv_primitive->dilateH();
parameter->dilation_w_ = conv_primitive->dilateW(); conv_param->dilation_w_ = conv_primitive->dilateW();
parameter->input_channel_ = conv_primitive->channelIn(); conv_param->input_channel_ = conv_primitive->channelIn();
parameter->output_channel_ = conv_primitive->channelOut(); conv_param->output_channel_ = conv_primitive->channelOut();
parameter->group_ = conv_primitive->group(); conv_param->group_ = conv_primitive->group();
auto act_type = conv_primitive->activationType(); auto act_type = conv_primitive->activationType();
switch (act_type) { switch (act_type) {
case schema::ActivationType_RELU: case schema::ActivationType_RELU:
parameter->is_relu_ = true; conv_param->is_relu_ = true;
parameter->is_relu6_ = false; conv_param->is_relu6_ = false;
break; break;
case schema::ActivationType_RELU6: case schema::ActivationType_RELU6:
parameter->is_relu_ = false; conv_param->is_relu_ = false;
parameter->is_relu6_ = true; conv_param->is_relu6_ = true;
break; break;
default: default:
parameter->is_relu_ = false; conv_param->is_relu_ = false;
parameter->is_relu6_ = false; conv_param->is_relu6_ = false;
break; break;
} }
return parameter; return reinterpret_cast<OpParameter *>(conv_param);
} }
ConvParameter *PopulateConvDwParameter(const lite::Primitive *primitive) { OpParameter *PopulateConvDwParameter(const lite::Primitive *primitive) {
ConvParameter *parameter = new (std::nothrow) ConvParameter(); ConvParameter *conv_param = new (std::nothrow) ConvParameter();
if (parameter == nullptr) { if (conv_param == nullptr) {
MS_LOG(ERROR) << "new ConvParameter failed."; MS_LOG(ERROR) << "new ConvParameter failed.";
return nullptr; return nullptr;
} }
conv_param->op_parameter_.type_ = primitive->Type();
auto conv_primitive = primitive->Value()->value_as_DepthwiseConv2D(); auto conv_primitive = primitive->Value()->value_as_DepthwiseConv2D();
parameter->kernel_h_ = conv_primitive->kernelH(); conv_param->kernel_h_ = conv_primitive->kernelH();
parameter->kernel_w_ = conv_primitive->kernelW(); conv_param->kernel_w_ = conv_primitive->kernelW();
// todo format, group // todo format, group
parameter->stride_h_ = conv_primitive->strideH(); conv_param->stride_h_ = conv_primitive->strideH();
parameter->stride_w_ = conv_primitive->strideW(); conv_param->stride_w_ = conv_primitive->strideW();
auto pad_mode = conv_primitive->padMode(); auto pad_mode = conv_primitive->padMode();
auto convdw_lite_primitive = (lite::DepthwiseConv2D *)primitive; auto convdw_lite_primitive = (lite::DepthwiseConv2D *)primitive;
MS_ASSERT(nullptr != convdw_lite_primitive); MS_ASSERT(nullptr != convdw_lite_primitive);
parameter->pad_u_ = convdw_lite_primitive->PadUp(); conv_param->pad_u_ = convdw_lite_primitive->PadUp();
parameter->pad_d_ = convdw_lite_primitive->PadDown(); conv_param->pad_d_ = convdw_lite_primitive->PadDown();
parameter->pad_l_ = convdw_lite_primitive->PadLeft(); conv_param->pad_l_ = convdw_lite_primitive->PadLeft();
parameter->pad_r_ = convdw_lite_primitive->PadRight(); conv_param->pad_r_ = convdw_lite_primitive->PadRight();
parameter->pad_h_ = convdw_lite_primitive->PadUp(); conv_param->pad_h_ = convdw_lite_primitive->PadUp();
parameter->pad_w_ = convdw_lite_primitive->PadLeft(); conv_param->pad_w_ = convdw_lite_primitive->PadLeft();
parameter->dilation_h_ = conv_primitive->dilateH(); conv_param->dilation_h_ = conv_primitive->dilateH();
parameter->dilation_w_ = conv_primitive->dilateW(); conv_param->dilation_w_ = conv_primitive->dilateW();
auto act_type = conv_primitive->activationType(); auto act_type = conv_primitive->activationType();
switch (act_type) { switch (act_type) {
case schema::ActivationType_RELU: case schema::ActivationType_RELU:
parameter->is_relu_ = true; conv_param->is_relu_ = true;
parameter->is_relu6_ = false; conv_param->is_relu6_ = false;
break; break;
case schema::ActivationType_RELU6: case schema::ActivationType_RELU6:
parameter->is_relu_ = false; conv_param->is_relu_ = false;
parameter->is_relu6_ = true; conv_param->is_relu6_ = true;
break; break;
default: default:
parameter->is_relu_ = false; conv_param->is_relu_ = false;
parameter->is_relu6_ = false; conv_param->is_relu6_ = false;
break; break;
} }
return parameter; return reinterpret_cast<OpParameter *>(conv_param);
} }
ConvParameter *PopulateDeconvDwParameter(const lite::Primitive *primitive) { OpParameter *PopulateDeconvDwParameter(const lite::Primitive *primitive) {
ConvParameter *parameter = new ConvParameter(); ConvParameter *conv_param = new ConvParameter();
if (conv_param == nullptr) {
MS_LOG(ERROR) << "new ConvParameter failed.";
return nullptr;
}
conv_param->op_parameter_.type_ = primitive->Type();
auto conv_primitive = primitive->Value()->value_as_DeDepthwiseConv2D(); auto conv_primitive = primitive->Value()->value_as_DeDepthwiseConv2D();
parameter->kernel_h_ = conv_primitive->kernelH(); conv_param->kernel_h_ = conv_primitive->kernelH();
parameter->kernel_w_ = conv_primitive->kernelW(); conv_param->kernel_w_ = conv_primitive->kernelW();
// todo format, group // todo format, group
parameter->stride_h_ = conv_primitive->strideH(); conv_param->stride_h_ = conv_primitive->strideH();
parameter->stride_w_ = conv_primitive->strideW(); conv_param->stride_w_ = conv_primitive->strideW();
auto deconvdw_lite_primitive = (lite::DeconvDepthwiseConv2D *)primitive; auto deconvdw_lite_primitive = (lite::DeconvDepthwiseConv2D *)primitive;
MS_ASSERT(nullptr != deconvdw_lite_primitive); MS_ASSERT(nullptr != deconvdw_lite_primitive);
parameter->pad_u_ = deconvdw_lite_primitive->PadUp(); conv_param->pad_u_ = deconvdw_lite_primitive->PadUp();
parameter->pad_d_ = deconvdw_lite_primitive->PadDown(); conv_param->pad_d_ = deconvdw_lite_primitive->PadDown();
parameter->pad_l_ = deconvdw_lite_primitive->PadLeft(); conv_param->pad_l_ = deconvdw_lite_primitive->PadLeft();
parameter->pad_r_ = deconvdw_lite_primitive->PadRight(); conv_param->pad_r_ = deconvdw_lite_primitive->PadRight();
parameter->pad_h_ = deconvdw_lite_primitive->PadUp(); conv_param->pad_h_ = deconvdw_lite_primitive->PadUp();
parameter->pad_w_ = deconvdw_lite_primitive->PadLeft(); conv_param->pad_w_ = deconvdw_lite_primitive->PadLeft();
parameter->dilation_h_ = conv_primitive->dilateH(); conv_param->dilation_h_ = conv_primitive->dilateH();
parameter->dilation_w_ = conv_primitive->dilateW(); conv_param->dilation_w_ = conv_primitive->dilateW();
auto act_type = conv_primitive->activationType(); auto act_type = conv_primitive->activationType();
switch (act_type) { switch (act_type) {
case schema::ActivationType_RELU: case schema::ActivationType_RELU:
parameter->is_relu_ = true; conv_param->is_relu_ = true;
parameter->is_relu6_ = false; conv_param->is_relu6_ = false;
break; break;
case schema::ActivationType_RELU6: case schema::ActivationType_RELU6:
parameter->is_relu_ = false; conv_param->is_relu_ = false;
parameter->is_relu6_ = true; conv_param->is_relu6_ = true;
break; break;
default: default:
parameter->is_relu_ = false; conv_param->is_relu_ = false;
parameter->is_relu6_ = false; conv_param->is_relu6_ = false;
break; break;
} }
return parameter; return reinterpret_cast<OpParameter *>(conv_param);
} }
ConvParameter *PopulateDeconvParameter(const lite::Primitive *primitive) { OpParameter *PopulateDeconvParameter(const lite::Primitive *primitive) {
ConvParameter *parameter = new ConvParameter(); ConvParameter *conv_param = new ConvParameter();
if (conv_param == nullptr) {
MS_LOG(ERROR) << "new ConvParameter failed.";
return nullptr;
}
conv_param->op_parameter_.type_ = primitive->Type();
auto conv_primitive = primitive->Value()->value_as_DeConv2D(); auto conv_primitive = primitive->Value()->value_as_DeConv2D();
parameter->kernel_h_ = conv_primitive->kernelH(); conv_param->kernel_h_ = conv_primitive->kernelH();
parameter->kernel_w_ = conv_primitive->kernelW(); conv_param->kernel_w_ = conv_primitive->kernelW();
parameter->stride_h_ = conv_primitive->strideH(); conv_param->stride_h_ = conv_primitive->strideH();
parameter->stride_w_ = conv_primitive->strideW(); conv_param->stride_w_ = conv_primitive->strideW();
auto deconv_lite_primitive = (lite::DeConv2D *)primitive; auto deconv_lite_primitive = (lite::DeConv2D *)primitive;
MS_ASSERT(nullptr != deconvdw_lite_primitive); MS_ASSERT(nullptr != deconvdw_lite_primitive);
parameter->pad_u_ = deconv_lite_primitive->PadUp(); conv_param->pad_u_ = deconv_lite_primitive->PadUp();
parameter->pad_d_ = deconv_lite_primitive->PadDown(); conv_param->pad_d_ = deconv_lite_primitive->PadDown();
parameter->pad_l_ = deconv_lite_primitive->PadLeft(); conv_param->pad_l_ = deconv_lite_primitive->PadLeft();
parameter->pad_r_ = deconv_lite_primitive->PadRight(); conv_param->pad_r_ = deconv_lite_primitive->PadRight();
parameter->pad_h_ = deconv_lite_primitive->PadUp(); conv_param->pad_h_ = deconv_lite_primitive->PadUp();
parameter->pad_w_ = deconv_lite_primitive->PadLeft(); conv_param->pad_w_ = deconv_lite_primitive->PadLeft();
parameter->dilation_h_ = conv_primitive->dilateH(); conv_param->dilation_h_ = conv_primitive->dilateH();
parameter->dilation_w_ = conv_primitive->dilateW(); conv_param->dilation_w_ = conv_primitive->dilateW();
auto act_type = conv_primitive->activationType(); auto act_type = conv_primitive->activationType();
switch (act_type) { switch (act_type) {
case schema::ActivationType_RELU: case schema::ActivationType_RELU:
parameter->is_relu_ = true; conv_param->is_relu_ = true;
parameter->is_relu6_ = false; conv_param->is_relu6_ = false;
break; break;
case schema::ActivationType_RELU6: case schema::ActivationType_RELU6:
parameter->is_relu_ = false; conv_param->is_relu_ = false;
parameter->is_relu6_ = true; conv_param->is_relu6_ = true;
break; break;
default: default:
parameter->is_relu_ = false; conv_param->is_relu_ = false;
parameter->is_relu6_ = false; conv_param->is_relu6_ = false;
break; break;
} }
return parameter; return reinterpret_cast<OpParameter *>(conv_param);
} }
SoftmaxParameter *PopulateSoftmaxParameter(const lite::Primitive *primitive) { OpParameter *PopulateSoftmaxParameter(const lite::Primitive *primitive) {
auto softmax_primitive = primitive->Value()->value_as_SoftMax(); auto softmax_primitive = primitive->Value()->value_as_SoftMax();
SoftmaxParameter *parameter = new (std::nothrow) SoftmaxParameter(); SoftmaxParameter *softmax_param = new (std::nothrow) SoftmaxParameter();
if (parameter == nullptr) { if (softmax_param == nullptr) {
MS_LOG(ERROR) << "new SoftmaxParameter failed."; MS_LOG(ERROR) << "new SoftmaxParameter failed.";
return nullptr; return nullptr;
} }
parameter->axis_ = softmax_primitive->axis(); softmax_param->op_parameter_.type_ = primitive->Type();
return parameter; softmax_param->axis_ = softmax_primitive->axis();
return reinterpret_cast<OpParameter *>(softmax_param);
} }
ReduceParameter *PopulateReduceParameter(const lite::Primitive *primitive) { OpParameter *PopulateReduceParameter(const lite::Primitive *primitive) {
ReduceParameter *parameter = new (std::nothrow) ReduceParameter(); ReduceParameter *reduce_param = new (std::nothrow) ReduceParameter();
if (parameter == nullptr) { if (reduce_param == nullptr) {
MS_LOG(ERROR) << "new ReduceParameter failed."; MS_LOG(ERROR) << "new ReduceParameter failed.";
return nullptr; return nullptr;
} }
reduce_param->op_parameter_.type_ = primitive->Type();
auto reduce = primitive->Value()->value_as_Reduce(); auto reduce = primitive->Value()->value_as_Reduce();
parameter->keep_dims_ = reduce->keepDims(); reduce_param->keep_dims_ = reduce->keepDims();
auto axisVector = reduce->axes(); auto axisVector = reduce->axes();
if (axisVector->size() > REDUCE_MAX_AXES_NUM) { if (axisVector->size() > REDUCE_MAX_AXES_NUM) {
MS_LOG(ERROR) << "Reduce axes size " << axisVector->size() << " exceed limit " << REDUCE_MAX_AXES_NUM; MS_LOG(ERROR) << "Reduce axes size " << axisVector->size() << " exceed limit " << REDUCE_MAX_AXES_NUM;
delete (parameter); delete (reduce_param);
return nullptr; return nullptr;
} }
parameter->num_axes_ = static_cast<int>(axisVector->size()); reduce_param->num_axes_ = static_cast<int>(axisVector->size());
int i = 0; int i = 0;
for (auto iter = axisVector->begin(); iter != axisVector->end(); iter++) { for (auto iter = axisVector->begin(); iter != axisVector->end(); iter++) {
parameter->axes_[i++] = *iter; reduce_param->axes_[i++] = *iter;
} }
parameter->mode_ = static_cast<int>(reduce->mode()); reduce_param->mode_ = static_cast<int>(reduce->mode());
return parameter; return reinterpret_cast<OpParameter *>(reduce_param);
} }
PadParameter *PopulatePadParameter(const lite::Primitive *primitive) { OpParameter *PopulatePadParameter(const lite::Primitive *primitive) {
PadParameter *pad_param = new (std::nothrow) PadParameter(); PadParameter *pad_param = new (std::nothrow) PadParameter();
if (pad_param == nullptr) { if (pad_param == nullptr) {
MS_LOG(ERROR) << "new PadParameter failed."; MS_LOG(ERROR) << "new PadParameter failed.";
return nullptr; return nullptr;
} }
pad_param->op_parameter_.type_ = primitive->Type();
auto pad_node = primitive->Value()->value_as_Pad(); auto pad_node = primitive->Value()->value_as_Pad();
pad_param->pad_mode_ = pad_node->paddingMode(); pad_param->pad_mode_ = pad_node->paddingMode();
if (pad_param->pad_mode_ == schema::PaddingMode_CONSTANT) { if (pad_param->pad_mode_ == schema::PaddingMode_CONSTANT) {
pad_param->constant_value_ = pad_node->constantValue(); pad_param->constant_value_ = pad_node->constantValue();
...@@ -402,218 +421,212 @@ PadParameter *PopulatePadParameter(const lite::Primitive *primitive) { ...@@ -402,218 +421,212 @@ PadParameter *PopulatePadParameter(const lite::Primitive *primitive) {
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
pad_param->paddings_[MAX_PAD_SIZE - size + i] = (*(pad_node->paddings()))[i]; pad_param->paddings_[MAX_PAD_SIZE - size + i] = (*(pad_node->paddings()))[i];
} }
return pad_param; return reinterpret_cast<OpParameter *>(pad_param);
} }
ActivationParameter *PopulateActivationParameter(const lite::Primitive *primitive) { OpParameter *PopulateActivationParameter(const lite::Primitive *primitive) {
ActivationParameter *parameter = new (std::nothrow) ActivationParameter(); ActivationParameter *act_param = new (std::nothrow) ActivationParameter();
if (parameter == nullptr) { if (act_param == nullptr) {
MS_LOG(ERROR) << "new ActivationParameter failed."; MS_LOG(ERROR) << "new ActivationParameter failed.";
return nullptr; return nullptr;
} }
auto activation = primitive->Value()->value_as_Activation(); auto activation = primitive->Value()->value_as_Activation();
parameter->type_ = static_cast<int>(activation->type()); act_param->type_ = static_cast<int>(activation->type());
return parameter; return reinterpret_cast<OpParameter *>(act_param);
} }
FusedBatchNormParameter *PopulateFusedBatchNorm(const lite::Primitive *primitive) { OpParameter *PopulateFusedBatchNorm(const lite::Primitive *primitive) {
FusedBatchNormParameter *parameter = new (std::nothrow) FusedBatchNormParameter(); FusedBatchNormParameter *fuse_batch_norm_param = new (std::nothrow) FusedBatchNormParameter();
if (parameter == nullptr) { if (fuse_batch_norm_param == nullptr) {
MS_LOG(ERROR) << "new FusedBatchNormParameter failed."; MS_LOG(ERROR) << "new FusedBatchNormParameter failed.";
return nullptr; return nullptr;
} }
fuse_batch_norm_param->op_parameter_.type_ = primitive->Type();
auto param = primitive->Value()->value_as_FusedBatchNorm(); auto param = primitive->Value()->value_as_FusedBatchNorm();
parameter->epsilon_ = param->epsilon(); fuse_batch_norm_param->epsilon_ = param->epsilon();
return parameter; return reinterpret_cast<OpParameter *>(fuse_batch_norm_param);
} }
ArithmeticParameter *PopulateArithmetic(const lite::Primitive *primitive) { OpParameter *PopulateArithmetic(const lite::Primitive *primitive) {
ArithmeticParameter *parameter = new (std::nothrow) ArithmeticParameter(); ArithmeticParameter *arithmetic_param = new (std::nothrow) ArithmeticParameter();
if (parameter == nullptr) { if (arithmetic_param == nullptr) {
MS_LOG(ERROR) << "new ArithmeticParameter failed."; MS_LOG(ERROR) << "new ArithmeticParameter failed.";
return nullptr; return nullptr;
} }
parameter->op_parameter.type_ = primitive->Type(); arithmetic_param->op_parameter_.type_ = primitive->Type();
parameter->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting(); arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting();
parameter->ndim_ = ((lite::Arithmetic *)primitive)->NDims(); arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims();
auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0(); auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0();
(void)memcpy(parameter->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); (void)memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->InShape1(); tmp_shape = ((lite::Arithmetic *)primitive)->InShape1();
(void)memcpy(parameter->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); (void)memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape(); tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape();
(void)memcpy(parameter->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); (void)memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
return parameter; return reinterpret_cast<OpParameter *>(arithmetic_param);
} }
ArithmeticParameter *PopulateEltwiseParam(const lite::Primitive *primitive) { OpParameter *PopulateEltwiseParameter(const lite::Primitive *primitive) {
ArithmeticParameter *parameter = new (std::nothrow) ArithmeticParameter(); ArithmeticParameter *arithmetic_param = new (std::nothrow) ArithmeticParameter();
if (parameter == nullptr) { if (arithmetic_param == nullptr) {
MS_LOG(ERROR) << "new ArithmeticParameter failed."; MS_LOG(ERROR) << "new ArithmeticParameter failed.";
return nullptr; return nullptr;
} }
auto eltwise = primitive->Value()->value_as_Eltwise(); auto eltwise = primitive->Value()->value_as_Eltwise();
switch (eltwise->mode()) { switch (eltwise->mode()) {
case schema::EltwiseMode_PROD: case schema::EltwiseMode_PROD:
parameter->op_parameter.type_ = schema::PrimitiveType_Mul; arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Mul;
break; break;
case schema::EltwiseMode_SUM: case schema::EltwiseMode_SUM:
parameter->op_parameter.type_ = schema::PrimitiveType_Add; arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Add;
break; break;
case schema::EltwiseMode_MAXIMUM: case schema::EltwiseMode_MAXIMUM:
parameter->op_parameter.type_ = schema::PrimitiveType_Maximum; arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Maximum;
break; break;
default: default:
delete parameter; delete arithmetic_param;
return nullptr; return nullptr;
} }
return parameter; return reinterpret_cast<OpParameter *>(arithmetic_param);
} }
ArithmeticSelfParameter *PopulateArithmeticSelf(const lite::Primitive *primitive) { OpParameter *PopulateArithmeticSelf(const lite::Primitive *primitive) {
ArithmeticSelfParameter *parameter = new (std::nothrow) ArithmeticSelfParameter(); ArithmeticSelfParameter *arithmetic_self_param = new (std::nothrow) ArithmeticSelfParameter();
if (parameter == nullptr) { if (arithmetic_self_param == nullptr) {
MS_LOG(ERROR) << "new ArithmeticParameter failed."; MS_LOG(ERROR) << "new ArithmeticParameter failed.";
return nullptr; return nullptr;
} }
parameter->op_parameter_.type_ = primitive->Type(); arithmetic_self_param->op_parameter_.type_ = primitive->Type();
return parameter; return reinterpret_cast<OpParameter *>(arithmetic_self_param);
} }
PowerParameter *PopulatePowerParameter(const lite::Primitive *primitive) { OpParameter *PopulatePowerParameter(const lite::Primitive *primitive) {
PowerParameter *parameter = new (std::nothrow) PowerParameter(); PowerParameter *power_param = new (std::nothrow) PowerParameter();
if (parameter == nullptr) { if (power_param == nullptr) {
MS_LOG(ERROR) << "new PowerParameter failed."; MS_LOG(ERROR) << "new PowerParameter failed.";
return nullptr; return nullptr;
} }
power_param->op_parameter_.type_ = primitive->Type();
auto power = primitive->Value()->value_as_Power(); auto power = primitive->Value()->value_as_Power();
parameter->power_ = power->power(); power_param->power_ = power->power();
parameter->scale_ = power->scale(); power_param->scale_ = power->scale();
parameter->shift_ = power->shift(); power_param->shift_ = power->shift();
return parameter; return reinterpret_cast<OpParameter *>(power_param);
} }
ArgMinMaxParameter *PopulateArgMaxParam(const lite::Primitive *primitive) { OpParameter *PopulateArgMaxParameter(const lite::Primitive *primitive) {
ArgMinMaxParameter *parameter = new (std::nothrow) ArgMinMaxParameter(); ArgMinMaxParameter *arg_param = new (std::nothrow) ArgMinMaxParameter();
if (parameter == nullptr) { if (arg_param == nullptr) {
MS_LOG(ERROR) << "new ArgMinMaxParameter failed."; MS_LOG(ERROR) << "new ArgMinMaxParameter failed.";
return nullptr; return nullptr;
} }
arg_param->op_parameter_.type_ = primitive->Type();
auto param = primitive->Value()->value_as_ArgMax(); auto param = primitive->Value()->value_as_ArgMax();
parameter->op_parameter_.type_ = primitive->Type(); arg_param->axis_ = param->axis();
parameter->axis_ = param->axis(); arg_param->topk_ = param->topK();
parameter->topk_ = param->topK(); arg_param->axis_type_ = param->axisType();
parameter->axis_type_ = param->axisType(); arg_param->out_value_ = param->outMaxValue();
parameter->out_value_ = param->outMaxValue(); arg_param->keep_dims_ = param->keepDims();
parameter->keep_dims_ = param->keepDims(); return reinterpret_cast<OpParameter *>(arg_param);
return parameter;
} }
ArgMinMaxParameter *PopulateArgMinParam(const lite::Primitive *primitive) { OpParameter *PopulateArgMinParameter(const lite::Primitive *primitive) {
ArgMinMaxParameter *parameter = new (std::nothrow) ArgMinMaxParameter(); ArgMinMaxParameter *arg_param = new (std::nothrow) ArgMinMaxParameter();
if (parameter == nullptr) { if (arg_param == nullptr) {
MS_LOG(ERROR) << "new ArgMinMaxParameter failed."; MS_LOG(ERROR) << "new ArgMinMaxParameter failed.";
return nullptr; return nullptr;
} }
arg_param->op_parameter_.type_ = primitive->Type();
auto param = primitive->Value()->value_as_ArgMin(); auto param = primitive->Value()->value_as_ArgMin();
parameter->op_parameter_.type_ = primitive->Type(); arg_param->axis_ = param->axis();
parameter->axis_ = param->axis(); arg_param->topk_ = param->topK();
parameter->topk_ = param->topK(); arg_param->axis_type_ = param->axisType();
parameter->axis_type_ = param->axisType(); arg_param->out_value_ = param->outMaxValue();
parameter->out_value_ = param->outMaxValue(); arg_param->keep_dims_ = param->keepDims();
parameter->keep_dims_ = param->keepDims(); return reinterpret_cast<OpParameter *>(arg_param);
return parameter;
} }
CastParameter *PopulateCastParam(const lite::Primitive *primitive) { OpParameter *PopulateCastParameter(const lite::Primitive *primitive) {
CastParameter *parameter = new (std::nothrow) CastParameter(); CastParameter *cast_param = new (std::nothrow) CastParameter();
if (parameter == nullptr) { if (cast_param == nullptr) {
MS_LOG(ERROR) << "new CastParameter failed."; MS_LOG(ERROR) << "new CastParameter failed.";
return nullptr; return nullptr;
} }
cast_param->op_parameter_.type_ = primitive->Type();
auto param = primitive->Value()->value_as_Cast(); auto param = primitive->Value()->value_as_Cast();
parameter->op_parameter_.type_ = primitive->Type(); cast_param->src_type_ = param->srcT();
parameter->src_type_ = param->srcT(); cast_param->dst_type_ = param->dstT();
parameter->dst_type_ = param->dstT(); return reinterpret_cast<OpParameter *>(cast_param);
return parameter;
} }
LocalResponseNormParameter *PopulateLocalResponseNormParameter(const lite::Primitive *primitive) { OpParameter *PopulateLocalResponseNormParameter(const lite::Primitive *primitive) {
auto local_response_norm_attr = primitive->Value()->value_as_LocalResponseNormalization(); auto local_response_norm_attr = primitive->Value()->value_as_LocalResponseNormalization();
LocalResponseNormParameter *parameter = new (std::nothrow) LocalResponseNormParameter(); LocalResponseNormParameter *lrn_param = new (std::nothrow) LocalResponseNormParameter();
if (parameter == nullptr) { if (lrn_param == nullptr) {
MS_LOG(ERROR) << "new LocalResponseNormParameter failed."; MS_LOG(ERROR) << "new LocalResponseNormParameter failed.";
return nullptr; return nullptr;
} }
parameter->depth_radius_ = local_response_norm_attr->depth_radius(); lrn_param->op_parameter_.type_ = primitive->Type();
parameter->bias_ = local_response_norm_attr->bias(); lrn_param->depth_radius_ = local_response_norm_attr->depth_radius();
parameter->alpha_ = local_response_norm_attr->alpha(); lrn_param->bias_ = local_response_norm_attr->bias();
parameter->beta_ = local_response_norm_attr->beta(); lrn_param->alpha_ = local_response_norm_attr->alpha();
return parameter; lrn_param->beta_ = local_response_norm_attr->beta();
return reinterpret_cast<OpParameter *>(lrn_param);
} }
RangeParameter *PopulateRangeParameter(const lite::Primitive *primitive) { OpParameter *PopulateRangeParameter(const lite::Primitive *primitive) {
auto range_attr = primitive->Value()->value_as_Range(); auto range_attr = primitive->Value()->value_as_Range();
RangeParameter *parameter = new (std::nothrow) RangeParameter(); RangeParameter *range_param = new (std::nothrow) RangeParameter();
if (parameter == nullptr) { if (range_param == nullptr) {
MS_LOG(ERROR) << "new RangeParameter failed."; MS_LOG(ERROR) << "new RangeParameter failed.";
return nullptr; return nullptr;
} }
parameter->start_ = range_attr->start(); range_param->op_parameter_.type_ = primitive->Type();
parameter->limit_ = range_attr->limit(); range_param->start_ = range_attr->start();
parameter->delta_ = range_attr->delta(); range_param->limit_ = range_attr->limit();
parameter->dType_ = range_attr->dType(); range_param->delta_ = range_attr->delta();
return parameter; range_param->dType_ = range_attr->dType();
} return reinterpret_cast<OpParameter *>(range_param);
OpParameter *PopulateCeilParameter(const lite::Primitive *primitive) {
OpParameter *parameter = new (std::nothrow) OpParameter();
if (parameter == nullptr) {
MS_LOG(ERROR) << "new OpParameter failed.";
return nullptr;
}
parameter->type_ = primitive->Type();
return parameter;
} }
ConcatParameter *PopulateConcatParameter(const lite::Primitive *primitive) { OpParameter *PopulateConcatParameter(const lite::Primitive *primitive) {
ConcatParameter *parameter = new (std::nothrow) ConcatParameter(); ConcatParameter *concat_param = new (std::nothrow) ConcatParameter();
if (parameter == nullptr) { if (concat_param == nullptr) {
MS_LOG(ERROR) << "new ConcatParameter failed."; MS_LOG(ERROR) << "new ConcatParameter failed.";
return nullptr; return nullptr;
} }
parameter->op_parameter_.type_ = primitive->Type(); concat_param->op_parameter_.type_ = primitive->Type();
auto param = primitive->Value()->value_as_Concat(); auto param = primitive->Value()->value_as_Concat();
parameter->axis_ = param->axis(); concat_param->axis_ = param->axis();
return parameter; return reinterpret_cast<OpParameter *>(concat_param);
} }
TileParameter *PopulateTileParameter(const lite::Primitive *primitive) { OpParameter *PopulateTileParameter(const lite::Primitive *primitive) {
TileParameter *parameter = new (std::nothrow) TileParameter(); TileParameter *tile_param = new (std::nothrow) TileParameter();
if (parameter == nullptr) { if (tile_param == nullptr) {
MS_LOG(ERROR) << "new TileParameter failed."; MS_LOG(ERROR) << "new TileParameter failed.";
return nullptr; return nullptr;
} }
parameter->op_parameter_.type_ = primitive->Type(); tile_param->op_parameter_.type_ = primitive->Type();
auto param = primitive->Value()->value_as_Tile(); auto param = primitive->Value()->value_as_Tile();
auto multiples = param->multiples(); auto multiples = param->multiples();
parameter->in_dim_ = multiples->size(); tile_param->in_dim_ = multiples->size();
for (size_t i = 0; i < parameter->in_dim_; ++i) { for (size_t i = 0; i < tile_param->in_dim_; ++i) {
parameter->multiples_[i] = multiples->Get(i); tile_param->multiples_[i] = multiples->Get(i);
} }
return parameter; return reinterpret_cast<OpParameter *>(tile_param);
} }
TopkParameter *PopulateTopKParameter(const lite::Primitive *primitive) { OpParameter *PopulateTopKParameter(const lite::Primitive *primitive) {
TopkParameter *parameter = new (std::nothrow) TopkParameter(); TopkParameter *topk_param = new (std::nothrow) TopkParameter();
if (parameter == nullptr) { if (topk_param == nullptr) {
MS_LOG(ERROR) << "new TopkParameter failed."; MS_LOG(ERROR) << "new TopkParameter failed.";
return nullptr; return nullptr;
} }
parameter->op_parameter_.type_ = primitive->Type(); topk_param->op_parameter_.type_ = primitive->Type();
auto param = primitive->Value()->value_as_TopK(); auto param = primitive->Value()->value_as_TopK();
parameter->k_ = param->k(); topk_param->k_ = param->k();
parameter->sorted_ = param->sorted(); topk_param->sorted_ = param->sorted();
return parameter; return reinterpret_cast<OpParameter *>(topk_param);
} }
OpParameter *PopulateNhwc2NchwParameter(const lite::Primitive *primitive) { OpParameter *PopulateNhwc2NchwParameter(const lite::Primitive *primitive) {
...@@ -636,64 +649,64 @@ OpParameter *PopulateNchw2NhwcParameter(const lite::Primitive *primitive) { ...@@ -636,64 +649,64 @@ OpParameter *PopulateNchw2NhwcParameter(const lite::Primitive *primitive) {
return parameter; return parameter;
} }
TransposeParameter *PopulateTransposeParameter(const lite::Primitive *primitive) { OpParameter *PopulateTransposeParameter(const lite::Primitive *primitive) {
TransposeParameter *parameter = new (std::nothrow) TransposeParameter(); TransposeParameter *transpose_param = new (std::nothrow) TransposeParameter();
if (parameter == nullptr) { if (transpose_param == nullptr) {
MS_LOG(ERROR) << "new TransposeParameter failed."; MS_LOG(ERROR) << "new TransposeParameter failed.";
return nullptr; return nullptr;
} }
auto param = primitive->Value()->value_as_Transpose(); auto param = primitive->Value()->value_as_Transpose();
parameter->op_parameter_.type_ = primitive->Type(); transpose_param->op_parameter_.type_ = primitive->Type();
auto perm_vector_ = param->perm(); auto perm_vector_ = param->perm();
int i = 0; int i = 0;
for (auto iter = perm_vector_->begin(); iter != perm_vector_->end(); iter++) { for (auto iter = perm_vector_->begin(); iter != perm_vector_->end(); iter++) {
parameter->perm_[i++] = *iter; transpose_param->perm_[i++] = *iter;
} }
parameter->num_axes_ = i; transpose_param->num_axes_ = i;
parameter->conjugate_ = param->conjugate(); transpose_param->conjugate_ = param->conjugate();
return parameter; return reinterpret_cast<OpParameter *>(transpose_param);
} }
SplitParameter *PopulateSplitParameter(const lite::Primitive *primitive) { OpParameter *PopulateSplitParameter(const lite::Primitive *primitive) {
SplitParameter *parameter = new (std::nothrow) SplitParameter(); SplitParameter *split_param = new (std::nothrow) SplitParameter();
if (parameter == nullptr) { if (split_param == nullptr) {
MS_LOG(ERROR) << "new SplitParameter failed."; MS_LOG(ERROR) << "new SplitParameter failed.";
return nullptr; return nullptr;
} }
auto param = primitive->Value()->value_as_Split(); auto param = primitive->Value()->value_as_Split();
parameter->op_parameter_.type_ = primitive->Type(); split_param->op_parameter_.type_ = primitive->Type();
parameter->num_split_ = param->numberSplit(); split_param->num_split_ = param->numberSplit();
auto split_sizes_vector_ = param->sizeSplits(); auto split_sizes_vector_ = param->sizeSplits();
int i = 0; int i = 0;
for (auto iter = split_sizes_vector_->begin(); iter != split_sizes_vector_->end(); iter++) { for (auto iter = split_sizes_vector_->begin(); iter != split_sizes_vector_->end(); iter++) {
parameter->split_sizes_[i++] = *iter; split_param->split_sizes_[i++] = *iter;
} }
parameter->split_dim_ = param->splitDim(); split_param->split_dim_ = param->splitDim();
parameter->num_split_ = param->numberSplit(); split_param->num_split_ = param->numberSplit();
return parameter; return reinterpret_cast<OpParameter *>(split_param);
} }
SqueezeParameter *PopulateSqueezeParameter(const lite::Primitive *primitive) { OpParameter *PopulateSqueezeParameter(const lite::Primitive *primitive) {
SqueezeParameter *parameter = new (std::nothrow) SqueezeParameter(); SqueezeParameter *squeeze_param = new (std::nothrow) SqueezeParameter();
if (parameter == nullptr) { if (squeeze_param == nullptr) {
MS_LOG(ERROR) << "new SqueezeParameter failed."; MS_LOG(ERROR) << "new SqueezeParameter failed.";
return nullptr; return nullptr;
} }
parameter->op_parameter_.type_ = primitive->Type(); squeeze_param->op_parameter_.type_ = primitive->Type();
return parameter; return reinterpret_cast<OpParameter *>(squeeze_param);
} }
ScaleParameter *PopulateScaleParameter(const lite::Primitive *primitive) { OpParameter *PopulateScaleParameter(const lite::Primitive *primitive) {
if (primitive == nullptr) { if (primitive == nullptr) {
MS_LOG(ERROR) << "input primitive is nullptr"; MS_LOG(ERROR) << "input primitive is nullptr";
return nullptr; return nullptr;
} }
ScaleParameter *parameter = new (std::nothrow) ScaleParameter(); ScaleParameter *scale_param = new (std::nothrow) ScaleParameter();
if (parameter == nullptr) { if (scale_param == nullptr) {
MS_LOG(ERROR) << "new ScaleParameter failed."; MS_LOG(ERROR) << "new ScaleParameter failed.";
return nullptr; return nullptr;
} }
parameter->op_parameter_.type_ = primitive->Type(); scale_param->op_parameter_.type_ = primitive->Type();
auto param = primitive->Value()->value_as_Scale(); auto param = primitive->Value()->value_as_Scale();
if (param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "value_as_Scale return nullptr"; MS_LOG(ERROR) << "value_as_Scale return nullptr";
...@@ -701,219 +714,253 @@ ScaleParameter *PopulateScaleParameter(const lite::Primitive *primitive) { ...@@ -701,219 +714,253 @@ ScaleParameter *PopulateScaleParameter(const lite::Primitive *primitive) {
} }
// NCHW todo use enum // NCHW todo use enum
if (param->format() == schema::Format_NCHW) { if (param->format() == schema::Format_NCHW) {
parameter->axis_ = 1; scale_param->axis_ = 1;
parameter->num_axis_ = 1; scale_param->num_axis_ = 1;
} else if (param->format() == schema::Format_NHWC) { } else if (param->format() == schema::Format_NHWC) {
parameter->axis_ = 3; scale_param->axis_ = 3;
parameter->num_axis_ = 1; scale_param->num_axis_ = 1;
} }
return parameter; return reinterpret_cast<OpParameter *>(scale_param);
} }
GatherParameter *PopulateGatherParameter(const lite::Primitive *primitive) { OpParameter *PopulateGatherParameter(const lite::Primitive *primitive) {
auto gather_attr = primitive->Value()->value_as_Gather(); auto gather_attr = primitive->Value()->value_as_Gather();
GatherParameter *parameter = new (std::nothrow) GatherParameter(); GatherParameter *gather_param = new (std::nothrow) GatherParameter();
if (parameter == nullptr) { if (gather_param == nullptr) {
MS_LOG(ERROR) << "new GatherParameter failed."; MS_LOG(ERROR) << "new GatherParameter failed.";
return nullptr; return nullptr;
} }
parameter->axis_ = gather_attr->axis(); gather_param->op_parameter_.type_ = primitive->Type();
parameter->batchDims_ = gather_attr->batchDims(); gather_param->axis_ = gather_attr->axis();
return parameter; gather_param->batchDims_ = gather_attr->batchDims();
return reinterpret_cast<OpParameter *>(gather_param);
} }
GatherNdParameter *PopulateGatherNdParameter(const lite::Primitive *primitive) { OpParameter *PopulateGatherNdParameter(const lite::Primitive *primitive) {
GatherNdParameter *parameter = new (std::nothrow) GatherNdParameter(); GatherNdParameter *gather_nd_param = new (std::nothrow) GatherNdParameter();
MS_ASSERT(paramter != nullptr); if (gather_nd_param == nullptr) {
MS_LOG(ERROR) << "new GatherNDParameter failed.";
return nullptr;
}
gather_nd_param->op_parameter_.type_ = primitive->Type();
auto gatherNd_attr = primitive->Value()->value_as_GatherNd(); auto gatherNd_attr = primitive->Value()->value_as_GatherNd();
parameter->batchDims_ = gatherNd_attr->batchDims(); gather_nd_param->batchDims_ = gatherNd_attr->batchDims();
return parameter; return reinterpret_cast<OpParameter *>(gather_nd_param);
} }
ScatterNDParameter *PopulateScatterNDParameter(const lite::Primitive *primitive) { OpParameter *PopulateScatterNDParameter(const lite::Primitive *primitive) {
ScatterNDParameter *parameter = new (std::nothrow) ScatterNDParameter(); ScatterNDParameter *scatter_nd_param = new (std::nothrow) ScatterNDParameter();
if (scatter_nd_param == nullptr) {
MS_LOG(ERROR) << "new ScatterNDParameter failed.";
return nullptr;
}
scatter_nd_param->op_parameter_.type_ = primitive->Type();
MS_ASSERT(paramter != nullptr); MS_ASSERT(paramter != nullptr);
return parameter; return reinterpret_cast<OpParameter *>(scatter_nd_param);
} }
SliceParameter *PopulateSliceParam(const lite::Primitive *primitive) { OpParameter *PopulateSliceParameter(const lite::Primitive *primitive) {
SliceParameter *parameter = new (std::nothrow) SliceParameter(); SliceParameter *slice_param = new (std::nothrow) SliceParameter();
if (parameter == nullptr) { if (slice_param == nullptr) {
MS_LOG(ERROR) << "new SliceParameter failed."; MS_LOG(ERROR) << "new SliceParameter failed.";
return nullptr; return nullptr;
} }
auto param = primitive->Value()->value_as_Slice(); auto param = primitive->Value()->value_as_Slice();
parameter->op_parameter_.type_ = primitive->Type(); slice_param->op_parameter_.type_ = primitive->Type();
auto param_begin = param->begin(); auto param_begin = param->begin();
auto param_size = param->size(); auto param_size = param->size();
if (param_begin->size() != param_size->size()) { if (param_begin->size() != param_size->size()) {
delete parameter; delete slice_param;
return nullptr; return nullptr;
} }
parameter->param_length_ = static_cast<int32_t>(param_begin->size()); slice_param->param_length_ = static_cast<int32_t>(param_begin->size());
for (int32_t i = 0; i < parameter->param_length_; ++i) { for (int32_t i = 0; i < slice_param->param_length_; ++i) {
parameter->begin_[i] = param_begin->Get(i); slice_param->begin_[i] = param_begin->Get(i);
parameter->size_[i] = param_size->Get(i); slice_param->size_[i] = param_size->Get(i);
} }
return parameter; return reinterpret_cast<OpParameter *>(slice_param);
} }
BroadcastToParameter *PopulateBroadcastToParam(const lite::Primitive *primitive) { OpParameter *PopulateBroadcastToParameter(const lite::Primitive *primitive) {
BroadcastToParameter *parameter = new (std::nothrow) BroadcastToParameter(); BroadcastToParameter *broadcast_param = new (std::nothrow) BroadcastToParameter();
if (parameter == nullptr) { if (broadcast_param == nullptr) {
MS_LOG(ERROR) << "new BroadcastToParameter failed."; MS_LOG(ERROR) << "new BroadcastToParameter failed.";
return nullptr; return nullptr;
} }
auto param = primitive->Value()->value_as_BroadcastTo(); auto param = primitive->Value()->value_as_BroadcastTo();
parameter->op_parameter_.type_ = primitive->Type(); broadcast_param->op_parameter_.type_ = primitive->Type();
auto dst_shape = param->dst_shape(); auto dst_shape = param->dst_shape();
parameter->shape_size_ = dst_shape->size(); broadcast_param->shape_size_ = dst_shape->size();
for (size_t i = 0; i < parameter->shape_size_; ++i) { for (size_t i = 0; i < broadcast_param->shape_size_; ++i) {
parameter->shape_[i] = dst_shape->Get(i); broadcast_param->shape_[i] = dst_shape->Get(i);
} }
return parameter; return reinterpret_cast<OpParameter *>(broadcast_param);
} }
ReshapeParameter *PopulateReshapeParam(const lite::Primitive *primitive) { OpParameter *PopulateReshapeParameter(const lite::Primitive *primitive) {
ReshapeParameter *parameter = new (std::nothrow) ReshapeParameter(); ReshapeParameter *reshape_param = new (std::nothrow) ReshapeParameter();
if (parameter == nullptr) { if (reshape_param == nullptr) {
MS_LOG(ERROR) << "new ReshapeParameter failed."; MS_LOG(ERROR) << "new ReshapeParameter failed.";
return nullptr; return nullptr;
} }
parameter->op_parameter_.type_ = primitive->Type(); reshape_param->op_parameter_.type_ = primitive->Type();
return parameter; return reinterpret_cast<OpParameter *>(reshape_param);
} }
ReverseParameter *PopulateReverseParameter(const lite::Primitive *primitive) { OpParameter *PopulateReverseParameter(const lite::Primitive *primitive) {
auto reverse_attr = primitive->Value()->value_as_Reverse(); auto reverse_attr = primitive->Value()->value_as_Reverse();
ReverseParameter *parameter = new (std::nothrow) ReverseParameter(); ReverseParameter *reverse_param = new (std::nothrow) ReverseParameter();
if (parameter == nullptr) { if (reverse_param == nullptr) {
MS_LOG(ERROR) << "new ReverseParameter failed."; MS_LOG(ERROR) << "new ReverseParameter failed.";
return nullptr; return nullptr;
} }
reverse_param->op_parameter_.type_ = primitive->Type();
auto flatAxis = reverse_attr->axis(); auto flatAxis = reverse_attr->axis();
parameter->num_axis_ = flatAxis->size(); reverse_param->num_axis_ = flatAxis->size();
int i = 0; int i = 0;
for (auto iter = flatAxis->begin(); iter != flatAxis->end(); iter++) { for (auto iter = flatAxis->begin(); iter != flatAxis->end(); iter++) {
parameter->axis_[i++] = *iter; reverse_param->axis_[i++] = *iter;
} }
return parameter; return reinterpret_cast<OpParameter *>(reverse_param);
} }
UnsqueezeParameter *PopulateUnsqueezeParameter(const lite::Primitive *primitive) { OpParameter *PopulateUnsqueezeParameter(const lite::Primitive *primitive) {
auto unsqueeze_attr = primitive->Value()->value_as_Unsqueeze(); auto unsqueeze_attr = primitive->Value()->value_as_Unsqueeze();
UnsqueezeParameter *parameter = new (std::nothrow) UnsqueezeParameter(); UnsqueezeParameter *unsqueeze_param = new (std::nothrow) UnsqueezeParameter();
if (parameter == nullptr) { if (unsqueeze_param == nullptr) {
MS_LOG(ERROR) << "new ReverseParameter failed."; MS_LOG(ERROR) << "new ReverseParameter failed.";
return nullptr; return nullptr;
} }
unsqueeze_param->op_parameter_.type_ = primitive->Type();
auto flatAxis = unsqueeze_attr->axis(); auto flatAxis = unsqueeze_attr->axis();
parameter->num_dim_ = flatAxis->size(); unsqueeze_param->num_dim_ = flatAxis->size();
int i = 0; int i = 0;
for (auto iter = flatAxis->begin(); iter != flatAxis->end(); iter++) { for (auto iter = flatAxis->begin(); iter != flatAxis->end(); iter++) {
parameter->dims_[i++] = *iter; unsqueeze_param->dims_[i++] = *iter;
} }
return parameter; return reinterpret_cast<OpParameter *>(unsqueeze_param);
} }
StackParameter *PopulateStackParam(const lite::Primitive *primitive) { OpParameter *PopulateStackParameter(const lite::Primitive *primitive) {
StackParameter *parameter = new (std::nothrow) StackParameter(); StackParameter *stack_param = new (std::nothrow) StackParameter();
if (parameter == nullptr) { if (stack_param == nullptr) {
MS_LOG(ERROR) << "new StackParameter failed."; MS_LOG(ERROR) << "new StackParameter failed.";
return nullptr; return nullptr;
} }
auto param = primitive->Value()->value_as_Stack(); auto param = primitive->Value()->value_as_Stack();
parameter->op_parameter_.type_ = primitive->Type(); stack_param->op_parameter_.type_ = primitive->Type();
parameter->axis_ = param->axis(); stack_param->axis_ = param->axis();
return parameter; return reinterpret_cast<OpParameter *>(stack_param);
} }
UnstackParameter *PopulateUnstackParam(const lite::Primitive *primitive) { OpParameter *PopulateUnstackParameter(const lite::Primitive *primitive) {
UnstackParameter *parameter = new (std::nothrow) UnstackParameter(); UnstackParameter *unstack_param = new (std::nothrow) UnstackParameter();
if (parameter == nullptr) { if (unstack_param == nullptr) {
MS_LOG(ERROR) << "new UnstackParameter failed."; MS_LOG(ERROR) << "new UnstackParameter failed.";
return nullptr; return nullptr;
} }
auto param = primitive->Value()->value_as_Unstack(); auto param = primitive->Value()->value_as_Unstack();
parameter->op_parameter_.type_ = primitive->Type(); unstack_param->op_parameter_.type_ = primitive->Type();
parameter->num_ = param->num(); unstack_param->num_ = param->num();
parameter->axis_ = param->axis(); unstack_param->axis_ = param->axis();
return parameter; return reinterpret_cast<OpParameter *>(unstack_param);
} }
ReverseSequenceParameter *PopulateReverseSequenceParam(const lite::Primitive *primitive) { OpParameter *PopulateReverseSequenceParameter(const lite::Primitive *primitive) {
ReverseSequenceParameter *parameter = new (std::nothrow) ReverseSequenceParameter(); ReverseSequenceParameter *reverse_sequence_param = new (std::nothrow) ReverseSequenceParameter();
if (parameter == nullptr) { if (reverse_sequence_param == nullptr) {
MS_LOG(ERROR) << "new ReverseSequenceParameter failed."; MS_LOG(ERROR) << "new ReverseSequenceParameter failed.";
return nullptr; return nullptr;
} }
auto param = primitive->Value()->value_as_ReverseSequence(); auto param = primitive->Value()->value_as_ReverseSequence();
parameter->op_parameter_.type_ = primitive->Type(); reverse_sequence_param->op_parameter_.type_ = primitive->Type();
parameter->seq_axis_ = param->seqAxis(); reverse_sequence_param->seq_axis_ = param->seqAxis();
parameter->batch_axis_ = param->batchAxis(); reverse_sequence_param->batch_axis_ = param->batchAxis();
return parameter; return reinterpret_cast<OpParameter *>(reverse_sequence_param);
} }
UniqueParameter *PopulateUniqueParam(const lite::Primitive *primitive) { OpParameter *PopulateUniqueParameter(const lite::Primitive *primitive) {
UniqueParameter *parameter = new (std::nothrow) UniqueParameter(); UniqueParameter *unique_param = new (std::nothrow) UniqueParameter();
if (parameter == nullptr) { if (unique_param == nullptr) {
MS_LOG(ERROR) << "new PopulateUniqueParam failed."; MS_LOG(ERROR) << "new PopulateUniqueParam failed.";
return nullptr; return nullptr;
} }
parameter->op_parameter_.type_ = primitive->Type(); unique_param->op_parameter_.type_ = primitive->Type();
return parameter; return reinterpret_cast<OpParameter *>(unique_param);
} }
DepthToSpaceParameter *PopulateDepthToSpaceParam(const lite::Primitive *primitive) { OpParameter *PopulateDepthToSpaceParameter(const lite::Primitive *primitive) {
DepthToSpaceParameter *parameter = new (std::nothrow) DepthToSpaceParameter(); DepthToSpaceParameter *depth_space_param = new (std::nothrow) DepthToSpaceParameter();
if (parameter == nullptr) { if (depth_space_param == nullptr) {
MS_LOG(ERROR) << "new DepthToSpaceParameter failed."; MS_LOG(ERROR) << "new DepthToSpaceParameter failed.";
return nullptr; return nullptr;
} }
auto param = primitive->Value()->value_as_DepthToSpace(); auto param = primitive->Value()->value_as_DepthToSpace();
parameter->op_parameter_.type_ = primitive->Type(); depth_space_param->op_parameter_.type_ = primitive->Type();
parameter->block_size_ = param->blockSize(); depth_space_param->block_size_ = param->blockSize();
return parameter; return reinterpret_cast<OpParameter *>(depth_space_param);
} }
SpaceToDepthParameter *PopulateSpaceToDepthParam(const lite::Primitive *primitive) { OpParameter *PopulateSpaceToDepthParameter(const lite::Primitive *primitive) {
SpaceToDepthParameter *parameter = new (std::nothrow) SpaceToDepthParameter(); SpaceToDepthParameter *space_depth_param = new (std::nothrow) SpaceToDepthParameter();
if (parameter == nullptr) { if (space_depth_param == nullptr) {
MS_LOG(ERROR) << "new SpaceToDepthParameter failed."; MS_LOG(ERROR) << "new SpaceToDepthspace_depth_param failed.";
return nullptr; return nullptr;
} }
space_depth_param->op_parameter_.type_ = primitive->Type();
auto param = primitive->Value()->value_as_DepthToSpace(); auto param = primitive->Value()->value_as_DepthToSpace();
parameter->op_parameter_.type_ = primitive->Type(); space_depth_param->op_parameter_.type_ = primitive->Type();
parameter->block_size_ = param->blockSize(); space_depth_param->block_size_ = param->blockSize();
if (param->format() != schema::Format_NHWC) { if (param->format() != schema::Format_NHWC) {
MS_LOG(ERROR) << "Currently only NHWC format is supported."; MS_LOG(ERROR) << "Currently only NHWC format is supported.";
return nullptr; return nullptr;
} }
return parameter; return reinterpret_cast<OpParameter *>(space_depth_param);
} }
ResizeParameter *PopulateResizeParameter(const lite::Primitive *primitive) { OpParameter *PopulateSpaceToBatchParameter(const lite::Primitive *primitive) {
ResizeParameter *parameter = new (std::nothrow) ResizeParameter(); SpaceToBatchParameter *space_batch_param = new (std::nothrow) SpaceToBatchParameter();
if (parameter == nullptr) { if (space_batch_param == nullptr) {
MS_LOG(ERROR) << "new SpaceToBatchParameter failed.";
return nullptr;
}
space_batch_param->op_parameter_.type_ = primitive->Type();
space_batch_param->op_parameter_.type_ = primitive->Type();
auto block_sizes = ((lite::SpaceToBatch *)primitive)->BlockSizes();
(void)memcpy(space_batch_param->block_sizes_, (block_sizes.data()), block_sizes.size() * sizeof(int));
auto paddings = ((lite::SpaceToBatch *)primitive)->Paddings();
(void)memcpy(space_batch_param->paddings_, (paddings.data()), paddings.size() * sizeof(int));
auto in_shape = ((lite::SpaceToBatch *)primitive)->InShape();
(void)memcpy(space_batch_param->in_shape_, (in_shape.data()), in_shape.size() * sizeof(int));
auto padded_in_shape = ((lite::SpaceToBatch *)primitive)->PaddedInShape();
(void)memcpy(space_batch_param->padded_in_shape_, (padded_in_shape.data()), padded_in_shape.size() * sizeof(int));
return reinterpret_cast<OpParameter *>(space_batch_param);
}
OpParameter *PopulateResizeParameter(const lite::Primitive *primitive) {
ResizeParameter *resize_param = new (std::nothrow) ResizeParameter();
if (resize_param == nullptr) {
MS_LOG(ERROR) << "new ResizeParameter failed."; MS_LOG(ERROR) << "new ResizeParameter failed.";
return nullptr; return nullptr;
} }
resize_param->op_parameter_.type_ = primitive->Type();
auto param = primitive->Value()->value_as_Resize(); auto param = primitive->Value()->value_as_Resize();
parameter->method_ = param->method(); resize_param->method_ = param->method();
parameter->new_height_ = param->newHeight(); resize_param->new_height_ = param->newHeight();
parameter->new_width_ = param->newWidth(); resize_param->new_width_ = param->newWidth();
parameter->align_corners_ = param->alignCorners(); resize_param->align_corners_ = param->alignCorners();
parameter->preserve_aspect_ratio_ = param->preserveAspectRatio(); resize_param->preserve_aspect_ratio_ = param->preserveAspectRatio();
return parameter; return reinterpret_cast<OpParameter *>(resize_param);
} }
BatchToSpaceParameter *PopulateBatchToSpaceParameter(const lite::Primitive *primitive) { OpParameter *PopulateBatchToSpaceParameter(const lite::Primitive *primitive) {
BatchToSpaceParameter *parameter = new (std::nothrow) BatchToSpaceParameter(); BatchToSpaceParameter *batch_space_param = new (std::nothrow) BatchToSpaceParameter();
if (parameter == nullptr) { if (batch_space_param == nullptr) {
MS_LOG(ERROR) << "New BatchToSpaceParameter fail!"; MS_LOG(ERROR) << "New BatchToSpaceParameter fail!";
return nullptr; return nullptr;
} }
batch_space_param->op_parameter_.type_ = primitive->Type();
auto param = primitive->Value()->value_as_BatchToSpace(); auto param = primitive->Value()->value_as_BatchToSpace();
auto block_shape = param->blockShape(); auto block_shape = param->blockShape();
if (block_shape->size() != BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) { if (block_shape->size() != BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) {
...@@ -928,308 +975,271 @@ BatchToSpaceParameter *PopulateBatchToSpaceParameter(const lite::Primitive *prim ...@@ -928,308 +975,271 @@ BatchToSpaceParameter *PopulateBatchToSpaceParameter(const lite::Primitive *prim
} }
for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) { for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) {
parameter->block_shape_[i] = block_shape->Get(i); batch_space_param->block_shape_[i] = block_shape->Get(i);
} }
for (int i = 0; i < BATCH_TO_SPACE_CROPS_SIZE; ++i) { for (int i = 0; i < BATCH_TO_SPACE_CROPS_SIZE; ++i) {
parameter->crops_[i] = crops->Get(i); batch_space_param->crops_[i] = crops->Get(i);
} }
return parameter; return reinterpret_cast<OpParameter *>(batch_space_param);
} }
CropParameter *PopulateCropParameter(const lite::Primitive *primitive) { OpParameter *PopulateCropParameter(const lite::Primitive *primitive) {
auto param = primitive->Value()->value_as_Crop(); auto param = primitive->Value()->value_as_Crop();
auto param_offset = param->offsets(); auto param_offset = param->offsets();
if (param_offset->size() > CROP_OFFSET_MAX_SIZE) { if (param_offset->size() > CROP_OFFSET_MAX_SIZE) {
MS_LOG(ERROR) << "parameter offset size(" << param_offset->size() << ") should <= " << CROP_OFFSET_MAX_SIZE; MS_LOG(ERROR) << "crop_param offset size(" << param_offset->size() << ") should <= " << CROP_OFFSET_MAX_SIZE;
return nullptr; return nullptr;
} }
CropParameter *parameter = new (std::nothrow) CropParameter(); CropParameter *crop_param = new (std::nothrow) CropParameter();
if (parameter == nullptr) { if (crop_param == nullptr) {
MS_LOG(ERROR) << "new CropParameter fail!"; MS_LOG(ERROR) << "new CropParameter fail!";
return nullptr; return nullptr;
} }
parameter->axis_ = param->axis(); crop_param->op_parameter_.type_ = primitive->Type();
parameter->offset_size_ = param_offset->size(); crop_param->axis_ = param->axis();
crop_param->offset_size_ = param_offset->size();
for (int i = 0; i < param_offset->size(); ++i) { for (int i = 0; i < param_offset->size(); ++i) {
parameter->offset_[i] = param_offset->Get(i); crop_param->offset_[i] = param_offset->Get(i);
} }
return parameter; return reinterpret_cast<OpParameter *>(crop_param);
} }
OneHotParameter *PopulateOneHotParameter(const lite::Primitive *primitive) { OpParameter *PopulateOneHotParameter(const lite::Primitive *primitive) {
OneHotParameter *parameter = new (std::nothrow) OneHotParameter(); OneHotParameter *one_hot_param = new (std::nothrow) OneHotParameter();
if (parameter == nullptr) { if (one_hot_param == nullptr) {
MS_LOG(ERROR) << "new OneHotParameter fail!"; MS_LOG(ERROR) << "new OneHotParameter fail!";
return nullptr; return nullptr;
} }
one_hot_param->op_parameter_.type_ = primitive->Type();
auto param = primitive->Value()->value_as_OneHot(); auto param = primitive->Value()->value_as_OneHot();
if (param == nullptr) { if (param == nullptr) {
delete (parameter); delete (one_hot_param);
MS_LOG(ERROR) << "get OneHot param nullptr."; MS_LOG(ERROR) << "get OneHot param nullptr.";
return nullptr; return nullptr;
} }
parameter->axis_ = param->axis(); one_hot_param->axis_ = param->axis();
return parameter; return reinterpret_cast<OpParameter *>(one_hot_param);
} }
FlattenParameter *PopulateFlattenParameter(const lite::Primitive *primitive) { OpParameter *PopulateFlattenParameter(const lite::Primitive *primitive) {
FlattenParameter *parameter = new (std::nothrow) FlattenParameter(); FlattenParameter *flatten_param = new (std::nothrow) FlattenParameter();
if (parameter == nullptr) { if (flatten_param == nullptr) {
MS_LOG(ERROR) << "new FlattenParameter fail!"; MS_LOG(ERROR) << "new FlattenParameter fail!";
return nullptr; return nullptr;
} }
return parameter; flatten_param->op_parameter_.type_ = primitive->Type();
return reinterpret_cast<OpParameter *>(flatten_param);
} }
DequantizeParameter *PopulateDequantizeParameter(const lite::Primitive *primitive) { OpParameter *PopulateDequantizeParameter(const lite::Primitive *primitive) {
DequantizeParameter *parameter = new (std::nothrow) DequantizeParameter(); DequantizeParameter *dequantize_parameter = new (std::nothrow) DequantizeParameter();
if (parameter == nullptr) { if (dequantize_parameter == nullptr) {
MS_LOG(ERROR) << "new DequantizeParameter fail!"; MS_LOG(ERROR) << "new DequantizeParameter fail!";
return nullptr; return nullptr;
} }
return parameter; dequantize_parameter->op_parameter_.type_ = primitive->Type();
return reinterpret_cast<OpParameter *>(dequantize_parameter);
} }
QuantizeParameter *PopulateQuantizeParameter(const lite::Primitive *primitive) { OpParameter *PopulateQuantizeParameter(const lite::Primitive *primitive) {
QuantizeParameter *parameter = new (std::nothrow) QuantizeParameter(); QuantizeParameter *quantize_parameter = new (std::nothrow) QuantizeParameter();
if (parameter == nullptr) { if (quantize_parameter == nullptr) {
MS_LOG(ERROR) << "new QuantizeParameter fail!"; MS_LOG(ERROR) << "new QuantizeParameter fail!";
return nullptr; return nullptr;
} }
return parameter; quantize_parameter->op_parameter_.type_ = primitive->Type();
return reinterpret_cast<OpParameter *>(quantize_parameter);
} }
StridedSliceParameter *PopulateStridedSliceParam(const lite::Primitive *primitive) { OpParameter *PopulateStridedSliceParameter(const lite::Primitive *primitive) {
StridedSliceParameter *parameter = new (std::nothrow) StridedSliceParameter(); StridedSliceParameter *strided_slice_param = new (std::nothrow) StridedSliceParameter();
if (parameter == nullptr) { if (strided_slice_param == nullptr) {
MS_LOG(ERROR) << "new StridedSliceParameter failed."; MS_LOG(ERROR) << "new StridedSliceParameter failed.";
return nullptr; return nullptr;
} }
parameter->op_parameter_.type_ = primitive->Type(); strided_slice_param->op_parameter_.type_ = primitive->Type();
auto n_dims = ((lite::StridedSlice *)primitive)->NDims(); auto n_dims = ((lite::StridedSlice *)primitive)->NDims();
parameter->num_axes_ = n_dims; strided_slice_param->num_axes_ = n_dims;
auto begin = ((lite::StridedSlice *)primitive)->UpdatedBegins(); auto begin = ((lite::StridedSlice *)primitive)->UpdatedBegins();
(void)memcpy(parameter->begins_, (begin.data()), begin.size() * sizeof(int)); (void)memcpy(strided_slice_param->begins_, (begin.data()), begin.size() * sizeof(int));
auto end = ((lite::StridedSlice *)primitive)->UpdatedEnds(); auto end = ((lite::StridedSlice *)primitive)->UpdatedEnds();
(void)memcpy(parameter->ends_, (end.data()), end.size() * sizeof(int)); (void)memcpy(strided_slice_param->ends_, (end.data()), end.size() * sizeof(int));
auto stride = ((lite::StridedSlice *)primitive)->UpdatedStrides(); auto stride = ((lite::StridedSlice *)primitive)->UpdatedStrides();
(void)memcpy(parameter->strides_, (stride.data()), stride.size() * sizeof(int)); (void)memcpy(strided_slice_param->strides_, (stride.data()), stride.size() * sizeof(int));
auto in_shape = ((lite::StridedSlice *)primitive)->UpdatedInShape(); auto in_shape = ((lite::StridedSlice *)primitive)->UpdatedInShape();
(void)memcpy(parameter->in_shape_, (in_shape.data()), in_shape.size() * sizeof(int)); (void)memcpy(strided_slice_param->in_shape_, (in_shape.data()), in_shape.size() * sizeof(int));
return parameter; return reinterpret_cast<OpParameter *>(strided_slice_param);
} }
OpParameter *PopulateAddNParam(const lite::Primitive *primitive) { OpParameter *PopulateAddNParameter(const lite::Primitive *primitive) {
auto parameter = new (std::nothrow) OpParameter(); auto addn_param = new (std::nothrow) OpParameter();
if (parameter == nullptr) { if (addn_param == nullptr) {
MS_LOG(ERROR) << "new OpParameter fail!"; MS_LOG(ERROR) << "new OpParameter fail!";
return nullptr; return nullptr;
} }
parameter->type_ = primitive->Type(); addn_param->type_ = primitive->Type();
return parameter; return reinterpret_cast<OpParameter *>(addn_param);
} }
PriorBoxParameter *PopulatePriorBoxParameter(const lite::Primitive *primitive) { OpParameter *PopulatePriorBoxParameter(const lite::Primitive *primitive) {
PriorBoxParameter *param = new (std::nothrow) PriorBoxParameter(); PriorBoxParameter *prior_box_param = new (std::nothrow) PriorBoxParameter();
if (param == nullptr) { if (prior_box_param == nullptr) {
MS_LOG(ERROR) << "new PriorBoxParameter failed."; MS_LOG(ERROR) << "new PriorBoxParameter failed.";
return nullptr; return nullptr;
} }
param->op_parameter_.type_ = primitive->Type(); prior_box_param->op_parameter_.type_ = primitive->Type();
auto prior_box_param = primitive->Value()->value_as_PriorBox(); auto prior_box_attr = primitive->Value()->value_as_PriorBox();
if (prior_box_param->min_sizes()->size() > PRIOR_BOX_MAX_NUM) { if (prior_box_attr->min_sizes()->size() > PRIOR_BOX_MAX_NUM) {
MS_LOG(ERROR) << "PriorBox min_sizes size exceeds max num " << PRIOR_BOX_MAX_NUM << ", got " MS_LOG(ERROR) << "PriorBox min_sizes size exceeds max num " << PRIOR_BOX_MAX_NUM << ", got "
<< prior_box_param->min_sizes(); << prior_box_attr->min_sizes();
delete (param); delete (prior_box_param);
return nullptr; return nullptr;
} }
param->min_sizes_size = prior_box_param->min_sizes()->size(); prior_box_param->min_sizes_size = prior_box_attr->min_sizes()->size();
if (prior_box_param->max_sizes()->size() > PRIOR_BOX_MAX_NUM) { if (prior_box_attr->max_sizes()->size() > PRIOR_BOX_MAX_NUM) {
MS_LOG(ERROR) << "PriorBox max_sizes size exceeds max num " << PRIOR_BOX_MAX_NUM << ", got " MS_LOG(ERROR) << "PriorBox max_sizes size exceeds max num " << PRIOR_BOX_MAX_NUM << ", got "
<< prior_box_param->max_sizes(); << prior_box_attr->max_sizes();
delete (param); delete (prior_box_param);
return nullptr; return nullptr;
} }
param->max_sizes_size = prior_box_param->max_sizes()->size(); prior_box_param->max_sizes_size = prior_box_attr->max_sizes()->size();
(void)memcpy(param->max_sizes, prior_box_param->max_sizes()->data(), (void)memcpy(prior_box_param->max_sizes, prior_box_attr->max_sizes()->data(),
prior_box_param->max_sizes()->size() * sizeof(int32_t)); prior_box_attr->max_sizes()->size() * sizeof(int32_t));
(void)memcpy(param->min_sizes, prior_box_param->min_sizes()->data(), (void)memcpy(prior_box_param->min_sizes, prior_box_attr->min_sizes()->data(),
prior_box_param->min_sizes()->size() * sizeof(int32_t)); prior_box_attr->min_sizes()->size() * sizeof(int32_t));
if (prior_box_param->aspect_ratios()->size() > PRIOR_BOX_MAX_NUM) { if (prior_box_attr->aspect_ratios()->size() > PRIOR_BOX_MAX_NUM) {
MS_LOG(ERROR) << "PriorBox aspect_ratios size exceeds max num " << PRIOR_BOX_MAX_NUM << ", got " MS_LOG(ERROR) << "PriorBox aspect_ratios size exceeds max num " << PRIOR_BOX_MAX_NUM << ", got "
<< prior_box_param->aspect_ratios(); << prior_box_attr->aspect_ratios();
delete (param); delete (prior_box_param);
return nullptr; return nullptr;
} }
param->aspect_ratios_size = prior_box_param->aspect_ratios()->size(); prior_box_param->aspect_ratios_size = prior_box_attr->aspect_ratios()->size();
(void)memcpy(param->aspect_ratios, prior_box_param->aspect_ratios()->data(), (void)memcpy(prior_box_param->aspect_ratios, prior_box_attr->aspect_ratios()->data(),
prior_box_param->aspect_ratios()->size() * sizeof(float)); prior_box_attr->aspect_ratios()->size() * sizeof(float));
if (prior_box_param->variances()->size() != PRIOR_BOX_VAR_NUM) { if (prior_box_attr->variances()->size() != PRIOR_BOX_VAR_NUM) {
MS_LOG(ERROR) << "PriorBox variances size should be " << PRIOR_BOX_VAR_NUM << ", got " MS_LOG(ERROR) << "PriorBox variances size should be " << PRIOR_BOX_VAR_NUM << ", got "
<< prior_box_param->variances()->size(); << prior_box_attr->variances()->size();
delete (param); delete (prior_box_param);
return nullptr; return nullptr;
} }
(void)memcpy(param->variances, prior_box_param->variances()->data(), PRIOR_BOX_VAR_NUM * sizeof(float)); (void)memcpy(prior_box_param->variances, prior_box_attr->variances()->data(), PRIOR_BOX_VAR_NUM * sizeof(float));
param->flip = prior_box_param->flip(); prior_box_param->flip = prior_box_attr->flip();
param->clip = prior_box_param->clip(); prior_box_param->clip = prior_box_attr->clip();
param->offset = prior_box_param->offset(); prior_box_param->offset = prior_box_attr->offset();
param->image_size_h = prior_box_param->image_size_h(); prior_box_param->image_size_h = prior_box_attr->image_size_h();
param->image_size_w = prior_box_param->image_size_w(); prior_box_param->image_size_w = prior_box_attr->image_size_w();
param->step_h = prior_box_param->step_h(); prior_box_param->step_h = prior_box_attr->step_h();
param->step_w = prior_box_param->step_w(); prior_box_param->step_w = prior_box_attr->step_w();
return param; return reinterpret_cast<OpParameter *>(prior_box_param);
} }
SpaceToBatchParameter *PopulateSpaceToBatchParam(const lite::Primitive *primitive) { PopulateParameterRegistry::PopulateParameterRegistry() {
SpaceToBatchParameter *parameter = new (std::nothrow) SpaceToBatchParameter(); populate_parameter_funcs_[schema::PrimitiveType_SoftMax] = PopulateSoftmaxParameter;
if (parameter == nullptr) { populate_parameter_funcs_[schema::PrimitiveType_Activation] = PopulateActivationParameter;
MS_LOG(ERROR) << "new SpaceToBatchParameter failed."; populate_parameter_funcs_[schema::PrimitiveType_Conv2D] = PopulateConvParameter;
return nullptr; populate_parameter_funcs_[schema::PrimitiveType_Reduce] = PopulateReduceParameter;
} populate_parameter_funcs_[schema::PrimitiveType_Pooling] = PopulatePoolingParameter;
parameter->op_parameter_.type_ = primitive->Type(); populate_parameter_funcs_[schema::PrimitiveType_DepthwiseConv2D] = PopulateConvDwParameter;
auto block_sizes = ((lite::SpaceToBatch *)primitive)->BlockSizes(); populate_parameter_funcs_[schema::PrimitiveType_DeDepthwiseConv2D] = PopulateDeconvDwParameter;
(void)memcpy(parameter->block_sizes_, (block_sizes.data()), block_sizes.size() * sizeof(int)); populate_parameter_funcs_[schema::PrimitiveType_DeConv2D] = PopulateDeconvParameter;
auto paddings = ((lite::SpaceToBatch *)primitive)->Paddings(); populate_parameter_funcs_[schema::PrimitiveType_FusedBatchNorm] = PopulateFusedBatchNorm;
(void)memcpy(parameter->paddings_, (paddings.data()), paddings.size() * sizeof(int)); populate_parameter_funcs_[schema::PrimitiveType_FullConnection] = PopulateFullconnectionParameter;
auto in_shape = ((lite::SpaceToBatch *)primitive)->InShape(); populate_parameter_funcs_[schema::PrimitiveType_Power] = PopulatePowerParameter;
(void)memcpy(parameter->in_shape_, (in_shape.data()), in_shape.size() * sizeof(int)); populate_parameter_funcs_[schema::PrimitiveType_LocalResponseNormalization] = PopulateLocalResponseNormParameter;
auto padded_in_shape = ((lite::SpaceToBatch *)primitive)->PaddedInShape(); populate_parameter_funcs_[schema::PrimitiveType_Range] = PopulateRangeParameter;
(void)memcpy(parameter->padded_in_shape_, (padded_in_shape.data()), padded_in_shape.size() * sizeof(int)); populate_parameter_funcs_[schema::PrimitiveType_Transpose] = PopulateTransposeParameter;
return parameter; populate_parameter_funcs_[schema::PrimitiveType_Mul] = PopulateArithmetic;
populate_parameter_funcs_[schema::PrimitiveType_Add] = PopulateArithmetic;
populate_parameter_funcs_[schema::PrimitiveType_Sub] = PopulateArithmetic;
populate_parameter_funcs_[schema::PrimitiveType_Div] = PopulateArithmetic;
populate_parameter_funcs_[schema::PrimitiveType_FloorDiv] = PopulateArithmetic;
populate_parameter_funcs_[schema::PrimitiveType_FloorMod] = PopulateArithmetic;
populate_parameter_funcs_[schema::PrimitiveType_SquaredDifference] = PopulateArithmetic;
populate_parameter_funcs_[schema::PrimitiveType_BiasAdd] = PopulateArithmetic;
populate_parameter_funcs_[schema::PrimitiveType_Eltwise] = PopulateEltwiseParameter;
populate_parameter_funcs_[schema::PrimitiveType_ExpandDims] = PopulateExpandDimsParameter;
populate_parameter_funcs_[schema::PrimitiveType_Abs] = PopulateArithmeticSelf;
populate_parameter_funcs_[schema::PrimitiveType_Cos] = PopulateArithmeticSelf;
populate_parameter_funcs_[schema::PrimitiveType_Sin] = PopulateArithmeticSelf;
populate_parameter_funcs_[schema::PrimitiveType_Exp] = PopulateArithmeticSelf;
populate_parameter_funcs_[schema::PrimitiveType_Log] = PopulateArithmeticSelf;
populate_parameter_funcs_[schema::PrimitiveType_Square] = PopulateArithmeticSelf;
populate_parameter_funcs_[schema::PrimitiveType_Sqrt] = PopulateArithmeticSelf;
populate_parameter_funcs_[schema::PrimitiveType_Rsqrt] = PopulateArithmeticSelf;
populate_parameter_funcs_[schema::PrimitiveType_LogicalNot] = PopulateArithmeticSelf;
populate_parameter_funcs_[schema::PrimitiveType_Floor] = PopulateArithmeticSelf;
populate_parameter_funcs_[schema::PrimitiveType_Ceil] = PopulateArithmeticSelf;
populate_parameter_funcs_[schema::PrimitiveType_ArgMax] = PopulateArgMaxParameter;
populate_parameter_funcs_[schema::PrimitiveType_ArgMin] = PopulateArgMinParameter;
populate_parameter_funcs_[schema::PrimitiveType_Cast] = PopulateCastParameter;
populate_parameter_funcs_[schema::PrimitiveType_Scale] = PopulateScaleParameter;
populate_parameter_funcs_[schema::PrimitiveType_Reshape] = PopulateReshapeParameter;
populate_parameter_funcs_[schema::PrimitiveType_Concat] = PopulateConcatParameter;
populate_parameter_funcs_[schema::PrimitiveType_Tile] = PopulateTileParameter;
populate_parameter_funcs_[schema::PrimitiveType_TopK] = PopulateTopKParameter;
populate_parameter_funcs_[schema::PrimitiveType_Fill] = PopulateFillParameter;
populate_parameter_funcs_[schema::PrimitiveType_Gather] = PopulateGatherParameter;
populate_parameter_funcs_[schema::PrimitiveType_GatherNd] = PopulateGatherNdParameter;
populate_parameter_funcs_[schema::PrimitiveType_Slice] = PopulateSliceParameter;
populate_parameter_funcs_[schema::PrimitiveType_BroadcastTo] = PopulateBroadcastToParameter;
populate_parameter_funcs_[schema::PrimitiveType_Reverse] = PopulateReverseParameter;
populate_parameter_funcs_[schema::PrimitiveType_Stack] = PopulateStackParameter;
populate_parameter_funcs_[schema::PrimitiveType_Unstack] = PopulateUnstackParameter;
populate_parameter_funcs_[schema::PrimitiveType_ReverseSequence] = PopulateReverseSequenceParameter;
populate_parameter_funcs_[schema::PrimitiveType_Unique] = PopulateUniqueParameter;
populate_parameter_funcs_[schema::PrimitiveType_DepthToSpace] = PopulateDepthToSpaceParameter;
populate_parameter_funcs_[schema::PrimitiveType_Nchw2Nhwc] = PopulateNchw2NhwcParameter;
populate_parameter_funcs_[schema::PrimitiveType_Nhwc2Nchw] = PopulateNhwc2NchwParameter;
populate_parameter_funcs_[schema::PrimitiveType_Pad] = PopulatePadParameter;
populate_parameter_funcs_[schema::PrimitiveType_Resize] = PopulateResizeParameter;
populate_parameter_funcs_[schema::PrimitiveType_BatchToSpace] = PopulateBatchToSpaceParameter;
populate_parameter_funcs_[schema::PrimitiveType_SpaceToDepth] = PopulateSpaceToDepthParameter;
populate_parameter_funcs_[schema::PrimitiveType_SpaceToBatch] = PopulateSpaceToBatchParameter;
populate_parameter_funcs_[schema::PrimitiveType_Crop] = PopulateCropParameter;
populate_parameter_funcs_[schema::PrimitiveType_Unsqueeze] = PopulateUnsqueezeParameter;
populate_parameter_funcs_[schema::PrimitiveType_Flatten] = PopulateFlattenParameter;
populate_parameter_funcs_[schema::PrimitiveType_MatMul] = PopulateMatMulParameter;
populate_parameter_funcs_[schema::PrimitiveType_OneHot] = PopulateOneHotParameter;
populate_parameter_funcs_[schema::PrimitiveType_AddN] = PopulateAddNParameter;
populate_parameter_funcs_[schema::PrimitiveType_StridedSlice] = PopulateStridedSliceParameter;
populate_parameter_funcs_[schema::PrimitiveType_ScatterND] = PopulateScatterNDParameter;
populate_parameter_funcs_[schema::PrimitiveType_Square] = PopulateSqueezeParameter;
populate_parameter_funcs_[schema::PrimitiveType_Split] = PopulateSplitParameter;
populate_parameter_funcs_[schema::PrimitiveType_PriorBox] = PopulatePriorBoxParameter;
populate_parameter_funcs_[schema::PrimitiveType_OnnxInt8Dequantize] = PopulateDequantizeParameter;
populate_parameter_funcs_[schema::PrimitiveType_OnnxInt8Quantize] = PopulateQuantizeParameter;
}
PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() {
static PopulateParameterRegistry populate_parameter_instance;
return &populate_parameter_instance;
}
PopulateParameterFunc PopulateParameterRegistry::GetParameterFunc(const schema::PrimitiveType &type) {
return populate_parameter_funcs_[type];
} }
OpParameter *PopulateParameter(const lite::Primitive *primitive) { OpParameter *PopulateParameter(const lite::Primitive *primitive) {
MS_EXCEPTION_IF_NULL(primitive); if (primitive == nullptr) {
MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op.";
return nullptr;
}
auto op_type = primitive->Type(); auto op_type = primitive->Type();
switch (op_type) { auto func = PopulateParameterRegistry::GetInstance()->GetParameterFunc(op_type);
case schema::PrimitiveType_SoftMax: if (func == nullptr) {
return reinterpret_cast<OpParameter *>(PopulateSoftmaxParameter(primitive)); MS_LOG(ERROR) << "Get nullptr for Op Parameter Func.";
case schema::PrimitiveType_Activation: return nullptr;
return reinterpret_cast<OpParameter *>(PopulateActivationParameter(primitive));
case schema::PrimitiveType_Conv2D:
return reinterpret_cast<OpParameter *>(PopulateConvParameter(primitive));
case schema::PrimitiveType_Reduce:
return reinterpret_cast<OpParameter *>(PopulateReduceParameter(primitive));
case schema::PrimitiveType_Pooling:
return reinterpret_cast<OpParameter *>(PopulatePoolingParam(primitive));
case schema::PrimitiveType_DepthwiseConv2D:
return reinterpret_cast<OpParameter *>(PopulateConvDwParameter(primitive));
case schema::PrimitiveType_DeDepthwiseConv2D:
return reinterpret_cast<OpParameter *>(PopulateDeconvDwParameter(primitive));
case schema::PrimitiveType_DeConv2D:
return reinterpret_cast<OpParameter *>(PopulateDeconvParameter(primitive));
case schema::PrimitiveType_FusedBatchNorm:
return reinterpret_cast<OpParameter *>(PopulateFusedBatchNorm(primitive));
case schema::PrimitiveType_FullConnection:
return reinterpret_cast<OpParameter *>(PopulateFullconnectionParameter(primitive));
case schema::PrimitiveType_Power:
return reinterpret_cast<OpParameter *>(PopulatePowerParameter(primitive));
case schema::PrimitiveType_LocalResponseNormalization:
return reinterpret_cast<OpParameter *>(PopulateLocalResponseNormParameter(primitive));
case schema::PrimitiveType_Range:
return reinterpret_cast<OpParameter *>(PopulateRangeParameter(primitive));
case schema::PrimitiveType_Transpose:
return reinterpret_cast<OpParameter *>(PopulateTransposeParameter(primitive));
case schema::PrimitiveType_Mul:
case schema::PrimitiveType_Add:
case schema::PrimitiveType_Sub:
case schema::PrimitiveType_Div:
case schema::PrimitiveType_FloorDiv:
case schema::PrimitiveType_FloorMod:
case schema::PrimitiveType_SquaredDifference:
return reinterpret_cast<OpParameter *>(PopulateArithmetic(primitive));
case schema::PrimitiveType_BiasAdd:
return reinterpret_cast<OpParameter *>(new ArithmeticParameter());
case schema::PrimitiveType_Eltwise:
return reinterpret_cast<OpParameter *>(PopulateEltwiseParam(primitive));
case schema::PrimitiveType_ExpandDims:
return reinterpret_cast<OpParameter *>(PopulateExpandDimsParam(primitive));
case schema::PrimitiveType_Abs:
case schema::PrimitiveType_Cos:
case schema::PrimitiveType_Sin:
case schema::PrimitiveType_Exp:
case schema::PrimitiveType_Log:
case schema::PrimitiveType_Square:
case schema::PrimitiveType_Sqrt:
case schema::PrimitiveType_Rsqrt:
case schema::PrimitiveType_LogicalNot:
case schema::PrimitiveType_Floor:
return reinterpret_cast<OpParameter *>(PopulateArithmeticSelf(primitive));
case schema::PrimitiveType_ArgMax:
return reinterpret_cast<OpParameter *>(PopulateArgMaxParam(primitive));
case schema::PrimitiveType_ArgMin:
return reinterpret_cast<OpParameter *>(PopulateArgMinParam(primitive));
case schema::PrimitiveType_Cast:
return reinterpret_cast<OpParameter *>(PopulateCastParam(primitive));
case schema::PrimitiveType_Ceil:
return reinterpret_cast<OpParameter *>(PopulateCeilParameter(primitive));
case schema::PrimitiveType_Scale:
return reinterpret_cast<OpParameter *>(PopulateScaleParameter(primitive));
case schema::PrimitiveType_Reshape:
return reinterpret_cast<OpParameter *>(PopulateReshapeParam(primitive));
case schema::PrimitiveType_Concat:
return reinterpret_cast<OpParameter *>(PopulateConcatParameter(primitive));
case schema::PrimitiveType_Tile:
return reinterpret_cast<OpParameter *>(PopulateTileParameter(primitive));
case schema::PrimitiveType_TopK:
return reinterpret_cast<OpParameter *>(PopulateTopKParameter(primitive));
case schema::PrimitiveType_Fill:
return reinterpret_cast<OpParameter *>(PopulateFillParam(primitive));
case schema::PrimitiveType_Gather:
return reinterpret_cast<OpParameter *>(PopulateGatherParameter(primitive));
case schema::PrimitiveType_GatherNd:
return reinterpret_cast<OpParameter *>(PopulateGatherNdParameter(primitive));
case schema::PrimitiveType_Slice:
return reinterpret_cast<OpParameter *>(PopulateSliceParam(primitive));
case schema::PrimitiveType_BroadcastTo:
return reinterpret_cast<OpParameter *>(PopulateBroadcastToParam(primitive));
case schema::PrimitiveType_Reverse:
return reinterpret_cast<OpParameter *>(PopulateReverseParameter(primitive));
case schema::PrimitiveType_Stack:
return reinterpret_cast<OpParameter *>(PopulateStackParam(primitive));
case schema::PrimitiveType_Unstack:
return reinterpret_cast<OpParameter *>(PopulateUnstackParam(primitive));
case schema::PrimitiveType_ReverseSequence:
return reinterpret_cast<OpParameter *>(PopulateReverseSequenceParam(primitive));
case schema::PrimitiveType_Unique:
return reinterpret_cast<OpParameter *>(PopulateUniqueParam(primitive));
case schema::PrimitiveType_DepthToSpace:
return reinterpret_cast<OpParameter *>(PopulateDepthToSpaceParam(primitive));
case schema::PrimitiveType_Nchw2Nhwc:
return reinterpret_cast<OpParameter *>(PopulateNchw2NhwcParameter(primitive));
case schema::PrimitiveType_Nhwc2Nchw:
return reinterpret_cast<OpParameter *>(PopulateNhwc2NchwParameter(primitive));
case schema::PrimitiveType_Pad:
return reinterpret_cast<OpParameter *>(PopulatePadParameter(primitive));
case schema::PrimitiveType_Resize:
return reinterpret_cast<OpParameter *>(PopulateResizeParameter(primitive));
case schema::PrimitiveType_BatchToSpace:
return reinterpret_cast<OpParameter *>(PopulateBatchToSpaceParameter(primitive));
case schema::PrimitiveType_Crop:
return reinterpret_cast<OpParameter *>(PopulateCropParameter(primitive));
case schema::PrimitiveType_Unsqueeze:
return reinterpret_cast<OpParameter *>(PopulateUnsqueezeParameter(primitive));
case schema::PrimitiveType_Flatten:
return reinterpret_cast<OpParameter *>(PopulateFlattenParameter(primitive));
case schema::PrimitiveType_MatMul:
return reinterpret_cast<OpParameter *>(PopulateMatMulParameter(primitive));
case schema::PrimitiveType_OneHot:
return reinterpret_cast<OpParameter *>(PopulateOneHotParameter(primitive));
case schema::PrimitiveType_AddN:
return reinterpret_cast<OpParameter *>(PopulateAddNParam(primitive));
case schema::PrimitiveType_PriorBox:
return reinterpret_cast<OpParameter *>(PopulatePriorBoxParameter(primitive));
case schema::PrimitiveType_OnnxInt8Dequantize:
return reinterpret_cast<OpParameter *>(PopulateDequantizeParameter(primitive));
case schema::PrimitiveType_OnnxInt8Quantize:
return reinterpret_cast<OpParameter *>(PopulateQuantizeParameter(primitive));
default:
break;
} }
auto *parameter = func(primitive);
if (parameter == nullptr) {
MS_LOG(ERROR) << "Get nullptr for Op Parameter.";
return nullptr; return nullptr;
}
return parameter;
} }
} // namespace mindspore::kernel } // namespace mindspore::kernel
...@@ -22,7 +22,20 @@ ...@@ -22,7 +22,20 @@
#include "src/runtime/kernel/arm/opclib/op_base.h" #include "src/runtime/kernel/arm/opclib/op_base.h"
namespace mindspore::kernel { namespace mindspore::kernel {
typedef OpParameter *(*PopulateParameterFunc)(const lite::Primitive *);
class PopulateParameterRegistry {
public:
PopulateParameterRegistry();
~PopulateParameterRegistry() = default;
static PopulateParameterRegistry *GetInstance();
PopulateParameterFunc GetParameterFunc(const schema::PrimitiveType &type);
protected:
PopulateParameterFunc populate_parameter_funcs_[schema::PrimitiveType_MAX + 1];
};
OpParameter *PopulateParameter(const lite::Primitive *primitive); OpParameter *PopulateParameter(const lite::Primitive *primitive);
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_POPULATE_PARAMETER_H_ #endif // MINDSPORE_LITE_SRC_POPULATE_PARAMETER_H_
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include "src/runtime/kernel/arm/opclib/arithmetic_common.h" #include "src/runtime/kernel/arm/opclib/arithmetic_common.h"
struct ArithmeticParameter { struct ArithmeticParameter {
OpParameter op_parameter; OpParameter op_parameter_;
bool broadcasting_; bool broadcasting_;
size_t ndim_; size_t ndim_;
int in_shape0_[5]; int in_shape0_[5];
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include "src/runtime/kernel/arm/opclib/op_base.h" #include "src/runtime/kernel/arm/opclib/op_base.h"
struct SoftmaxParameter { struct SoftmaxParameter {
OpParameter op_parameter; OpParameter op_parameter_;
int32_t axis_; int32_t axis_;
int element_size_; int element_size_;
int n_dim_; int n_dim_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册