提交 3162b125 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5040 rename primitiveTvalue to primitive_c

Merge pull request !5040 from yeyunpeng2020/master
......@@ -165,14 +165,14 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) {
auto cnodes = func_graph->GetOrderedCnodes();
auto meta_graphT = std::make_unique<schema::MetaGraphT>();
for (const auto &cnode : cnodes) {
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitiveT_value == nullptr) {
MS_LOG(ERROR) << "PrimitiveT_value is nullptr";
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
return nullptr;
}
auto primT = primitiveT_value->GetPrimitiveT();
if (primitiveT_value->Type() == schema::PrimitiveType_TupleGetItem ||
primitiveT_value->Type() == schema::PrimitiveType_MakeTuple) {
auto primT = primitive_c->GetPrimitiveT();
if (primitive_c->Type() == schema::PrimitiveType_TupleGetItem ||
primitive_c->Type() == schema::PrimitiveType_MakeTuple) {
continue;
}
RemoveIfMakeTuple(cnode);
......@@ -196,7 +196,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) {
return nullptr;
}
SetOpOutputNode(cnode, meta_graphT, node.get());
ret = ConvertQuantParam(meta_graphT, primitiveT_value, node);
ret = ConvertQuantParam(meta_graphT, primitive_c, node);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvertQuantParam failed";
return nullptr;
......
......@@ -62,12 +62,12 @@ void Converter::FreeFuncGraph(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto cnodes = func_graph->GetOrderedCnodes();
for (auto &cnode : cnodes) {
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitiveT_value == nullptr) {
MS_LOG(ERROR) << "PrimitiveT_value is nullptr";
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
return;
}
auto primT = primitiveT_value->GetPrimitiveT();
auto primT = primitive_c->GetPrimitiveT();
if (primT == nullptr) {
MS_LOG(ERROR) << "PrimitiveT is nullptr";
return;
......@@ -75,7 +75,7 @@ void Converter::FreeFuncGraph(const FuncGraphPtr &func_graph) {
if (primT->value.type == schema::PrimitiveType_TupleGetItem ||
primT->value.type == schema::PrimitiveType_MakeTuple || primT->value.type == schema::PrimitiveType_Return) {
delete primT;
primitiveT_value->SetPrimitiveT(nullptr);
primitive_c->SetPrimitiveT(nullptr);
}
}
}
......
......@@ -534,7 +534,7 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct
return RET_OK;
}
STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitiveT_value,
STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c,
bool perchanel, bool depthwise) {
// const vector<int> dims = filter->dims;
// perlayer
......@@ -552,7 +552,7 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<P
MS_LOG(ERROR) << weight->fullname_with_scope() << " can not get value";
return RET_ERROR;
}
auto status = QuantFilter(paramValue, primitiveT_value, QuantType_PostTraining, quant_max, quant_min, bit_num,
auto status = QuantFilter(paramValue, primitive_c, QuantType_PostTraining, quant_max, quant_min, bit_num,
perchanel, depthwise);
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed: " << status;
......@@ -573,8 +573,8 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<P
return RET_OK;
}
STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveC> primitiveT_value) {
if (primitiveT_value == nullptr || bias == nullptr) {
STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveC> primitive_c) {
if (primitive_c == nullptr || bias == nullptr) {
MS_LOG(ERROR) << "null pointer!";
return RET_NULL_PTR;
}
......@@ -583,7 +583,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi
auto bias_default_param = bias_parameter_ptr->default_param();
auto bias_param = std::dynamic_pointer_cast<ParamValueLite>(bias_default_param);
auto active_weight_quant_params = primitiveT_value->GetInputQuantParams();
auto active_weight_quant_params = primitive_c->GetInputQuantParams();
if (active_weight_quant_params.size() != 2) {
MS_LOG(ERROR) << "unexpected active_weight_quant_params size: " << active_weight_quant_params.size();
return RET_ERROR;
......@@ -627,7 +627,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi
quant_param.inited = true;
quant_params.emplace_back(quant_param);
}
primitiveT_value->AddInputQuantParam(quant_params);
primitive_c->AddInputQuantParam(quant_params);
// quant bias data
int32_t *quant_datas = new (std::nothrow) int32_t[shape_size];
if (quant_datas == nullptr) {
......@@ -683,18 +683,18 @@ STATUS PostTrainingQuantizer::QuantNode() {
MS_LOG(INFO) << cnode_name << " can not do quant";
continue;
}
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitiveT_value == nullptr) {
MS_LOG(ERROR) << "PrimitiveT_value is nullptr";
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
continue;
}
if (input_scale.find(cnode) == input_scale.end()) {
primitiveT_value->SetQuantType(schema::QuantType_QUANT_NONE);
primitive_c->SetQuantType(schema::QuantType_QUANT_NONE);
continue;
}
primitiveT_value->ClearInputOutputQuantParam();
primitive_c->ClearInputOutputQuantParam();
auto op_name = cnode->fullname_with_scope();
auto op_type = (schema::PrimitiveType)primitiveT_value->Type();
auto op_type = (schema::PrimitiveType)primitive_c->Type();
MS_LOG(INFO) << "OpName: " << op_name;
if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D &&
op_type != PrimitiveType_FullConnection) {
......@@ -715,35 +715,35 @@ STATUS PostTrainingQuantizer::QuantNode() {
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
if (abstractTensor->element()->GetTypeTrack()->type_id() == kNumberTypeFloat32) {
MS_LOG(DEBUG) << "this parameter do quant";
DoWeightQuant(input_node, primitiveT_value, false, false);
DoWeightQuant(input_node, primitive_c, false, false);
} else {
MS_LOG(DEBUG) << "this parameter no need to do quant";
}
continue;
}
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node);
auto input_cnode_primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
if (input_cnode_primitiveT_value == nullptr) {
auto input_cnode_primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
if (input_cnode_primitive_c == nullptr) {
MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": "
<< " PrimitiveC is null";
continue;
}
if (!input_cnode_primitiveT_value->GetOutputQuantParams().empty()) {
for (auto &quant_param : input_cnode_primitiveT_value->GetOutputQuantParams()) {
primitiveT_value->AddInputQuantParam(quant_param);
if (!input_cnode_primitive_c->GetOutputQuantParams().empty()) {
for (auto &quant_param : input_cnode_primitive_c->GetOutputQuantParams()) {
primitive_c->AddInputQuantParam(quant_param);
}
} else {
// do input quant
double scale = input_scale[cnode];
int32_t zp = input_zero_point[cnode];
DoQuantInput(scale, zp, &input_min_max[cnode], primitiveT_value);
DoQuantInput(scale, zp, &input_min_max[cnode], primitive_c);
}
}
} else {
// do input quant
double scale = input_scale[cnode];
int32_t convInputzeropoint = input_zero_point[cnode];
DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitiveT_value);
DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitive_c);
// do weight quant
auto weight = cnode->input(2);
bool depthwise = op_type == PrimitiveType_DepthwiseConv2D;
......@@ -751,18 +751,18 @@ STATUS PostTrainingQuantizer::QuantNode() {
if (op_type == PrimitiveType_FullConnection) {
perchannel = false;
}
DoWeightQuant(weight, primitiveT_value, perchannel, depthwise);
DoWeightQuant(weight, primitive_c, perchannel, depthwise);
// do bias quant
if (cnode->inputs().size() == 4) {
auto bias = cnode->input(3);
DoBiasQuant(bias, primitiveT_value);
DoBiasQuant(bias, primitive_c);
}
}
// do output quant
double OutputScale = output_scale[cnode];
int32_t OutputZeropoint = output_zeropoint[cnode];
DoQuantOutput(OutputScale, OutputZeropoint, &output_min_max[cnode], primitiveT_value);
primitiveT_value->SetQuantType(schema::QuantType_PostTraining);
DoQuantOutput(OutputScale, OutputZeropoint, &output_min_max[cnode], primitive_c);
primitive_c->SetQuantType(schema::QuantType_PostTraining);
}
return RET_OK;
}
......
......@@ -95,10 +95,10 @@ class PostTrainingQuantizer : public Quantizer {
STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>);
STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>);
STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitiveT_value, bool perchannel,
STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchannel,
bool depthwise);
STATUS DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveC> primitiveT_value);
STATUS DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveC> primitive_c);
};
struct DivergInfo;
......
......@@ -44,17 +44,17 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
bool first = true;
for (auto &cnode : cnodes) {
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
auto curnode_quant_type = schema::QuantType_QUANT_NONE;
if (primitiveT_value == nullptr) {
MS_LOG(WARNING) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope();
if (primitive_c == nullptr) {
MS_LOG(WARNING) << "primitive_c is nullptr: " << cnode->fullname_with_scope();
} else {
curnode_quant_type = primitiveT_value->GetQuantType();
curnode_quant_type = primitive_c->GetQuantType();
}
if (first) {
if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) {
auto value_node =
NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitiveT_value->GetInputQuantParams().front());
NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitive_c->GetInputQuantParams().front());
std::vector<AnfNodePtr> op_inputs = {value_node, cnode->input(1)};
auto quant_cast_cnode = graph->NewCNode(op_inputs);
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast");
......@@ -72,24 +72,24 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
continue;
}
auto input_cnode = std::dynamic_pointer_cast<CNode>(input_node);
auto input_cnode_primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
if (input_cnode_primitiveT_value == nullptr) {
auto input_cnode_primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
if (input_cnode_primitive_c == nullptr) {
MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": "
<< " PrimitiveC is null";
continue;
}
auto input_cnode_quant_type = input_cnode_primitiveT_value->GetQuantType();
auto input_cnode_quant_type = input_cnode_primitive_c->GetQuantType();
if (curnode_quant_type != input_cnode_quant_type) {
ValueNodePtr value_node = nullptr;
if (curnode_quant_type == schema::QuantType_PostTraining &&
input_cnode_quant_type == schema::QuantType_QUANT_NONE) {
value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8,
primitiveT_value->GetInputQuantParams().front());
primitive_c->GetInputQuantParams().front());
} else if (curnode_quant_type == schema::QuantType_QUANT_NONE &&
input_cnode_quant_type == schema::QuantType_PostTraining) {
value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32,
input_cnode_primitiveT_value->GetInputQuantParams().front());
input_cnode_primitive_c->GetInputQuantParams().front());
}
if (value_node == nullptr) {
MS_LOG(WARNING) << "value_node is null! "
......
......@@ -87,13 +87,13 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
}
auto cnode = std::dynamic_pointer_cast<CNode>(node);
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitiveT_value == nullptr) {
MS_LOG(WARNING) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope();
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(WARNING) << "primitive_c is nullptr: " << cnode->fullname_with_scope();
return false;
}
auto type = (schema::PrimitiveType)primitiveT_value->Type();
auto type = (schema::PrimitiveType)primitive_c->Type();
MS_LOG(INFO) << "Primitive type: " << type;
static const std::vector<schema::PrimitiveType> uint8OpList = {
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw,
......@@ -279,7 +279,7 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl
return RET_OK;
}
STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitiveT_value, QuantType quantType,
STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType,
int quant_max, int quant_min, size_t bitNum, bool per_channel, bool depth_wise) {
auto dims = weight->tensor_shape();
if (per_channel) {
......@@ -450,7 +450,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
MS_LOG(ERROR) << "quant_params empty";
return RET_ERROR;
}
primitiveT_value->AddInputQuantParam(quant_params);
primitive_c->AddInputQuantParam(quant_params);
return RET_OK;
}
......
......@@ -118,7 +118,7 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan
}();
}
STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitiveT_value, QuantType quantType,
STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType,
int quant_max, int quant_min, size_t bitNum = UINT8_QUANTIZATION, bool per_channel = false,
bool depth_wise = false);
......
......@@ -135,6 +135,26 @@ void FreeInputTensor(std::vector<Tensor *> *input_tensor) {
}
return;
}
schema::Primitive *PackPrimitiveT(const CNodePtr &cnode) {
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
return nullptr;
}
auto *lite_primitive = primitive_c->GetPrimitiveT();
if (lite_primitive == nullptr) {
MS_LOG(ERROR) << "Primitive in primitive_c is nullptr";
return nullptr;
}
flatbuffers::FlatBufferBuilder builder(1024);
auto offset = schema::Primitive::Pack(builder, lite_primitive);
builder.Finish(offset);
auto buf = builder.GetBufferPointer();
auto primitive = flatbuffers::GetRoot<schema::Primitive>(buf);
return const_cast<schema::Primitive *>(primitive);
}
const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
CheckIfFuncGraphIsNull(func_graph);
......@@ -155,10 +175,16 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
}
MS_LOG(INFO) << "Begin fold node:" << input_node->fullname_with_scope();
auto output_nums = GetOutputTensorNum(input_cnode);
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
std::vector<Tensor *> output_tensors{output_nums, new Tensor()};
primitiveT_value->InferShape(input_tensors, output_tensors);
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, primitiveT_value.get());
auto scheam_primitive = PackPrimitiveT(input_cnode);
auto lite_primitive = mindspore::lite::PrimitiveC::UnPackFromSchemaPrimitive(scheam_primitive);
if (lite_primitive == nullptr) {
MS_LOG(ERROR) << "constant_folding schedule node lite primitive nullptr";
FreeInputTensor(&input_tensors);
return nullptr;
}
lite_primitive->InferShape(input_tensors, output_tensors);
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, lite_primitive);
if (lite_kernel == nullptr) {
MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr";
FreeInputTensor(&input_tensors);
......
......@@ -62,17 +62,17 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c
}
auto conv_node = pre_node->cast<CNodePtr>();
auto node_type = GetCNodeType(conv_node);
auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv_node->input(0));
MS_ASSERT(primitiveT_value);
auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv_node->input(0));
MS_ASSERT(primitive_c);
if (node_type == schema::PrimitiveType_Conv2D) {
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Conv2D>>(primitiveT_value));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitiveT_value);
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c);
MS_ASSERT(primc != nullptr);
primc->SetActivationType(activation_type);
return pre_node;
} else if (node_type == schema::PrimitiveType_DepthwiseConv2D) {
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitiveT_value));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitiveT_value);
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c);
MS_ASSERT(primc != nullptr);
primc->SetActivationType(activation_type);
return pre_node;
......
......@@ -160,22 +160,22 @@ const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, cons
auto conv_node = conv_node_anf->cast<CNodePtr>();
CheckIfCNodeIsNull(conv_node);
GenConvNewBias(func_graph, conv_node, add_node);
auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv_node->input(0));
MS_ASSERT(primitiveT_value != nullptr);
auto type = primitiveT_value->Type();
auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv_node->input(0));
MS_ASSERT(primitive_c != nullptr);
auto type = primitive_c->Type();
if (type == schema::PrimitiveType_Conv2D) {
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Conv2D>>(primitiveT_value));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitiveT_value);
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c);
MS_ASSERT(primc != nullptr);
primc->SetHasBias(true);
} else if (type == schema::PrimitiveType_DepthwiseConv2D) {
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitiveT_value));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitiveT_value);
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c);
MS_ASSERT(primc != nullptr);
primc->SetHasBias(true);
} else if (type == schema::PrimitiveType_DeConv2D) {
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DeConv2D>>(primitiveT_value));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::DeConv2D>>(primitiveT_value);
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DeConv2D>>(primitive_c));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::DeConv2D>>(primitive_c);
MS_ASSERT(primc != nullptr);
primc->SetHasBias(true);
} else {
......
......@@ -115,14 +115,14 @@ const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kern
AnfNodePtr bn_scale_node = nullptr;
AnfNodePtr bn_bias_node = nullptr;
float eps = 0;
auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(bn_node->input(0));
auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(bn_node->input(0));
if (GetCNodeType(bn_node) == schema::PrimitiveType_BatchNorm) {
bn_mean_node = bn_node->input(kCaffeBNMeanIndex);
bn_variance_node = bn_node->input(kCaffeBNVarIndex);
CheckIfNodeIsParam(bn_mean_node);
CheckIfNodeIsParam(bn_variance_node);
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::BatchNorm>>(primitiveT_value));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::BatchNorm>>(primitiveT_value);
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::BatchNorm>>(primitive_c));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::BatchNorm>>(primitive_c);
MS_ASSERT(primc != nullptr);
eps = primc->GetEpsilon();
} else if (GetCNodeType(bn_node) == schema::PrimitiveType_FusedBatchNorm) {
......@@ -130,8 +130,8 @@ const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kern
bn_bias_node = bn_node->input(kTFBNBiasIndex);
bn_mean_node = bn_node->input(kTFBNMeanIndex);
bn_variance_node = bn_node->input(kTFBNVarIndex);
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::FusedBatchNorm>>(primitiveT_value));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::FusedBatchNorm>>(primitiveT_value);
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::FusedBatchNorm>>(primitive_c));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::FusedBatchNorm>>(primitive_c);
MS_ASSERT(primc != nullptr);
eps = primc->GetEpsilon();
} else {
......
......@@ -97,17 +97,17 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co
GenNewConvTensor(func_graph, conv_node, kernel_nums, trans_scale, trans_bias);
delete[] trans_bias;
delete[] trans_scale;
auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv_node->input(0));
MS_ASSERT(primitiveT_value != nullptr);
auto type = primitiveT_value->Type();
auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv_node->input(0));
MS_ASSERT(primitive_c != nullptr);
auto type = primitive_c->Type();
if (type == schema::PrimitiveType_Conv2D) {
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Conv2D>>(primitiveT_value));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitiveT_value);
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c);
MS_ASSERT(primc != nullptr);
primc->SetHasBias(true);
} else if (type == schema::PrimitiveType_DepthwiseConv2D) {
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitiveT_value));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitiveT_value);
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c);
MS_ASSERT(primc != nullptr);
primc->SetHasBias(true);
} else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册