diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index 6161a92eb0d8ae93d117e66186207b9d86ca775c..46a6dcd0453400d76509470275277d9b107738b0 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -69,320 +69,339 @@ #include "src/runtime/kernel/arm/opclib/fp32/quantize.h" namespace mindspore::kernel { -FillParameter *PopulateFillParam(const lite::Primitive *primitive) { +OpParameter *PopulateFillParameter(const lite::Primitive *primitive) { auto param = primitive->Value()->value_as_Fill(); - FillParameter *parameter = new (std::nothrow) FillParameter(); - if (parameter == nullptr) { + FillParameter *fill_param = new (std::nothrow) FillParameter(); + if (fill_param == nullptr) { MS_LOG(ERROR) << "new FillParameter failed."; return nullptr; } + fill_param->op_parameter_.type_ = primitive->Type(); auto flatDims = param->dims(); - parameter->num_dims_ = flatDims->size(); + fill_param->num_dims_ = flatDims->size(); int i = 0; for (auto iter = flatDims->begin(); iter != flatDims->end(); iter++) { - parameter->dims_[i++] = *iter; + fill_param->dims_[i++] = *iter; } - return parameter; + return reinterpret_cast(fill_param); } -ExpandDimsParameter *PopulateExpandDimsParam(const lite::Primitive *primitive) { +OpParameter *PopulateExpandDimsParameter(const lite::Primitive *primitive) { auto param = primitive->Value()->value_as_ExpandDims(); - ExpandDimsParameter *parameter = new (std::nothrow) ExpandDimsParameter(); - if (parameter == nullptr) { + ExpandDimsParameter *expand_dims_param = new (std::nothrow) ExpandDimsParameter(); + if (expand_dims_param == nullptr) { MS_LOG(ERROR) << "new ExpandDimsParameter failed."; return nullptr; } - parameter->dim_ = param->dim(); - return parameter; + expand_dims_param->op_parameter_.type_ = primitive->Type(); + expand_dims_param->dim_ = param->dim(); + return reinterpret_cast(expand_dims_param); } -PoolingParameter *PopulatePoolingParam(const lite::Primitive *primitive) { +OpParameter *PopulatePoolingParameter(const lite::Primitive *primitive) { auto pooling_primitive = primitive->Value()->value_as_Pooling(); // todo use malloc instead - PoolingParameter *parameter = new (std::nothrow) PoolingParameter(); - if (parameter == nullptr) { + PoolingParameter *pooling_param = new (std::nothrow) PoolingParameter(); + if (pooling_param == nullptr) { MS_LOG(ERROR) << "new PoolingParameter failed."; return nullptr; } - parameter->global_ = pooling_primitive->global(); - parameter->window_w_ = pooling_primitive->windowW(); - parameter->window_h_ = pooling_primitive->windowH(); + pooling_param->op_parameter_.type_ = primitive->Type(); + pooling_param->global_ = pooling_primitive->global(); + pooling_param->window_w_ = pooling_primitive->windowW(); + pooling_param->window_h_ = pooling_primitive->windowH(); // todo format auto pooling_lite_primitive = (lite::Pooling *)primitive; MS_ASSERT(nullptr != pooling_lite_primitive); - parameter->pad_u_ = pooling_lite_primitive->PadUp(); - parameter->pad_d_ = pooling_lite_primitive->PadDown(); - parameter->pad_l_ = pooling_lite_primitive->PadLeft(); - parameter->pad_r_ = pooling_lite_primitive->PadRight(); - parameter->stride_w_ = pooling_primitive->strideW(); - parameter->stride_h_ = pooling_primitive->strideH(); + pooling_param->pad_u_ = pooling_lite_primitive->PadUp(); + pooling_param->pad_d_ = pooling_lite_primitive->PadDown(); + pooling_param->pad_l_ = pooling_lite_primitive->PadLeft(); + pooling_param->pad_r_ = pooling_lite_primitive->PadRight(); + pooling_param->stride_w_ = pooling_primitive->strideW(); + pooling_param->stride_h_ = pooling_primitive->strideH(); auto pool_mode = pooling_primitive->poolingMode(); switch (pool_mode) { case schema::PoolMode_MAX_POOLING: - parameter->max_pooling_ = true; - parameter->avg_pooling_ = false; + pooling_param->max_pooling_ = true; + pooling_param->avg_pooling_ = false; break; case schema::PoolMode_MEAN_POOLING: - parameter->max_pooling_ = false; - parameter->avg_pooling_ = true; + pooling_param->max_pooling_ = false; + pooling_param->avg_pooling_ = true; break; default: - parameter->max_pooling_ = false; - parameter->avg_pooling_ = false; + pooling_param->max_pooling_ = false; + pooling_param->avg_pooling_ = false; break; } auto round_mode = pooling_primitive->roundMode(); switch (round_mode) { case schema::RoundMode_FLOOR: - parameter->round_floor_ = true; - parameter->round_ceil_ = false; + pooling_param->round_floor_ = true; + pooling_param->round_ceil_ = false; break; case schema::RoundMode_CEIL: - parameter->round_floor_ = false; - parameter->round_ceil_ = true; + pooling_param->round_floor_ = false; + pooling_param->round_ceil_ = true; break; default: - parameter->round_floor_ = false; - parameter->round_ceil_ = false; + pooling_param->round_floor_ = false; + pooling_param->round_ceil_ = false; break; } - return parameter; + return reinterpret_cast(pooling_param); } -MatMulParameter *PopulateFullconnectionParameter(const lite::Primitive *primitive) { +OpParameter *PopulateFullconnectionParameter(const lite::Primitive *primitive) { auto param = primitive->Value()->value_as_FullConnection(); - MatMulParameter *parameter = new (std::nothrow) MatMulParameter(); - if (parameter == nullptr) { + MatMulParameter *matmul_param = new (std::nothrow) MatMulParameter(); + if (matmul_param == nullptr) { MS_LOG(ERROR) << "new FullconnectionParameter failed."; return nullptr; } - parameter->b_transpose_ = true; - parameter->a_transpose_ = false; - parameter->has_bias_ = param->hasBias(); - parameter->minf_ = -FLT_MAX; - parameter->maxf_ = FLT_MAX; - return parameter; + matmul_param->op_parameter_.type_ = primitive->Type(); + matmul_param->b_transpose_ = true; + matmul_param->a_transpose_ = false; + matmul_param->has_bias_ = param->hasBias(); + matmul_param->minf_ = -FLT_MAX; + matmul_param->maxf_ = FLT_MAX; + return reinterpret_cast(matmul_param); } -MatMulParameter *PopulateMatMulParameter(const lite::Primitive *primitive) { +OpParameter *PopulateMatMulParameter(const lite::Primitive *primitive) { auto param = primitive->Value()->value_as_MatMul(); - MatMulParameter *parameter = new (std::nothrow) MatMulParameter(); - if (parameter == nullptr) { + MatMulParameter *matmul_param = new (std::nothrow) MatMulParameter(); + if (matmul_param == nullptr) { MS_LOG(ERROR) << "new FullconnectionParameter failed."; return nullptr; } - parameter->b_transpose_ = param->transposeB(); - parameter->a_transpose_ = param->transposeA(); - parameter->has_bias_ = false; - parameter->minf_ = -FLT_MAX; - parameter->maxf_ = FLT_MAX; - return parameter; + matmul_param->op_parameter_.type_ = primitive->Type(); + matmul_param->b_transpose_ = param->transposeB(); + matmul_param->a_transpose_ = param->transposeA(); + matmul_param->has_bias_ = false; + matmul_param->minf_ = -FLT_MAX; + matmul_param->maxf_ = FLT_MAX; + return reinterpret_cast(matmul_param); } -ConvParameter *PopulateConvParameter(const lite::Primitive *primitive) { - ConvParameter *parameter = new (std::nothrow) ConvParameter(); - if (parameter == nullptr) { +OpParameter *PopulateConvParameter(const lite::Primitive *primitive) { + ConvParameter *conv_param = new (std::nothrow) ConvParameter(); + if (conv_param == nullptr) { MS_LOG(ERROR) << "new ConvParameter failed."; return nullptr; } + conv_param->op_parameter_.type_ = primitive->Type(); auto conv_primitive = primitive->Value()->value_as_Conv2D(); - parameter->kernel_h_ = conv_primitive->kernelH(); - parameter->kernel_w_ = conv_primitive->kernelW(); + conv_param->kernel_h_ = conv_primitive->kernelH(); + conv_param->kernel_w_ = conv_primitive->kernelW(); // todo format - parameter->group_ = conv_primitive->group(); - parameter->stride_h_ = conv_primitive->strideH(); - parameter->stride_w_ = conv_primitive->strideW(); + conv_param->group_ = conv_primitive->group(); + conv_param->stride_h_ = conv_primitive->strideH(); + conv_param->stride_w_ = conv_primitive->strideW(); auto conv2d_lite_primitive = (lite::Conv2D *)primitive; MS_ASSERT(nullptr != conv2d_lite_primitive); - parameter->pad_u_ = conv2d_lite_primitive->PadUp(); - parameter->pad_d_ = conv2d_lite_primitive->PadDown(); - parameter->pad_l_ = conv2d_lite_primitive->PadLeft(); - parameter->pad_r_ = conv2d_lite_primitive->PadRight(); - parameter->pad_h_ = conv2d_lite_primitive->PadUp(); - parameter->pad_w_ = conv2d_lite_primitive->PadLeft(); - parameter->dilation_h_ = conv_primitive->dilateH(); - parameter->dilation_w_ = conv_primitive->dilateW(); - parameter->input_channel_ = conv_primitive->channelIn(); - parameter->output_channel_ = conv_primitive->channelOut(); - parameter->group_ = conv_primitive->group(); + conv_param->pad_u_ = conv2d_lite_primitive->PadUp(); + conv_param->pad_d_ = conv2d_lite_primitive->PadDown(); + conv_param->pad_l_ = conv2d_lite_primitive->PadLeft(); + conv_param->pad_r_ = conv2d_lite_primitive->PadRight(); + conv_param->pad_h_ = conv2d_lite_primitive->PadUp(); + conv_param->pad_w_ = conv2d_lite_primitive->PadLeft(); + conv_param->dilation_h_ = conv_primitive->dilateH(); + conv_param->dilation_w_ = conv_primitive->dilateW(); + conv_param->input_channel_ = conv_primitive->channelIn(); + conv_param->output_channel_ = conv_primitive->channelOut(); + conv_param->group_ = conv_primitive->group(); auto act_type = conv_primitive->activationType(); switch (act_type) { case schema::ActivationType_RELU: - parameter->is_relu_ = true; - parameter->is_relu6_ = false; + conv_param->is_relu_ = true; + conv_param->is_relu6_ = false; break; case schema::ActivationType_RELU6: - parameter->is_relu_ = false; - parameter->is_relu6_ = true; + conv_param->is_relu_ = false; + conv_param->is_relu6_ = true; break; default: - parameter->is_relu_ = false; - parameter->is_relu6_ = false; + conv_param->is_relu_ = false; + conv_param->is_relu6_ = false; break; } - return parameter; + return reinterpret_cast(conv_param); } -ConvParameter *PopulateConvDwParameter(const lite::Primitive *primitive) { - ConvParameter *parameter = new (std::nothrow) ConvParameter(); - if (parameter == nullptr) { +OpParameter *PopulateConvDwParameter(const lite::Primitive *primitive) { + ConvParameter *conv_param = new (std::nothrow) ConvParameter(); + if (conv_param == nullptr) { MS_LOG(ERROR) << "new ConvParameter failed."; return nullptr; } + conv_param->op_parameter_.type_ = primitive->Type(); auto conv_primitive = primitive->Value()->value_as_DepthwiseConv2D(); - parameter->kernel_h_ = conv_primitive->kernelH(); - parameter->kernel_w_ = conv_primitive->kernelW(); + conv_param->kernel_h_ = conv_primitive->kernelH(); + conv_param->kernel_w_ = conv_primitive->kernelW(); // todo format, group - parameter->stride_h_ = conv_primitive->strideH(); - parameter->stride_w_ = conv_primitive->strideW(); + conv_param->stride_h_ = conv_primitive->strideH(); + conv_param->stride_w_ = conv_primitive->strideW(); auto pad_mode = conv_primitive->padMode(); auto convdw_lite_primitive = (lite::DepthwiseConv2D *)primitive; MS_ASSERT(nullptr != convdw_lite_primitive); - parameter->pad_u_ = convdw_lite_primitive->PadUp(); - parameter->pad_d_ = convdw_lite_primitive->PadDown(); - parameter->pad_l_ = convdw_lite_primitive->PadLeft(); - parameter->pad_r_ = convdw_lite_primitive->PadRight(); - parameter->pad_h_ = convdw_lite_primitive->PadUp(); - parameter->pad_w_ = convdw_lite_primitive->PadLeft(); - parameter->dilation_h_ = conv_primitive->dilateH(); - parameter->dilation_w_ = conv_primitive->dilateW(); + conv_param->pad_u_ = convdw_lite_primitive->PadUp(); + conv_param->pad_d_ = convdw_lite_primitive->PadDown(); + conv_param->pad_l_ = convdw_lite_primitive->PadLeft(); + conv_param->pad_r_ = convdw_lite_primitive->PadRight(); + conv_param->pad_h_ = convdw_lite_primitive->PadUp(); + conv_param->pad_w_ = convdw_lite_primitive->PadLeft(); + conv_param->dilation_h_ = conv_primitive->dilateH(); + conv_param->dilation_w_ = conv_primitive->dilateW(); auto act_type = conv_primitive->activationType(); switch (act_type) { case schema::ActivationType_RELU: - parameter->is_relu_ = true; - parameter->is_relu6_ = false; + conv_param->is_relu_ = true; + conv_param->is_relu6_ = false; break; case schema::ActivationType_RELU6: - parameter->is_relu_ = false; - parameter->is_relu6_ = true; + conv_param->is_relu_ = false; + conv_param->is_relu6_ = true; break; default: - parameter->is_relu_ = false; - parameter->is_relu6_ = false; + conv_param->is_relu_ = false; + conv_param->is_relu6_ = false; break; } - return parameter; + return reinterpret_cast(conv_param); } -ConvParameter *PopulateDeconvDwParameter(const lite::Primitive *primitive) { - ConvParameter *parameter = new ConvParameter(); +OpParameter *PopulateDeconvDwParameter(const lite::Primitive *primitive) { + ConvParameter *conv_param = new ConvParameter(); + if (conv_param == nullptr) { + MS_LOG(ERROR) << "new ConvParameter failed."; + return nullptr; + } + conv_param->op_parameter_.type_ = primitive->Type(); auto conv_primitive = primitive->Value()->value_as_DeDepthwiseConv2D(); - parameter->kernel_h_ = conv_primitive->kernelH(); - parameter->kernel_w_ = conv_primitive->kernelW(); + conv_param->kernel_h_ = conv_primitive->kernelH(); + conv_param->kernel_w_ = conv_primitive->kernelW(); // todo format, group - parameter->stride_h_ = conv_primitive->strideH(); - parameter->stride_w_ = conv_primitive->strideW(); + conv_param->stride_h_ = conv_primitive->strideH(); + conv_param->stride_w_ = conv_primitive->strideW(); auto deconvdw_lite_primitive = (lite::DeconvDepthwiseConv2D *)primitive; MS_ASSERT(nullptr != deconvdw_lite_primitive); - parameter->pad_u_ = deconvdw_lite_primitive->PadUp(); - parameter->pad_d_ = deconvdw_lite_primitive->PadDown(); - parameter->pad_l_ = deconvdw_lite_primitive->PadLeft(); - parameter->pad_r_ = deconvdw_lite_primitive->PadRight(); - parameter->pad_h_ = deconvdw_lite_primitive->PadUp(); - parameter->pad_w_ = deconvdw_lite_primitive->PadLeft(); - parameter->dilation_h_ = conv_primitive->dilateH(); - parameter->dilation_w_ = conv_primitive->dilateW(); + conv_param->pad_u_ = deconvdw_lite_primitive->PadUp(); + conv_param->pad_d_ = deconvdw_lite_primitive->PadDown(); + conv_param->pad_l_ = deconvdw_lite_primitive->PadLeft(); + conv_param->pad_r_ = deconvdw_lite_primitive->PadRight(); + conv_param->pad_h_ = deconvdw_lite_primitive->PadUp(); + conv_param->pad_w_ = deconvdw_lite_primitive->PadLeft(); + conv_param->dilation_h_ = conv_primitive->dilateH(); + conv_param->dilation_w_ = conv_primitive->dilateW(); auto act_type = conv_primitive->activationType(); switch (act_type) { case schema::ActivationType_RELU: - parameter->is_relu_ = true; - parameter->is_relu6_ = false; + conv_param->is_relu_ = true; + conv_param->is_relu6_ = false; break; case schema::ActivationType_RELU6: - parameter->is_relu_ = false; - parameter->is_relu6_ = true; + conv_param->is_relu_ = false; + conv_param->is_relu6_ = true; break; default: - parameter->is_relu_ = false; - parameter->is_relu6_ = false; + conv_param->is_relu_ = false; + conv_param->is_relu6_ = false; break; } - return parameter; + return reinterpret_cast(conv_param); } -ConvParameter *PopulateDeconvParameter(const lite::Primitive *primitive) { - ConvParameter *parameter = new ConvParameter(); +OpParameter *PopulateDeconvParameter(const lite::Primitive *primitive) { + ConvParameter *conv_param = new ConvParameter(); + if (conv_param == nullptr) { + MS_LOG(ERROR) << "new ConvParameter failed."; + return nullptr; + } + conv_param->op_parameter_.type_ = primitive->Type(); auto conv_primitive = primitive->Value()->value_as_DeConv2D(); - parameter->kernel_h_ = conv_primitive->kernelH(); - parameter->kernel_w_ = conv_primitive->kernelW(); - parameter->stride_h_ = conv_primitive->strideH(); - parameter->stride_w_ = conv_primitive->strideW(); + conv_param->kernel_h_ = conv_primitive->kernelH(); + conv_param->kernel_w_ = conv_primitive->kernelW(); + conv_param->stride_h_ = conv_primitive->strideH(); + conv_param->stride_w_ = conv_primitive->strideW(); auto deconv_lite_primitive = (lite::DeConv2D *)primitive; MS_ASSERT(nullptr != deconvdw_lite_primitive); - parameter->pad_u_ = deconv_lite_primitive->PadUp(); - parameter->pad_d_ = deconv_lite_primitive->PadDown(); - parameter->pad_l_ = deconv_lite_primitive->PadLeft(); - parameter->pad_r_ = deconv_lite_primitive->PadRight(); - parameter->pad_h_ = deconv_lite_primitive->PadUp(); - parameter->pad_w_ = deconv_lite_primitive->PadLeft(); - parameter->dilation_h_ = conv_primitive->dilateH(); - parameter->dilation_w_ = conv_primitive->dilateW(); + conv_param->pad_u_ = deconv_lite_primitive->PadUp(); + conv_param->pad_d_ = deconv_lite_primitive->PadDown(); + conv_param->pad_l_ = deconv_lite_primitive->PadLeft(); + conv_param->pad_r_ = deconv_lite_primitive->PadRight(); + conv_param->pad_h_ = deconv_lite_primitive->PadUp(); + conv_param->pad_w_ = deconv_lite_primitive->PadLeft(); + conv_param->dilation_h_ = conv_primitive->dilateH(); + conv_param->dilation_w_ = conv_primitive->dilateW(); auto act_type = conv_primitive->activationType(); switch (act_type) { case schema::ActivationType_RELU: - parameter->is_relu_ = true; - parameter->is_relu6_ = false; + conv_param->is_relu_ = true; + conv_param->is_relu6_ = false; break; case schema::ActivationType_RELU6: - parameter->is_relu_ = false; - parameter->is_relu6_ = true; + conv_param->is_relu_ = false; + conv_param->is_relu6_ = true; break; default: - parameter->is_relu_ = false; - parameter->is_relu6_ = false; + conv_param->is_relu_ = false; + conv_param->is_relu6_ = false; break; } - return parameter; + return reinterpret_cast(conv_param); } -SoftmaxParameter *PopulateSoftmaxParameter(const lite::Primitive *primitive) { +OpParameter *PopulateSoftmaxParameter(const lite::Primitive *primitive) { auto softmax_primitive = primitive->Value()->value_as_SoftMax(); - SoftmaxParameter *parameter = new (std::nothrow) SoftmaxParameter(); - if (parameter == nullptr) { + SoftmaxParameter *softmax_param = new (std::nothrow) SoftmaxParameter(); + if (softmax_param == nullptr) { MS_LOG(ERROR) << "new SoftmaxParameter failed."; return nullptr; } - parameter->axis_ = softmax_primitive->axis(); - return parameter; + softmax_param->op_parameter_.type_ = primitive->Type(); + softmax_param->axis_ = softmax_primitive->axis(); + return reinterpret_cast(softmax_param); } -ReduceParameter *PopulateReduceParameter(const lite::Primitive *primitive) { - ReduceParameter *parameter = new (std::nothrow) ReduceParameter(); - if (parameter == nullptr) { +OpParameter *PopulateReduceParameter(const lite::Primitive *primitive) { + ReduceParameter *reduce_param = new (std::nothrow) ReduceParameter(); + if (reduce_param == nullptr) { MS_LOG(ERROR) << "new ReduceParameter failed."; return nullptr; } + reduce_param->op_parameter_.type_ = primitive->Type(); auto reduce = primitive->Value()->value_as_Reduce(); - parameter->keep_dims_ = reduce->keepDims(); + reduce_param->keep_dims_ = reduce->keepDims(); auto axisVector = reduce->axes(); if (axisVector->size() > REDUCE_MAX_AXES_NUM) { MS_LOG(ERROR) << "Reduce axes size " << axisVector->size() << " exceed limit " << REDUCE_MAX_AXES_NUM; - delete (parameter); + delete (reduce_param); return nullptr; } - parameter->num_axes_ = static_cast(axisVector->size()); + reduce_param->num_axes_ = static_cast(axisVector->size()); int i = 0; for (auto iter = axisVector->begin(); iter != axisVector->end(); iter++) { - parameter->axes_[i++] = *iter; + reduce_param->axes_[i++] = *iter; } - parameter->mode_ = static_cast(reduce->mode()); - return parameter; + reduce_param->mode_ = static_cast(reduce->mode()); + return reinterpret_cast(reduce_param); } -PadParameter *PopulatePadParameter(const lite::Primitive *primitive) { +OpParameter *PopulatePadParameter(const lite::Primitive *primitive) { PadParameter *pad_param = new (std::nothrow) PadParameter(); if (pad_param == nullptr) { MS_LOG(ERROR) << "new PadParameter failed."; return nullptr; } + pad_param->op_parameter_.type_ = primitive->Type(); auto pad_node = primitive->Value()->value_as_Pad(); - pad_param->pad_mode_ = pad_node->paddingMode(); if (pad_param->pad_mode_ == schema::PaddingMode_CONSTANT) { pad_param->constant_value_ = pad_node->constantValue(); @@ -402,218 +421,212 @@ PadParameter *PopulatePadParameter(const lite::Primitive *primitive) { for (size_t i = 0; i < size; i++) { pad_param->paddings_[MAX_PAD_SIZE - size + i] = (*(pad_node->paddings()))[i]; } - return pad_param; + return reinterpret_cast(pad_param); } -ActivationParameter *PopulateActivationParameter(const lite::Primitive *primitive) { - ActivationParameter *parameter = new (std::nothrow) ActivationParameter(); - if (parameter == nullptr) { +OpParameter *PopulateActivationParameter(const lite::Primitive *primitive) { + ActivationParameter *act_param = new (std::nothrow) ActivationParameter(); + if (act_param == nullptr) { MS_LOG(ERROR) << "new ActivationParameter failed."; return nullptr; } auto activation = primitive->Value()->value_as_Activation(); - parameter->type_ = static_cast(activation->type()); - return parameter; + act_param->type_ = static_cast(activation->type()); + return reinterpret_cast(act_param); } -FusedBatchNormParameter *PopulateFusedBatchNorm(const lite::Primitive *primitive) { - FusedBatchNormParameter *parameter = new (std::nothrow) FusedBatchNormParameter(); - if (parameter == nullptr) { +OpParameter *PopulateFusedBatchNorm(const lite::Primitive *primitive) { + FusedBatchNormParameter *fuse_batch_norm_param = new (std::nothrow) FusedBatchNormParameter(); + if (fuse_batch_norm_param == nullptr) { MS_LOG(ERROR) << "new FusedBatchNormParameter failed."; return nullptr; } + fuse_batch_norm_param->op_parameter_.type_ = primitive->Type(); auto param = primitive->Value()->value_as_FusedBatchNorm(); - parameter->epsilon_ = param->epsilon(); - return parameter; + fuse_batch_norm_param->epsilon_ = param->epsilon(); + return reinterpret_cast(fuse_batch_norm_param); } -ArithmeticParameter *PopulateArithmetic(const lite::Primitive *primitive) { - ArithmeticParameter *parameter = new (std::nothrow) ArithmeticParameter(); - if (parameter == nullptr) { +OpParameter *PopulateArithmetic(const lite::Primitive *primitive) { + ArithmeticParameter *arithmetic_param = new (std::nothrow) ArithmeticParameter(); + if (arithmetic_param == nullptr) { MS_LOG(ERROR) << "new ArithmeticParameter failed."; return nullptr; } - parameter->op_parameter.type_ = primitive->Type(); - parameter->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting(); - parameter->ndim_ = ((lite::Arithmetic *)primitive)->NDims(); + arithmetic_param->op_parameter_.type_ = primitive->Type(); + arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting(); + arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims(); auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0(); - (void)memcpy(parameter->in_shape0_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + (void)memcpy(arithmetic_param->in_shape0_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); tmp_shape = ((lite::Arithmetic *)primitive)->InShape1(); - (void)memcpy(parameter->in_shape1_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + (void)memcpy(arithmetic_param->in_shape1_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape(); - (void)memcpy(parameter->out_shape_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); - return parameter; + (void)memcpy(arithmetic_param->out_shape_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + return reinterpret_cast(arithmetic_param); } -ArithmeticParameter *PopulateEltwiseParam(const lite::Primitive *primitive) { - ArithmeticParameter *parameter = new (std::nothrow) ArithmeticParameter(); - if (parameter == nullptr) { +OpParameter *PopulateEltwiseParameter(const lite::Primitive *primitive) { + ArithmeticParameter *arithmetic_param = new (std::nothrow) ArithmeticParameter(); + if (arithmetic_param == nullptr) { MS_LOG(ERROR) << "new ArithmeticParameter failed."; return nullptr; } auto eltwise = primitive->Value()->value_as_Eltwise(); switch (eltwise->mode()) { case schema::EltwiseMode_PROD: - parameter->op_parameter.type_ = schema::PrimitiveType_Mul; + arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Mul; break; case schema::EltwiseMode_SUM: - parameter->op_parameter.type_ = schema::PrimitiveType_Add; + arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Add; break; case schema::EltwiseMode_MAXIMUM: - parameter->op_parameter.type_ = schema::PrimitiveType_Maximum; + arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Maximum; break; default: - delete parameter; + delete arithmetic_param; return nullptr; } - return parameter; + return reinterpret_cast(arithmetic_param); } -ArithmeticSelfParameter *PopulateArithmeticSelf(const lite::Primitive *primitive) { - ArithmeticSelfParameter *parameter = new (std::nothrow) ArithmeticSelfParameter(); - if (parameter == nullptr) { +OpParameter *PopulateArithmeticSelf(const lite::Primitive *primitive) { + ArithmeticSelfParameter *arithmetic_self_param = new (std::nothrow) ArithmeticSelfParameter(); + if (arithmetic_self_param == nullptr) { MS_LOG(ERROR) << "new ArithmeticParameter failed."; return nullptr; } - parameter->op_parameter_.type_ = primitive->Type(); - return parameter; + arithmetic_self_param->op_parameter_.type_ = primitive->Type(); + return reinterpret_cast(arithmetic_self_param); } -PowerParameter *PopulatePowerParameter(const lite::Primitive *primitive) { - PowerParameter *parameter = new (std::nothrow) PowerParameter(); - if (parameter == nullptr) { +OpParameter *PopulatePowerParameter(const lite::Primitive *primitive) { + PowerParameter *power_param = new (std::nothrow) PowerParameter(); + if (power_param == nullptr) { MS_LOG(ERROR) << "new PowerParameter failed."; return nullptr; } + power_param->op_parameter_.type_ = primitive->Type(); auto power = primitive->Value()->value_as_Power(); - parameter->power_ = power->power(); - parameter->scale_ = power->scale(); - parameter->shift_ = power->shift(); - return parameter; + power_param->power_ = power->power(); + power_param->scale_ = power->scale(); + power_param->shift_ = power->shift(); + return reinterpret_cast(power_param); } -ArgMinMaxParameter *PopulateArgMaxParam(const lite::Primitive *primitive) { - ArgMinMaxParameter *parameter = new (std::nothrow) ArgMinMaxParameter(); - if (parameter == nullptr) { +OpParameter *PopulateArgMaxParameter(const lite::Primitive *primitive) { + ArgMinMaxParameter *arg_param = new (std::nothrow) ArgMinMaxParameter(); + if (arg_param == nullptr) { MS_LOG(ERROR) << "new ArgMinMaxParameter failed."; return nullptr; } + arg_param->op_parameter_.type_ = primitive->Type(); auto param = primitive->Value()->value_as_ArgMax(); - parameter->op_parameter_.type_ = primitive->Type(); - parameter->axis_ = param->axis(); - parameter->topk_ = param->topK(); - parameter->axis_type_ = param->axisType(); - parameter->out_value_ = param->outMaxValue(); - parameter->keep_dims_ = param->keepDims(); - return parameter; + arg_param->axis_ = param->axis(); + arg_param->topk_ = param->topK(); + arg_param->axis_type_ = param->axisType(); + arg_param->out_value_ = param->outMaxValue(); + arg_param->keep_dims_ = param->keepDims(); + return reinterpret_cast(arg_param); } -ArgMinMaxParameter *PopulateArgMinParam(const lite::Primitive *primitive) { - ArgMinMaxParameter *parameter = new (std::nothrow) ArgMinMaxParameter(); - if (parameter == nullptr) { +OpParameter *PopulateArgMinParameter(const lite::Primitive *primitive) { + ArgMinMaxParameter *arg_param = new (std::nothrow) ArgMinMaxParameter(); + if (arg_param == nullptr) { MS_LOG(ERROR) << "new ArgMinMaxParameter failed."; return nullptr; } + arg_param->op_parameter_.type_ = primitive->Type(); auto param = primitive->Value()->value_as_ArgMin(); - parameter->op_parameter_.type_ = primitive->Type(); - parameter->axis_ = param->axis(); - parameter->topk_ = param->topK(); - parameter->axis_type_ = param->axisType(); - parameter->out_value_ = param->outMaxValue(); - parameter->keep_dims_ = param->keepDims(); - return parameter; + arg_param->axis_ = param->axis(); + arg_param->topk_ = param->topK(); + arg_param->axis_type_ = param->axisType(); + arg_param->out_value_ = param->outMaxValue(); + arg_param->keep_dims_ = param->keepDims(); + return reinterpret_cast(arg_param); } -CastParameter *PopulateCastParam(const lite::Primitive *primitive) { - CastParameter *parameter = new (std::nothrow) CastParameter(); - if (parameter == nullptr) { +OpParameter *PopulateCastParameter(const lite::Primitive *primitive) { + CastParameter *cast_param = new (std::nothrow) CastParameter(); + if (cast_param == nullptr) { MS_LOG(ERROR) << "new CastParameter failed."; return nullptr; } + cast_param->op_parameter_.type_ = primitive->Type(); auto param = primitive->Value()->value_as_Cast(); - parameter->op_parameter_.type_ = primitive->Type(); - parameter->src_type_ = param->srcT(); - parameter->dst_type_ = param->dstT(); - return parameter; + cast_param->src_type_ = param->srcT(); + cast_param->dst_type_ = param->dstT(); + return reinterpret_cast(cast_param); } -LocalResponseNormParameter *PopulateLocalResponseNormParameter(const lite::Primitive *primitive) { +OpParameter *PopulateLocalResponseNormParameter(const lite::Primitive *primitive) { auto local_response_norm_attr = primitive->Value()->value_as_LocalResponseNormalization(); - LocalResponseNormParameter *parameter = new (std::nothrow) LocalResponseNormParameter(); - if (parameter == nullptr) { + LocalResponseNormParameter *lrn_param = new (std::nothrow) LocalResponseNormParameter(); + if (lrn_param == nullptr) { MS_LOG(ERROR) << "new LocalResponseNormParameter failed."; return nullptr; } - parameter->depth_radius_ = local_response_norm_attr->depth_radius(); - parameter->bias_ = local_response_norm_attr->bias(); - parameter->alpha_ = local_response_norm_attr->alpha(); - parameter->beta_ = local_response_norm_attr->beta(); - return parameter; + lrn_param->op_parameter_.type_ = primitive->Type(); + lrn_param->depth_radius_ = local_response_norm_attr->depth_radius(); + lrn_param->bias_ = local_response_norm_attr->bias(); + lrn_param->alpha_ = local_response_norm_attr->alpha(); + lrn_param->beta_ = local_response_norm_attr->beta(); + return reinterpret_cast(lrn_param); } -RangeParameter *PopulateRangeParameter(const lite::Primitive *primitive) { +OpParameter *PopulateRangeParameter(const lite::Primitive *primitive) { auto range_attr = primitive->Value()->value_as_Range(); - RangeParameter *parameter = new (std::nothrow) RangeParameter(); - if (parameter == nullptr) { + RangeParameter *range_param = new (std::nothrow) RangeParameter(); + if (range_param == nullptr) { MS_LOG(ERROR) << "new RangeParameter failed."; return nullptr; } - parameter->start_ = range_attr->start(); - parameter->limit_ = range_attr->limit(); - parameter->delta_ = range_attr->delta(); - parameter->dType_ = range_attr->dType(); - return parameter; + range_param->op_parameter_.type_ = primitive->Type(); + range_param->start_ = range_attr->start(); + range_param->limit_ = range_attr->limit(); + range_param->delta_ = range_attr->delta(); + range_param->dType_ = range_attr->dType(); + return reinterpret_cast(range_param); } -OpParameter *PopulateCeilParameter(const lite::Primitive *primitive) { - OpParameter *parameter = new (std::nothrow) OpParameter(); - if (parameter == nullptr) { - MS_LOG(ERROR) << "new OpParameter failed."; - return nullptr; - } - parameter->type_ = primitive->Type(); - return parameter; -} - -ConcatParameter *PopulateConcatParameter(const lite::Primitive *primitive) { - ConcatParameter *parameter = new (std::nothrow) ConcatParameter(); - if (parameter == nullptr) { +OpParameter *PopulateConcatParameter(const lite::Primitive *primitive) { + ConcatParameter *concat_param = new (std::nothrow) ConcatParameter(); + if (concat_param == nullptr) { MS_LOG(ERROR) << "new ConcatParameter failed."; return nullptr; } - parameter->op_parameter_.type_ = primitive->Type(); + concat_param->op_parameter_.type_ = primitive->Type(); auto param = primitive->Value()->value_as_Concat(); - parameter->axis_ = param->axis(); - return parameter; + concat_param->axis_ = param->axis(); + return reinterpret_cast(concat_param); } -TileParameter *PopulateTileParameter(const lite::Primitive *primitive) { - TileParameter *parameter = new (std::nothrow) TileParameter(); - if (parameter == nullptr) { +OpParameter *PopulateTileParameter(const lite::Primitive *primitive) { + TileParameter *tile_param = new (std::nothrow) TileParameter(); + if (tile_param == nullptr) { MS_LOG(ERROR) << "new TileParameter failed."; return nullptr; } - parameter->op_parameter_.type_ = primitive->Type(); + tile_param->op_parameter_.type_ = primitive->Type(); auto param = primitive->Value()->value_as_Tile(); auto multiples = param->multiples(); - parameter->in_dim_ = multiples->size(); - for (size_t i = 0; i < parameter->in_dim_; ++i) { - parameter->multiples_[i] = multiples->Get(i); + tile_param->in_dim_ = multiples->size(); + for (size_t i = 0; i < tile_param->in_dim_; ++i) { + tile_param->multiples_[i] = multiples->Get(i); } - return parameter; + return reinterpret_cast(tile_param); } -TopkParameter *PopulateTopKParameter(const lite::Primitive *primitive) { - TopkParameter *parameter = new (std::nothrow) TopkParameter(); - if (parameter == nullptr) { +OpParameter *PopulateTopKParameter(const lite::Primitive *primitive) { + TopkParameter *topk_param = new (std::nothrow) TopkParameter(); + if (topk_param == nullptr) { MS_LOG(ERROR) << "new TopkParameter failed."; return nullptr; } - parameter->op_parameter_.type_ = primitive->Type(); + topk_param->op_parameter_.type_ = primitive->Type(); auto param = primitive->Value()->value_as_TopK(); - parameter->k_ = param->k(); - parameter->sorted_ = param->sorted(); - return parameter; + topk_param->k_ = param->k(); + topk_param->sorted_ = param->sorted(); + return reinterpret_cast(topk_param); } OpParameter *PopulateNhwc2NchwParameter(const lite::Primitive *primitive) { @@ -636,64 +649,64 @@ OpParameter *PopulateNchw2NhwcParameter(const lite::Primitive *primitive) { return parameter; } -TransposeParameter *PopulateTransposeParameter(const lite::Primitive *primitive) { - TransposeParameter *parameter = new (std::nothrow) TransposeParameter(); - if (parameter == nullptr) { +OpParameter *PopulateTransposeParameter(const lite::Primitive *primitive) { + TransposeParameter *transpose_param = new (std::nothrow) TransposeParameter(); + if (transpose_param == nullptr) { MS_LOG(ERROR) << "new TransposeParameter failed."; return nullptr; } auto param = primitive->Value()->value_as_Transpose(); - parameter->op_parameter_.type_ = primitive->Type(); + transpose_param->op_parameter_.type_ = primitive->Type(); auto perm_vector_ = param->perm(); int i = 0; for (auto iter = perm_vector_->begin(); iter != perm_vector_->end(); iter++) { - parameter->perm_[i++] = *iter; + transpose_param->perm_[i++] = *iter; } - parameter->num_axes_ = i; - parameter->conjugate_ = param->conjugate(); - return parameter; + transpose_param->num_axes_ = i; + transpose_param->conjugate_ = param->conjugate(); + return reinterpret_cast(transpose_param); } -SplitParameter *PopulateSplitParameter(const lite::Primitive *primitive) { - SplitParameter *parameter = new (std::nothrow) SplitParameter(); - if (parameter == nullptr) { +OpParameter *PopulateSplitParameter(const lite::Primitive *primitive) { + SplitParameter *split_param = new (std::nothrow) SplitParameter(); + if (split_param == nullptr) { MS_LOG(ERROR) << "new SplitParameter failed."; return nullptr; } auto param = primitive->Value()->value_as_Split(); - parameter->op_parameter_.type_ = primitive->Type(); - parameter->num_split_ = param->numberSplit(); + split_param->op_parameter_.type_ = primitive->Type(); + split_param->num_split_ = param->numberSplit(); auto split_sizes_vector_ = param->sizeSplits(); int i = 0; for (auto iter = split_sizes_vector_->begin(); iter != split_sizes_vector_->end(); iter++) { - parameter->split_sizes_[i++] = *iter; + split_param->split_sizes_[i++] = *iter; } - parameter->split_dim_ = param->splitDim(); - parameter->num_split_ = param->numberSplit(); - return parameter; + split_param->split_dim_ = param->splitDim(); + split_param->num_split_ = param->numberSplit(); + return reinterpret_cast(split_param); } -SqueezeParameter *PopulateSqueezeParameter(const lite::Primitive *primitive) { - SqueezeParameter *parameter = new (std::nothrow) SqueezeParameter(); - if (parameter == nullptr) { +OpParameter *PopulateSqueezeParameter(const lite::Primitive *primitive) { + SqueezeParameter *squeeze_param = new (std::nothrow) SqueezeParameter(); + if (squeeze_param == nullptr) { MS_LOG(ERROR) << "new SqueezeParameter failed."; return nullptr; } - parameter->op_parameter_.type_ = primitive->Type(); - return parameter; + squeeze_param->op_parameter_.type_ = primitive->Type(); + return reinterpret_cast(squeeze_param); } -ScaleParameter *PopulateScaleParameter(const lite::Primitive *primitive) { +OpParameter *PopulateScaleParameter(const lite::Primitive *primitive) { if (primitive == nullptr) { MS_LOG(ERROR) << "input primitive is nullptr"; return nullptr; } - ScaleParameter *parameter = new (std::nothrow) ScaleParameter(); - if (parameter == nullptr) { + ScaleParameter *scale_param = new (std::nothrow) ScaleParameter(); + if (scale_param == nullptr) { MS_LOG(ERROR) << "new ScaleParameter failed."; return nullptr; } - parameter->op_parameter_.type_ = primitive->Type(); + scale_param->op_parameter_.type_ = primitive->Type(); auto param = primitive->Value()->value_as_Scale(); if (param == nullptr) { MS_LOG(ERROR) << "value_as_Scale return nullptr"; @@ -701,219 +714,253 @@ ScaleParameter *PopulateScaleParameter(const lite::Primitive *primitive) { } // NCHW todo use enum if (param->format() == schema::Format_NCHW) { - parameter->axis_ = 1; - parameter->num_axis_ = 1; + scale_param->axis_ = 1; + scale_param->num_axis_ = 1; } else if (param->format() == schema::Format_NHWC) { - parameter->axis_ = 3; - parameter->num_axis_ = 1; + scale_param->axis_ = 3; + scale_param->num_axis_ = 1; } - return parameter; + return reinterpret_cast(scale_param); } -GatherParameter *PopulateGatherParameter(const lite::Primitive *primitive) { +OpParameter *PopulateGatherParameter(const lite::Primitive *primitive) { auto gather_attr = primitive->Value()->value_as_Gather(); - GatherParameter *parameter = new (std::nothrow) GatherParameter(); - if (parameter == nullptr) { + GatherParameter *gather_param = new (std::nothrow) GatherParameter(); + if (gather_param == nullptr) { MS_LOG(ERROR) << "new GatherParameter failed."; return nullptr; } - parameter->axis_ = gather_attr->axis(); - parameter->batchDims_ = gather_attr->batchDims(); - return parameter; + gather_param->op_parameter_.type_ = primitive->Type(); + gather_param->axis_ = gather_attr->axis(); + gather_param->batchDims_ = gather_attr->batchDims(); + return reinterpret_cast(gather_param); } -GatherNdParameter *PopulateGatherNdParameter(const lite::Primitive *primitive) { - GatherNdParameter *parameter = new (std::nothrow) GatherNdParameter(); - MS_ASSERT(paramter != nullptr); +OpParameter *PopulateGatherNdParameter(const lite::Primitive *primitive) { + GatherNdParameter *gather_nd_param = new (std::nothrow) GatherNdParameter(); + if (gather_nd_param == nullptr) { + MS_LOG(ERROR) << "new GatherNDParameter failed."; + return nullptr; + } + gather_nd_param->op_parameter_.type_ = primitive->Type(); auto gatherNd_attr = primitive->Value()->value_as_GatherNd(); - parameter->batchDims_ = gatherNd_attr->batchDims(); - return parameter; + gather_nd_param->batchDims_ = gatherNd_attr->batchDims(); + return reinterpret_cast(gather_nd_param); } -ScatterNDParameter *PopulateScatterNDParameter(const lite::Primitive *primitive) { - ScatterNDParameter *parameter = new (std::nothrow) ScatterNDParameter(); +OpParameter *PopulateScatterNDParameter(const lite::Primitive *primitive) { + ScatterNDParameter *scatter_nd_param = new (std::nothrow) ScatterNDParameter(); + if (scatter_nd_param == nullptr) { + MS_LOG(ERROR) << "new ScatterNDParameter failed."; + return nullptr; + } + scatter_nd_param->op_parameter_.type_ = primitive->Type(); MS_ASSERT(paramter != nullptr); - return parameter; + return reinterpret_cast(scatter_nd_param); } -SliceParameter *PopulateSliceParam(const lite::Primitive *primitive) { - SliceParameter *parameter = new (std::nothrow) SliceParameter(); - if (parameter == nullptr) { +OpParameter *PopulateSliceParameter(const lite::Primitive *primitive) { + SliceParameter *slice_param = new (std::nothrow) SliceParameter(); + if (slice_param == nullptr) { MS_LOG(ERROR) << "new SliceParameter failed."; return nullptr; } auto param = primitive->Value()->value_as_Slice(); - parameter->op_parameter_.type_ = primitive->Type(); + slice_param->op_parameter_.type_ = primitive->Type(); auto param_begin = param->begin(); auto param_size = param->size(); if (param_begin->size() != param_size->size()) { - delete parameter; + delete slice_param; return nullptr; } - parameter->param_length_ = static_cast(param_begin->size()); - for (int32_t i = 0; i < parameter->param_length_; ++i) { - parameter->begin_[i] = param_begin->Get(i); - parameter->size_[i] = param_size->Get(i); + slice_param->param_length_ = static_cast(param_begin->size()); + for (int32_t i = 0; i < slice_param->param_length_; ++i) { + slice_param->begin_[i] = param_begin->Get(i); + slice_param->size_[i] = param_size->Get(i); } - return parameter; + return reinterpret_cast(slice_param); } -BroadcastToParameter *PopulateBroadcastToParam(const lite::Primitive *primitive) { - BroadcastToParameter *parameter = new (std::nothrow) BroadcastToParameter(); - if (parameter == nullptr) { +OpParameter *PopulateBroadcastToParameter(const lite::Primitive *primitive) { + BroadcastToParameter *broadcast_param = new (std::nothrow) BroadcastToParameter(); + if (broadcast_param == nullptr) { MS_LOG(ERROR) << "new BroadcastToParameter failed."; return nullptr; } auto param = primitive->Value()->value_as_BroadcastTo(); - parameter->op_parameter_.type_ = primitive->Type(); + broadcast_param->op_parameter_.type_ = primitive->Type(); auto dst_shape = param->dst_shape(); - parameter->shape_size_ = dst_shape->size(); - for (size_t i = 0; i < parameter->shape_size_; ++i) { - parameter->shape_[i] = dst_shape->Get(i); + broadcast_param->shape_size_ = dst_shape->size(); + for (size_t i = 0; i < broadcast_param->shape_size_; ++i) { + broadcast_param->shape_[i] = dst_shape->Get(i); } - return parameter; + return reinterpret_cast(broadcast_param); } -ReshapeParameter *PopulateReshapeParam(const lite::Primitive *primitive) { - ReshapeParameter *parameter = new (std::nothrow) ReshapeParameter(); - if (parameter == nullptr) { +OpParameter *PopulateReshapeParameter(const lite::Primitive *primitive) { + ReshapeParameter *reshape_param = new (std::nothrow) ReshapeParameter(); + if (reshape_param == nullptr) { MS_LOG(ERROR) << "new ReshapeParameter failed."; return nullptr; } - parameter->op_parameter_.type_ = primitive->Type(); - return parameter; + reshape_param->op_parameter_.type_ = primitive->Type(); + return reinterpret_cast(reshape_param); } -ReverseParameter *PopulateReverseParameter(const lite::Primitive *primitive) { +OpParameter *PopulateReverseParameter(const lite::Primitive *primitive) { auto reverse_attr = primitive->Value()->value_as_Reverse(); - ReverseParameter *parameter = new (std::nothrow) ReverseParameter(); - if (parameter == nullptr) { + ReverseParameter *reverse_param = new (std::nothrow) ReverseParameter(); + if (reverse_param == nullptr) { MS_LOG(ERROR) << "new ReverseParameter failed."; return nullptr; } + reverse_param->op_parameter_.type_ = primitive->Type(); auto flatAxis = reverse_attr->axis(); - parameter->num_axis_ = flatAxis->size(); + reverse_param->num_axis_ = flatAxis->size(); int i = 0; for (auto iter = flatAxis->begin(); iter != flatAxis->end(); iter++) { - parameter->axis_[i++] = *iter; + reverse_param->axis_[i++] = *iter; } - return parameter; + return reinterpret_cast(reverse_param); } -UnsqueezeParameter *PopulateUnsqueezeParameter(const lite::Primitive *primitive) { +OpParameter *PopulateUnsqueezeParameter(const lite::Primitive *primitive) { auto unsqueeze_attr = primitive->Value()->value_as_Unsqueeze(); - UnsqueezeParameter *parameter = new (std::nothrow) UnsqueezeParameter(); - if (parameter == nullptr) { + UnsqueezeParameter *unsqueeze_param = new (std::nothrow) UnsqueezeParameter(); + if (unsqueeze_param == nullptr) { MS_LOG(ERROR) << "new ReverseParameter failed."; return nullptr; } + unsqueeze_param->op_parameter_.type_ = primitive->Type(); auto flatAxis = unsqueeze_attr->axis(); - parameter->num_dim_ = flatAxis->size(); + unsqueeze_param->num_dim_ = flatAxis->size(); int i = 0; for (auto iter = flatAxis->begin(); iter != flatAxis->end(); iter++) { - parameter->dims_[i++] = *iter; + unsqueeze_param->dims_[i++] = *iter; } - return parameter; + return reinterpret_cast(unsqueeze_param); } -StackParameter *PopulateStackParam(const lite::Primitive *primitive) { - StackParameter *parameter = new (std::nothrow) StackParameter(); - if (parameter == nullptr) { +OpParameter *PopulateStackParameter(const lite::Primitive *primitive) { + StackParameter *stack_param = new (std::nothrow) StackParameter(); + if (stack_param == nullptr) { MS_LOG(ERROR) << "new StackParameter failed."; return nullptr; } auto param = primitive->Value()->value_as_Stack(); - parameter->op_parameter_.type_ = primitive->Type(); - parameter->axis_ = param->axis(); - return parameter; + stack_param->op_parameter_.type_ = primitive->Type(); + stack_param->axis_ = param->axis(); + return reinterpret_cast(stack_param); } -UnstackParameter *PopulateUnstackParam(const lite::Primitive *primitive) { - UnstackParameter *parameter = new (std::nothrow) UnstackParameter(); - if (parameter == nullptr) { +OpParameter *PopulateUnstackParameter(const lite::Primitive *primitive) { + UnstackParameter *unstack_param = new (std::nothrow) UnstackParameter(); + if (unstack_param == nullptr) { MS_LOG(ERROR) << "new UnstackParameter failed."; return nullptr; } auto param = primitive->Value()->value_as_Unstack(); - parameter->op_parameter_.type_ = primitive->Type(); - parameter->num_ = param->num(); - parameter->axis_ = param->axis(); - return parameter; + unstack_param->op_parameter_.type_ = primitive->Type(); + unstack_param->num_ = param->num(); + unstack_param->axis_ = param->axis(); + return reinterpret_cast(unstack_param); } -ReverseSequenceParameter *PopulateReverseSequenceParam(const lite::Primitive *primitive) { - ReverseSequenceParameter *parameter = new (std::nothrow) ReverseSequenceParameter(); - if (parameter == nullptr) { +OpParameter *PopulateReverseSequenceParameter(const lite::Primitive *primitive) { + ReverseSequenceParameter *reverse_sequence_param = new (std::nothrow) ReverseSequenceParameter(); + if (reverse_sequence_param == nullptr) { MS_LOG(ERROR) << "new ReverseSequenceParameter failed."; return nullptr; } auto param = primitive->Value()->value_as_ReverseSequence(); - parameter->op_parameter_.type_ = primitive->Type(); - parameter->seq_axis_ = param->seqAxis(); - parameter->batch_axis_ = param->batchAxis(); - return parameter; + reverse_sequence_param->op_parameter_.type_ = primitive->Type(); + reverse_sequence_param->seq_axis_ = param->seqAxis(); + reverse_sequence_param->batch_axis_ = param->batchAxis(); + return reinterpret_cast(reverse_sequence_param); } -UniqueParameter *PopulateUniqueParam(const lite::Primitive *primitive) { - UniqueParameter *parameter = new (std::nothrow) UniqueParameter(); - if (parameter == nullptr) { +OpParameter *PopulateUniqueParameter(const lite::Primitive *primitive) { + UniqueParameter *unique_param = new (std::nothrow) UniqueParameter(); + if (unique_param == nullptr) { MS_LOG(ERROR) << "new PopulateUniqueParam failed."; return nullptr; } - parameter->op_parameter_.type_ = primitive->Type(); - return parameter; + unique_param->op_parameter_.type_ = primitive->Type(); + return reinterpret_cast(unique_param); } -DepthToSpaceParameter *PopulateDepthToSpaceParam(const lite::Primitive *primitive) { - DepthToSpaceParameter *parameter = new (std::nothrow) DepthToSpaceParameter(); - if (parameter == nullptr) { +OpParameter *PopulateDepthToSpaceParameter(const lite::Primitive *primitive) { + DepthToSpaceParameter *depth_space_param = new (std::nothrow) DepthToSpaceParameter(); + if (depth_space_param == nullptr) { MS_LOG(ERROR) << "new DepthToSpaceParameter failed."; return nullptr; } auto param = primitive->Value()->value_as_DepthToSpace(); - parameter->op_parameter_.type_ = primitive->Type(); - parameter->block_size_ = param->blockSize(); - return parameter; + depth_space_param->op_parameter_.type_ = primitive->Type(); + depth_space_param->block_size_ = param->blockSize(); + return reinterpret_cast(depth_space_param); } -SpaceToDepthParameter *PopulateSpaceToDepthParam(const lite::Primitive *primitive) { - SpaceToDepthParameter *parameter = new (std::nothrow) SpaceToDepthParameter(); - if (parameter == nullptr) { - MS_LOG(ERROR) << "new SpaceToDepthParameter failed."; +OpParameter *PopulateSpaceToDepthParameter(const lite::Primitive *primitive) { + SpaceToDepthParameter *space_depth_param = new (std::nothrow) SpaceToDepthParameter(); + if (space_depth_param == nullptr) { + MS_LOG(ERROR) << "new SpaceToDepthspace_depth_param failed."; return nullptr; } + space_depth_param->op_parameter_.type_ = primitive->Type(); auto param = primitive->Value()->value_as_DepthToSpace(); - parameter->op_parameter_.type_ = primitive->Type(); - parameter->block_size_ = param->blockSize(); + space_depth_param->op_parameter_.type_ = primitive->Type(); + space_depth_param->block_size_ = param->blockSize(); if (param->format() != schema::Format_NHWC) { MS_LOG(ERROR) << "Currently only NHWC format is supported."; return nullptr; } - return parameter; + return reinterpret_cast(space_depth_param); } -ResizeParameter *PopulateResizeParameter(const lite::Primitive *primitive) { - ResizeParameter *parameter = new (std::nothrow) ResizeParameter(); - if (parameter == nullptr) { +OpParameter *PopulateSpaceToBatchParameter(const lite::Primitive *primitive) { + SpaceToBatchParameter *space_batch_param = new (std::nothrow) SpaceToBatchParameter(); + if (space_batch_param == nullptr) { + MS_LOG(ERROR) << "new SpaceToBatchParameter failed."; + return nullptr; + } + space_batch_param->op_parameter_.type_ = primitive->Type(); + space_batch_param->op_parameter_.type_ = primitive->Type(); + auto block_sizes = ((lite::SpaceToBatch *)primitive)->BlockSizes(); + (void)memcpy(space_batch_param->block_sizes_, (block_sizes.data()), block_sizes.size() * sizeof(int)); + auto paddings = ((lite::SpaceToBatch *)primitive)->Paddings(); + (void)memcpy(space_batch_param->paddings_, (paddings.data()), paddings.size() * sizeof(int)); + auto in_shape = ((lite::SpaceToBatch *)primitive)->InShape(); + (void)memcpy(space_batch_param->in_shape_, (in_shape.data()), in_shape.size() * sizeof(int)); + auto padded_in_shape = ((lite::SpaceToBatch *)primitive)->PaddedInShape(); + (void)memcpy(space_batch_param->padded_in_shape_, (padded_in_shape.data()), padded_in_shape.size() * sizeof(int)); + return reinterpret_cast(space_batch_param); +} + +OpParameter *PopulateResizeParameter(const lite::Primitive *primitive) { + ResizeParameter *resize_param = new (std::nothrow) ResizeParameter(); + if (resize_param == nullptr) { MS_LOG(ERROR) << "new ResizeParameter failed."; return nullptr; } + resize_param->op_parameter_.type_ = primitive->Type(); auto param = primitive->Value()->value_as_Resize(); - parameter->method_ = param->method(); - parameter->new_height_ = param->newHeight(); - parameter->new_width_ = param->newWidth(); - parameter->align_corners_ = param->alignCorners(); - parameter->preserve_aspect_ratio_ = param->preserveAspectRatio(); - return parameter; + resize_param->method_ = param->method(); + resize_param->new_height_ = param->newHeight(); + resize_param->new_width_ = param->newWidth(); + resize_param->align_corners_ = param->alignCorners(); + resize_param->preserve_aspect_ratio_ = param->preserveAspectRatio(); + return reinterpret_cast(resize_param); } -BatchToSpaceParameter *PopulateBatchToSpaceParameter(const lite::Primitive *primitive) { - BatchToSpaceParameter *parameter = new (std::nothrow) BatchToSpaceParameter(); - if (parameter == nullptr) { +OpParameter *PopulateBatchToSpaceParameter(const lite::Primitive *primitive) { + BatchToSpaceParameter *batch_space_param = new (std::nothrow) BatchToSpaceParameter(); + if (batch_space_param == nullptr) { MS_LOG(ERROR) << "New BatchToSpaceParameter fail!"; return nullptr; } + batch_space_param->op_parameter_.type_ = primitive->Type(); auto param = primitive->Value()->value_as_BatchToSpace(); auto block_shape = param->blockShape(); if (block_shape->size() != BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) { @@ -928,308 +975,271 @@ BatchToSpaceParameter *PopulateBatchToSpaceParameter(const lite::Primitive *prim } for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) { - parameter->block_shape_[i] = block_shape->Get(i); + batch_space_param->block_shape_[i] = block_shape->Get(i); } for (int i = 0; i < BATCH_TO_SPACE_CROPS_SIZE; ++i) { - parameter->crops_[i] = crops->Get(i); + batch_space_param->crops_[i] = crops->Get(i); } - return parameter; + return reinterpret_cast(batch_space_param); } -CropParameter *PopulateCropParameter(const lite::Primitive *primitive) { +OpParameter *PopulateCropParameter(const lite::Primitive *primitive) { auto param = primitive->Value()->value_as_Crop(); auto param_offset = param->offsets(); if (param_offset->size() > CROP_OFFSET_MAX_SIZE) { - MS_LOG(ERROR) << "parameter offset size(" << param_offset->size() << ") should <= " << CROP_OFFSET_MAX_SIZE; + MS_LOG(ERROR) << "crop_param offset size(" << param_offset->size() << ") should <= " << CROP_OFFSET_MAX_SIZE; return nullptr; } - CropParameter *parameter = new (std::nothrow) CropParameter(); - if (parameter == nullptr) { + CropParameter *crop_param = new (std::nothrow) CropParameter(); + if (crop_param == nullptr) { MS_LOG(ERROR) << "new CropParameter fail!"; return nullptr; } - parameter->axis_ = param->axis(); - parameter->offset_size_ = param_offset->size(); + crop_param->op_parameter_.type_ = primitive->Type(); + crop_param->axis_ = param->axis(); + crop_param->offset_size_ = param_offset->size(); for (int i = 0; i < param_offset->size(); ++i) { - parameter->offset_[i] = param_offset->Get(i); + crop_param->offset_[i] = param_offset->Get(i); } - return parameter; + return reinterpret_cast(crop_param); } -OneHotParameter *PopulateOneHotParameter(const lite::Primitive *primitive) { - OneHotParameter *parameter = new (std::nothrow) OneHotParameter(); - if (parameter == nullptr) { +OpParameter *PopulateOneHotParameter(const lite::Primitive *primitive) { + OneHotParameter *one_hot_param = new (std::nothrow) OneHotParameter(); + if (one_hot_param == nullptr) { MS_LOG(ERROR) << "new OneHotParameter fail!"; return nullptr; } + one_hot_param->op_parameter_.type_ = primitive->Type(); auto param = primitive->Value()->value_as_OneHot(); if (param == nullptr) { - delete (parameter); + delete (one_hot_param); MS_LOG(ERROR) << "get OneHot param nullptr."; return nullptr; } - parameter->axis_ = param->axis(); - return parameter; + one_hot_param->axis_ = param->axis(); + return reinterpret_cast(one_hot_param); } -FlattenParameter *PopulateFlattenParameter(const lite::Primitive *primitive) { - FlattenParameter *parameter = new (std::nothrow) FlattenParameter(); - if (parameter == nullptr) { +OpParameter *PopulateFlattenParameter(const lite::Primitive *primitive) { + FlattenParameter *flatten_param = new (std::nothrow) FlattenParameter(); + if (flatten_param == nullptr) { MS_LOG(ERROR) << "new FlattenParameter fail!"; return nullptr; } - return parameter; + flatten_param->op_parameter_.type_ = primitive->Type(); + return reinterpret_cast(flatten_param); } -DequantizeParameter *PopulateDequantizeParameter(const lite::Primitive *primitive) { - DequantizeParameter *parameter = new (std::nothrow) DequantizeParameter(); - if (parameter == nullptr) { +OpParameter *PopulateDequantizeParameter(const lite::Primitive *primitive) { + DequantizeParameter *dequantize_parameter = new (std::nothrow) DequantizeParameter(); + if (dequantize_parameter == nullptr) { MS_LOG(ERROR) << "new DequantizeParameter fail!"; return nullptr; } - return parameter; + dequantize_parameter->op_parameter_.type_ = primitive->Type(); + return reinterpret_cast(dequantize_parameter); } -QuantizeParameter *PopulateQuantizeParameter(const lite::Primitive *primitive) { - QuantizeParameter *parameter = new (std::nothrow) QuantizeParameter(); - if (parameter == nullptr) { +OpParameter *PopulateQuantizeParameter(const lite::Primitive *primitive) { + QuantizeParameter *quantize_parameter = new (std::nothrow) QuantizeParameter(); + if (quantize_parameter == nullptr) { MS_LOG(ERROR) << "new QuantizeParameter fail!"; return nullptr; } - return parameter; + quantize_parameter->op_parameter_.type_ = primitive->Type(); + return reinterpret_cast(quantize_parameter); } -StridedSliceParameter *PopulateStridedSliceParam(const lite::Primitive *primitive) { - StridedSliceParameter *parameter = new (std::nothrow) StridedSliceParameter(); - if (parameter == nullptr) { +OpParameter *PopulateStridedSliceParameter(const lite::Primitive *primitive) { + StridedSliceParameter *strided_slice_param = new (std::nothrow) StridedSliceParameter(); + if (strided_slice_param == nullptr) { MS_LOG(ERROR) << "new StridedSliceParameter failed."; return nullptr; } - parameter->op_parameter_.type_ = primitive->Type(); + strided_slice_param->op_parameter_.type_ = primitive->Type(); auto n_dims = ((lite::StridedSlice *)primitive)->NDims(); - parameter->num_axes_ = n_dims; + strided_slice_param->num_axes_ = n_dims; auto begin = ((lite::StridedSlice *)primitive)->UpdatedBegins(); - (void)memcpy(parameter->begins_, (begin.data()), begin.size() * sizeof(int)); + (void)memcpy(strided_slice_param->begins_, (begin.data()), begin.size() * sizeof(int)); auto end = ((lite::StridedSlice *)primitive)->UpdatedEnds(); - (void)memcpy(parameter->ends_, (end.data()), end.size() * sizeof(int)); + (void)memcpy(strided_slice_param->ends_, (end.data()), end.size() * sizeof(int)); auto stride = ((lite::StridedSlice *)primitive)->UpdatedStrides(); - (void)memcpy(parameter->strides_, (stride.data()), stride.size() * sizeof(int)); + (void)memcpy(strided_slice_param->strides_, (stride.data()), stride.size() * sizeof(int)); auto in_shape = ((lite::StridedSlice *)primitive)->UpdatedInShape(); - (void)memcpy(parameter->in_shape_, (in_shape.data()), in_shape.size() * sizeof(int)); - return parameter; + (void)memcpy(strided_slice_param->in_shape_, (in_shape.data()), in_shape.size() * sizeof(int)); + return reinterpret_cast(strided_slice_param); } -OpParameter *PopulateAddNParam(const lite::Primitive *primitive) { - auto parameter = new (std::nothrow) OpParameter(); - if (parameter == nullptr) { +OpParameter *PopulateAddNParameter(const lite::Primitive *primitive) { + auto addn_param = new (std::nothrow) OpParameter(); + if (addn_param == nullptr) { MS_LOG(ERROR) << "new OpParameter fail!"; return nullptr; } - parameter->type_ = primitive->Type(); - return parameter; + addn_param->type_ = primitive->Type(); + return reinterpret_cast(addn_param); } -PriorBoxParameter *PopulatePriorBoxParameter(const lite::Primitive *primitive) { - PriorBoxParameter *param = new (std::nothrow) PriorBoxParameter(); - if (param == nullptr) { +OpParameter *PopulatePriorBoxParameter(const lite::Primitive *primitive) { + PriorBoxParameter *prior_box_param = new (std::nothrow) PriorBoxParameter(); + if (prior_box_param == nullptr) { MS_LOG(ERROR) << "new PriorBoxParameter failed."; return nullptr; } - param->op_parameter_.type_ = primitive->Type(); - auto prior_box_param = primitive->Value()->value_as_PriorBox(); + prior_box_param->op_parameter_.type_ = primitive->Type(); + auto prior_box_attr = primitive->Value()->value_as_PriorBox(); - if (prior_box_param->min_sizes()->size() > PRIOR_BOX_MAX_NUM) { + if (prior_box_attr->min_sizes()->size() > PRIOR_BOX_MAX_NUM) { MS_LOG(ERROR) << "PriorBox min_sizes size exceeds max num " << PRIOR_BOX_MAX_NUM << ", got " - << prior_box_param->min_sizes(); - delete (param); + << prior_box_attr->min_sizes(); + delete (prior_box_param); return nullptr; } - param->min_sizes_size = prior_box_param->min_sizes()->size(); - if (prior_box_param->max_sizes()->size() > PRIOR_BOX_MAX_NUM) { + prior_box_param->min_sizes_size = prior_box_attr->min_sizes()->size(); + if (prior_box_attr->max_sizes()->size() > PRIOR_BOX_MAX_NUM) { MS_LOG(ERROR) << "PriorBox max_sizes size exceeds max num " << PRIOR_BOX_MAX_NUM << ", got " - << prior_box_param->max_sizes(); - delete (param); + << prior_box_attr->max_sizes(); + delete (prior_box_param); return nullptr; } - param->max_sizes_size = prior_box_param->max_sizes()->size(); - (void)memcpy(param->max_sizes, prior_box_param->max_sizes()->data(), - prior_box_param->max_sizes()->size() * sizeof(int32_t)); - (void)memcpy(param->min_sizes, prior_box_param->min_sizes()->data(), - prior_box_param->min_sizes()->size() * sizeof(int32_t)); + prior_box_param->max_sizes_size = prior_box_attr->max_sizes()->size(); + (void)memcpy(prior_box_param->max_sizes, prior_box_attr->max_sizes()->data(), + prior_box_attr->max_sizes()->size() * sizeof(int32_t)); + (void)memcpy(prior_box_param->min_sizes, prior_box_attr->min_sizes()->data(), + prior_box_attr->min_sizes()->size() * sizeof(int32_t)); - if (prior_box_param->aspect_ratios()->size() > PRIOR_BOX_MAX_NUM) { + if (prior_box_attr->aspect_ratios()->size() > PRIOR_BOX_MAX_NUM) { MS_LOG(ERROR) << "PriorBox aspect_ratios size exceeds max num " << PRIOR_BOX_MAX_NUM << ", got " - << prior_box_param->aspect_ratios(); - delete (param); + << prior_box_attr->aspect_ratios(); + delete (prior_box_param); return nullptr; } - param->aspect_ratios_size = prior_box_param->aspect_ratios()->size(); - (void)memcpy(param->aspect_ratios, prior_box_param->aspect_ratios()->data(), - prior_box_param->aspect_ratios()->size() * sizeof(float)); - if (prior_box_param->variances()->size() != PRIOR_BOX_VAR_NUM) { + prior_box_param->aspect_ratios_size = prior_box_attr->aspect_ratios()->size(); + (void)memcpy(prior_box_param->aspect_ratios, prior_box_attr->aspect_ratios()->data(), + prior_box_attr->aspect_ratios()->size() * sizeof(float)); + if (prior_box_attr->variances()->size() != PRIOR_BOX_VAR_NUM) { MS_LOG(ERROR) << "PriorBox variances size should be " << PRIOR_BOX_VAR_NUM << ", got " - << prior_box_param->variances()->size(); - delete (param); - return nullptr; - } - (void)memcpy(param->variances, prior_box_param->variances()->data(), PRIOR_BOX_VAR_NUM * sizeof(float)); - param->flip = prior_box_param->flip(); - param->clip = prior_box_param->clip(); - param->offset = prior_box_param->offset(); - param->image_size_h = prior_box_param->image_size_h(); - param->image_size_w = prior_box_param->image_size_w(); - param->step_h = prior_box_param->step_h(); - param->step_w = prior_box_param->step_w(); - return param; + << prior_box_attr->variances()->size(); + delete (prior_box_param); + return nullptr; + } + (void)memcpy(prior_box_param->variances, prior_box_attr->variances()->data(), PRIOR_BOX_VAR_NUM * sizeof(float)); + prior_box_param->flip = prior_box_attr->flip(); + prior_box_param->clip = prior_box_attr->clip(); + prior_box_param->offset = prior_box_attr->offset(); + prior_box_param->image_size_h = prior_box_attr->image_size_h(); + prior_box_param->image_size_w = prior_box_attr->image_size_w(); + prior_box_param->step_h = prior_box_attr->step_h(); + prior_box_param->step_w = prior_box_attr->step_w(); + return reinterpret_cast(prior_box_param); } -SpaceToBatchParameter *PopulateSpaceToBatchParam(const lite::Primitive *primitive) { - SpaceToBatchParameter *parameter = new (std::nothrow) SpaceToBatchParameter(); - if (parameter == nullptr) { - MS_LOG(ERROR) << "new SpaceToBatchParameter failed."; - return nullptr; - } - parameter->op_parameter_.type_ = primitive->Type(); - auto block_sizes = ((lite::SpaceToBatch *)primitive)->BlockSizes(); - (void)memcpy(parameter->block_sizes_, (block_sizes.data()), block_sizes.size() * sizeof(int)); - auto paddings = ((lite::SpaceToBatch *)primitive)->Paddings(); - (void)memcpy(parameter->paddings_, (paddings.data()), paddings.size() * sizeof(int)); - auto in_shape = ((lite::SpaceToBatch *)primitive)->InShape(); - (void)memcpy(parameter->in_shape_, (in_shape.data()), in_shape.size() * sizeof(int)); - auto padded_in_shape = ((lite::SpaceToBatch *)primitive)->PaddedInShape(); - (void)memcpy(parameter->padded_in_shape_, (padded_in_shape.data()), padded_in_shape.size() * sizeof(int)); - return parameter; +PopulateParameterRegistry::PopulateParameterRegistry() { + populate_parameter_funcs_[schema::PrimitiveType_SoftMax] = PopulateSoftmaxParameter; + populate_parameter_funcs_[schema::PrimitiveType_Activation] = PopulateActivationParameter; + populate_parameter_funcs_[schema::PrimitiveType_Conv2D] = PopulateConvParameter; + populate_parameter_funcs_[schema::PrimitiveType_Reduce] = PopulateReduceParameter; + populate_parameter_funcs_[schema::PrimitiveType_Pooling] = PopulatePoolingParameter; + populate_parameter_funcs_[schema::PrimitiveType_DepthwiseConv2D] = PopulateConvDwParameter; + populate_parameter_funcs_[schema::PrimitiveType_DeDepthwiseConv2D] = PopulateDeconvDwParameter; + populate_parameter_funcs_[schema::PrimitiveType_DeConv2D] = PopulateDeconvParameter; + populate_parameter_funcs_[schema::PrimitiveType_FusedBatchNorm] = PopulateFusedBatchNorm; + populate_parameter_funcs_[schema::PrimitiveType_FullConnection] = PopulateFullconnectionParameter; + populate_parameter_funcs_[schema::PrimitiveType_Power] = PopulatePowerParameter; + populate_parameter_funcs_[schema::PrimitiveType_LocalResponseNormalization] = PopulateLocalResponseNormParameter; + populate_parameter_funcs_[schema::PrimitiveType_Range] = PopulateRangeParameter; + populate_parameter_funcs_[schema::PrimitiveType_Transpose] = PopulateTransposeParameter; + populate_parameter_funcs_[schema::PrimitiveType_Mul] = PopulateArithmetic; + populate_parameter_funcs_[schema::PrimitiveType_Add] = PopulateArithmetic; + populate_parameter_funcs_[schema::PrimitiveType_Sub] = PopulateArithmetic; + populate_parameter_funcs_[schema::PrimitiveType_Div] = PopulateArithmetic; + populate_parameter_funcs_[schema::PrimitiveType_FloorDiv] = PopulateArithmetic; + populate_parameter_funcs_[schema::PrimitiveType_FloorMod] = PopulateArithmetic; + populate_parameter_funcs_[schema::PrimitiveType_SquaredDifference] = PopulateArithmetic; + populate_parameter_funcs_[schema::PrimitiveType_BiasAdd] = PopulateArithmetic; + populate_parameter_funcs_[schema::PrimitiveType_Eltwise] = PopulateEltwiseParameter; + populate_parameter_funcs_[schema::PrimitiveType_ExpandDims] = PopulateExpandDimsParameter; + populate_parameter_funcs_[schema::PrimitiveType_Abs] = PopulateArithmeticSelf; + populate_parameter_funcs_[schema::PrimitiveType_Cos] = PopulateArithmeticSelf; + populate_parameter_funcs_[schema::PrimitiveType_Sin] = PopulateArithmeticSelf; + populate_parameter_funcs_[schema::PrimitiveType_Exp] = PopulateArithmeticSelf; + populate_parameter_funcs_[schema::PrimitiveType_Log] = PopulateArithmeticSelf; + populate_parameter_funcs_[schema::PrimitiveType_Square] = PopulateArithmeticSelf; + populate_parameter_funcs_[schema::PrimitiveType_Sqrt] = PopulateArithmeticSelf; + populate_parameter_funcs_[schema::PrimitiveType_Rsqrt] = PopulateArithmeticSelf; + populate_parameter_funcs_[schema::PrimitiveType_LogicalNot] = PopulateArithmeticSelf; + populate_parameter_funcs_[schema::PrimitiveType_Floor] = PopulateArithmeticSelf; + populate_parameter_funcs_[schema::PrimitiveType_Ceil] = PopulateArithmeticSelf; + populate_parameter_funcs_[schema::PrimitiveType_ArgMax] = PopulateArgMaxParameter; + populate_parameter_funcs_[schema::PrimitiveType_ArgMin] = PopulateArgMinParameter; + populate_parameter_funcs_[schema::PrimitiveType_Cast] = PopulateCastParameter; + populate_parameter_funcs_[schema::PrimitiveType_Scale] = PopulateScaleParameter; + populate_parameter_funcs_[schema::PrimitiveType_Reshape] = PopulateReshapeParameter; + populate_parameter_funcs_[schema::PrimitiveType_Concat] = PopulateConcatParameter; + populate_parameter_funcs_[schema::PrimitiveType_Tile] = PopulateTileParameter; + populate_parameter_funcs_[schema::PrimitiveType_TopK] = PopulateTopKParameter; + populate_parameter_funcs_[schema::PrimitiveType_Fill] = PopulateFillParameter; + populate_parameter_funcs_[schema::PrimitiveType_Gather] = PopulateGatherParameter; + populate_parameter_funcs_[schema::PrimitiveType_GatherNd] = PopulateGatherNdParameter; + populate_parameter_funcs_[schema::PrimitiveType_Slice] = PopulateSliceParameter; + populate_parameter_funcs_[schema::PrimitiveType_BroadcastTo] = PopulateBroadcastToParameter; + populate_parameter_funcs_[schema::PrimitiveType_Reverse] = PopulateReverseParameter; + populate_parameter_funcs_[schema::PrimitiveType_Stack] = PopulateStackParameter; + populate_parameter_funcs_[schema::PrimitiveType_Unstack] = PopulateUnstackParameter; + populate_parameter_funcs_[schema::PrimitiveType_ReverseSequence] = PopulateReverseSequenceParameter; + populate_parameter_funcs_[schema::PrimitiveType_Unique] = PopulateUniqueParameter; + populate_parameter_funcs_[schema::PrimitiveType_DepthToSpace] = PopulateDepthToSpaceParameter; + populate_parameter_funcs_[schema::PrimitiveType_Nchw2Nhwc] = PopulateNchw2NhwcParameter; + populate_parameter_funcs_[schema::PrimitiveType_Nhwc2Nchw] = PopulateNhwc2NchwParameter; + populate_parameter_funcs_[schema::PrimitiveType_Pad] = PopulatePadParameter; + populate_parameter_funcs_[schema::PrimitiveType_Resize] = PopulateResizeParameter; + populate_parameter_funcs_[schema::PrimitiveType_BatchToSpace] = PopulateBatchToSpaceParameter; + populate_parameter_funcs_[schema::PrimitiveType_SpaceToDepth] = PopulateSpaceToDepthParameter; + populate_parameter_funcs_[schema::PrimitiveType_SpaceToBatch] = PopulateSpaceToBatchParameter; + populate_parameter_funcs_[schema::PrimitiveType_Crop] = PopulateCropParameter; + populate_parameter_funcs_[schema::PrimitiveType_Unsqueeze] = PopulateUnsqueezeParameter; + populate_parameter_funcs_[schema::PrimitiveType_Flatten] = PopulateFlattenParameter; + populate_parameter_funcs_[schema::PrimitiveType_MatMul] = PopulateMatMulParameter; + populate_parameter_funcs_[schema::PrimitiveType_OneHot] = PopulateOneHotParameter; + populate_parameter_funcs_[schema::PrimitiveType_AddN] = PopulateAddNParameter; + populate_parameter_funcs_[schema::PrimitiveType_StridedSlice] = PopulateStridedSliceParameter; + populate_parameter_funcs_[schema::PrimitiveType_ScatterND] = PopulateScatterNDParameter; + populate_parameter_funcs_[schema::PrimitiveType_Square] = PopulateSqueezeParameter; + populate_parameter_funcs_[schema::PrimitiveType_Split] = PopulateSplitParameter; + populate_parameter_funcs_[schema::PrimitiveType_PriorBox] = PopulatePriorBoxParameter; + populate_parameter_funcs_[schema::PrimitiveType_OnnxInt8Dequantize] = PopulateDequantizeParameter; + populate_parameter_funcs_[schema::PrimitiveType_OnnxInt8Quantize] = PopulateQuantizeParameter; +} + +PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() { + static PopulateParameterRegistry populate_parameter_instance; + return &populate_parameter_instance; +} + +PopulateParameterFunc PopulateParameterRegistry::GetParameterFunc(const schema::PrimitiveType &type) { + return populate_parameter_funcs_[type]; } OpParameter *PopulateParameter(const lite::Primitive *primitive) { - MS_EXCEPTION_IF_NULL(primitive); + if (primitive == nullptr) { + MS_LOG(ERROR) << "Primitive is nullptr when populating parameter for op."; + return nullptr; + } + auto op_type = primitive->Type(); - switch (op_type) { - case schema::PrimitiveType_SoftMax: - return reinterpret_cast(PopulateSoftmaxParameter(primitive)); - case schema::PrimitiveType_Activation: - return reinterpret_cast(PopulateActivationParameter(primitive)); - case schema::PrimitiveType_Conv2D: - return reinterpret_cast(PopulateConvParameter(primitive)); - case schema::PrimitiveType_Reduce: - return reinterpret_cast(PopulateReduceParameter(primitive)); - case schema::PrimitiveType_Pooling: - return reinterpret_cast(PopulatePoolingParam(primitive)); - case schema::PrimitiveType_DepthwiseConv2D: - return reinterpret_cast(PopulateConvDwParameter(primitive)); - case schema::PrimitiveType_DeDepthwiseConv2D: - return reinterpret_cast(PopulateDeconvDwParameter(primitive)); - case schema::PrimitiveType_DeConv2D: - return reinterpret_cast(PopulateDeconvParameter(primitive)); - case schema::PrimitiveType_FusedBatchNorm: - return reinterpret_cast(PopulateFusedBatchNorm(primitive)); - case schema::PrimitiveType_FullConnection: - return reinterpret_cast(PopulateFullconnectionParameter(primitive)); - case schema::PrimitiveType_Power: - return reinterpret_cast(PopulatePowerParameter(primitive)); - case schema::PrimitiveType_LocalResponseNormalization: - return reinterpret_cast(PopulateLocalResponseNormParameter(primitive)); - case schema::PrimitiveType_Range: - return reinterpret_cast(PopulateRangeParameter(primitive)); - case schema::PrimitiveType_Transpose: - return reinterpret_cast(PopulateTransposeParameter(primitive)); - case schema::PrimitiveType_Mul: - case schema::PrimitiveType_Add: - case schema::PrimitiveType_Sub: - case schema::PrimitiveType_Div: - case schema::PrimitiveType_FloorDiv: - case schema::PrimitiveType_FloorMod: - case schema::PrimitiveType_SquaredDifference: - return reinterpret_cast(PopulateArithmetic(primitive)); - case schema::PrimitiveType_BiasAdd: - return reinterpret_cast(new ArithmeticParameter()); - case schema::PrimitiveType_Eltwise: - return reinterpret_cast(PopulateEltwiseParam(primitive)); - case schema::PrimitiveType_ExpandDims: - return reinterpret_cast(PopulateExpandDimsParam(primitive)); - case schema::PrimitiveType_Abs: - case schema::PrimitiveType_Cos: - case schema::PrimitiveType_Sin: - case schema::PrimitiveType_Exp: - case schema::PrimitiveType_Log: - case schema::PrimitiveType_Square: - case schema::PrimitiveType_Sqrt: - case schema::PrimitiveType_Rsqrt: - case schema::PrimitiveType_LogicalNot: - case schema::PrimitiveType_Floor: - return reinterpret_cast(PopulateArithmeticSelf(primitive)); - case schema::PrimitiveType_ArgMax: - return reinterpret_cast(PopulateArgMaxParam(primitive)); - case schema::PrimitiveType_ArgMin: - return reinterpret_cast(PopulateArgMinParam(primitive)); - case schema::PrimitiveType_Cast: - return reinterpret_cast(PopulateCastParam(primitive)); - case schema::PrimitiveType_Ceil: - return reinterpret_cast(PopulateCeilParameter(primitive)); - case schema::PrimitiveType_Scale: - return reinterpret_cast(PopulateScaleParameter(primitive)); - case schema::PrimitiveType_Reshape: - return reinterpret_cast(PopulateReshapeParam(primitive)); - case schema::PrimitiveType_Concat: - return reinterpret_cast(PopulateConcatParameter(primitive)); - case schema::PrimitiveType_Tile: - return reinterpret_cast(PopulateTileParameter(primitive)); - case schema::PrimitiveType_TopK: - return reinterpret_cast(PopulateTopKParameter(primitive)); - case schema::PrimitiveType_Fill: - return reinterpret_cast(PopulateFillParam(primitive)); - case schema::PrimitiveType_Gather: - return reinterpret_cast(PopulateGatherParameter(primitive)); - case schema::PrimitiveType_GatherNd: - return reinterpret_cast(PopulateGatherNdParameter(primitive)); - case schema::PrimitiveType_Slice: - return reinterpret_cast(PopulateSliceParam(primitive)); - case schema::PrimitiveType_BroadcastTo: - return reinterpret_cast(PopulateBroadcastToParam(primitive)); - case schema::PrimitiveType_Reverse: - return reinterpret_cast(PopulateReverseParameter(primitive)); - case schema::PrimitiveType_Stack: - return reinterpret_cast(PopulateStackParam(primitive)); - case schema::PrimitiveType_Unstack: - return reinterpret_cast(PopulateUnstackParam(primitive)); - case schema::PrimitiveType_ReverseSequence: - return reinterpret_cast(PopulateReverseSequenceParam(primitive)); - case schema::PrimitiveType_Unique: - return reinterpret_cast(PopulateUniqueParam(primitive)); - case schema::PrimitiveType_DepthToSpace: - return reinterpret_cast(PopulateDepthToSpaceParam(primitive)); - case schema::PrimitiveType_Nchw2Nhwc: - return reinterpret_cast(PopulateNchw2NhwcParameter(primitive)); - case schema::PrimitiveType_Nhwc2Nchw: - return reinterpret_cast(PopulateNhwc2NchwParameter(primitive)); - case schema::PrimitiveType_Pad: - return reinterpret_cast(PopulatePadParameter(primitive)); - case schema::PrimitiveType_Resize: - return reinterpret_cast(PopulateResizeParameter(primitive)); - case schema::PrimitiveType_BatchToSpace: - return reinterpret_cast(PopulateBatchToSpaceParameter(primitive)); - case schema::PrimitiveType_Crop: - return reinterpret_cast(PopulateCropParameter(primitive)); - case schema::PrimitiveType_Unsqueeze: - return reinterpret_cast(PopulateUnsqueezeParameter(primitive)); - case schema::PrimitiveType_Flatten: - return reinterpret_cast(PopulateFlattenParameter(primitive)); - case schema::PrimitiveType_MatMul: - return reinterpret_cast(PopulateMatMulParameter(primitive)); - case schema::PrimitiveType_OneHot: - return reinterpret_cast(PopulateOneHotParameter(primitive)); - case schema::PrimitiveType_AddN: - return reinterpret_cast(PopulateAddNParam(primitive)); - case schema::PrimitiveType_PriorBox: - return reinterpret_cast(PopulatePriorBoxParameter(primitive)); - case schema::PrimitiveType_OnnxInt8Dequantize: - return reinterpret_cast(PopulateDequantizeParameter(primitive)); - case schema::PrimitiveType_OnnxInt8Quantize: - return reinterpret_cast(PopulateQuantizeParameter(primitive)); - default: - break; + auto func = PopulateParameterRegistry::GetInstance()->GetParameterFunc(op_type); + if (func == nullptr) { + MS_LOG(ERROR) << "Get nullptr for Op Parameter Func."; + return nullptr; } - return nullptr; + + auto *parameter = func(primitive); + if (parameter == nullptr) { + MS_LOG(ERROR) << "Get nullptr for Op Parameter."; + return nullptr; + } + return parameter; } } // namespace mindspore::kernel diff --git a/mindspore/lite/src/populate_parameter.h b/mindspore/lite/src/populate_parameter.h index f2e9ab6afebf1586b14ba9a5a495d15558723fe6..be92c856cb238b3b4c687acd420da53f7cf25191 100644 --- a/mindspore/lite/src/populate_parameter.h +++ b/mindspore/lite/src/populate_parameter.h @@ -22,7 +22,20 @@ #include "src/runtime/kernel/arm/opclib/op_base.h" namespace mindspore::kernel { +typedef OpParameter *(*PopulateParameterFunc)(const lite::Primitive *); + +class PopulateParameterRegistry { + public: + PopulateParameterRegistry(); + ~PopulateParameterRegistry() = default; + + static PopulateParameterRegistry *GetInstance(); + PopulateParameterFunc GetParameterFunc(const schema::PrimitiveType &type); + + protected: + PopulateParameterFunc populate_parameter_funcs_[schema::PrimitiveType_MAX + 1]; +}; + OpParameter *PopulateParameter(const lite::Primitive *primitive); } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_POPULATE_PARAMETER_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/arithmetic_common.h b/mindspore/lite/src/runtime/kernel/arm/opclib/arithmetic_common.h index 15d132c38fe808113cf6db85d8f568bafba715cc..57ac591572d48e27890f28d16b015e50fc3ba654 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/arithmetic_common.h +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/arithmetic_common.h @@ -24,7 +24,7 @@ #include "src/runtime/kernel/arm/opclib/arithmetic_common.h" struct ArithmeticParameter { - OpParameter op_parameter; + OpParameter op_parameter_; bool broadcasting_; size_t ndim_; int in_shape0_[5]; diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/softmax.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/softmax.h index 2f91d15836d11fcce11aacf51888e9e71d4fb191..0787e1682a4332dd830aca94cd8f08557159a855 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/softmax.h +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/softmax.h @@ -20,7 +20,7 @@ #include "src/runtime/kernel/arm/opclib/op_base.h" struct SoftmaxParameter { - OpParameter op_parameter; + OpParameter op_parameter_; int32_t axis_; int element_size_; int n_dim_;