提交 6763b63c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5022 delete GetPrimitiveT

Merge pull request !5022 from yeyunpeng2020/master
...@@ -73,13 +73,11 @@ void Converter::FreeFuncGraph(const FuncGraphPtr &func_graph) { ...@@ -73,13 +73,11 @@ void Converter::FreeFuncGraph(const FuncGraphPtr &func_graph) {
return; return;
} }
if (primT->value.type == schema::PrimitiveType_TupleGetItem || if (primT->value.type == schema::PrimitiveType_TupleGetItem ||
primT->value.type == schema::PrimitiveType_MakeTuple || primT->value.type == schema::PrimitiveType_MakeTuple || primT->value.type == schema::PrimitiveType_Return) {
primT->value.type == schema::PrimitiveType_Return) {
delete primT; delete primT;
primitiveT_value->SetPrimitiveT(nullptr); primitiveT_value->SetPrimitiveT(nullptr);
} }
} }
return;
} }
MetaGraphT *Converter::Convert(const converter::Flags *flag) { MetaGraphT *Converter::Convert(const converter::Flags *flag) {
// parse the model and weight file to generate inference data structure // parse the model and weight file to generate inference data structure
......
...@@ -93,7 +93,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { ...@@ -93,7 +93,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
return false; return false;
} }
auto type = primitiveT_value->GetPrimitiveT()->value.type; auto type = (schema::PrimitiveType)primitiveT_value->Type();
MS_LOG(INFO) << "Primitive type: " << type; MS_LOG(INFO) << "Primitive type: " << type;
static const std::vector<schema::PrimitiveType> uint8OpList = { static const std::vector<schema::PrimitiveType> uint8OpList = {
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw,
......
...@@ -170,7 +170,7 @@ bool AnfEqual(const BaseRef &a, const BaseRef &b) { ...@@ -170,7 +170,7 @@ bool AnfEqual(const BaseRef &a, const BaseRef &b) {
if (a.m_ptr->isa<lite::PrimitiveC>() && b.m_ptr->isa<lite::PrimitiveC>()) { if (a.m_ptr->isa<lite::PrimitiveC>() && b.m_ptr->isa<lite::PrimitiveC>()) {
auto a_value_node_ptr = a.m_ptr->cast<PrimitiveCPtr>(); auto a_value_node_ptr = a.m_ptr->cast<PrimitiveCPtr>();
auto b_value_node_ptr = b.m_ptr->cast<PrimitiveCPtr>(); auto b_value_node_ptr = b.m_ptr->cast<PrimitiveCPtr>();
return a_value_node_ptr->GetPrimitiveT()->value.type == b_value_node_ptr->GetPrimitiveT()->value.type; return a_value_node_ptr->Type() == b_value_node_ptr->Type();
} }
return a == b; return a == b;
...@@ -316,7 +316,7 @@ schema::PrimitiveType GetCNodeType(const BaseRef &n) { ...@@ -316,7 +316,7 @@ schema::PrimitiveType GetCNodeType(const BaseRef &n) {
if (utils::isa<PrimitiveCPtr>(value)) { if (utils::isa<PrimitiveCPtr>(value)) {
auto primitive = value->cast<PrimitiveCPtr>(); auto primitive = value->cast<PrimitiveCPtr>();
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
return primitive->GetPrimitiveT()->value.type; return (schema::PrimitiveType)primitive->Type();
} else if (utils::isa<Primitive>(value)) { } else if (utils::isa<Primitive>(value)) {
auto primitive = value->cast<PrimitivePtr>(); auto primitive = value->cast<PrimitivePtr>();
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);
......
...@@ -73,26 +73,6 @@ const std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) { ...@@ -73,26 +73,6 @@ const std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) {
} }
return input_tensors; return input_tensors;
} }
schema::Primitive *PackPrimitiveT(const CNodePtr &cnode) {
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitiveT_value == nullptr) {
MS_LOG(ERROR) << "PrimitiveT_value is nullptr";
return nullptr;
}
auto *lite_primitive = primitiveT_value->GetPrimitiveT();
if (lite_primitive == nullptr) {
MS_LOG(ERROR) << "Primitive in primitiveT_value is nullptr";
return nullptr;
}
flatbuffers::FlatBufferBuilder builder(1024);
auto offset = schema::Primitive::Pack(builder, lite_primitive);
builder.Finish(offset);
auto buf = builder.GetBufferPointer();
auto primitive = flatbuffers::GetRoot<schema::Primitive>(buf);
return const_cast<schema::Primitive *>(primitive);
}
const ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) { const ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) {
auto parameter = func_graph->add_parameter(); auto parameter = func_graph->add_parameter();
std::vector<int> shape; std::vector<int> shape;
...@@ -175,16 +155,10 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An ...@@ -175,16 +155,10 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
} }
MS_LOG(INFO) << "Begin fold node:" << input_node->fullname_with_scope(); MS_LOG(INFO) << "Begin fold node:" << input_node->fullname_with_scope();
auto output_nums = GetOutputTensorNum(input_cnode); auto output_nums = GetOutputTensorNum(input_cnode);
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
std::vector<Tensor *> output_tensors{output_nums, new Tensor()}; std::vector<Tensor *> output_tensors{output_nums, new Tensor()};
auto scheam_primitive = PackPrimitiveT(input_cnode); primitiveT_value->InferShape(input_tensors, output_tensors);
auto lite_primitive = mindspore::lite::PrimitiveC::UnPackFromSchemaPrimitive(scheam_primitive); auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, primitiveT_value.get());
if (lite_primitive == nullptr) {
MS_LOG(ERROR) << "constant_folding schedule node lite primitive nullptr";
FreeInputTensor(&input_tensors);
return nullptr;
}
lite_primitive->InferShape(input_tensors, output_tensors);
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, lite_primitive);
if (lite_kernel == nullptr) { if (lite_kernel == nullptr) {
MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr"; MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr";
FreeInputTensor(&input_tensors); FreeInputTensor(&input_tensors);
......
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
#include "utils/utils.h" #include "utils/utils.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
#include "securec/include/securec.h" #include "securec/include/securec.h"
#include "src/ops/batch_norm.h"
#include "src/ops/fused_batchnorm.h"
namespace mindspore::opt { namespace mindspore::opt {
namespace { namespace {
...@@ -94,7 +96,7 @@ const BaseRef ConvBatchNormFusion::DefinePattern() const { ...@@ -94,7 +96,7 @@ const BaseRef ConvBatchNormFusion::DefinePattern() const {
auto bn_mean_var = std::make_shared<CondVar>(IsParamNode); auto bn_mean_var = std::make_shared<CondVar>(IsParamNode);
auto bn_variable_var = std::make_shared<CondVar>(IsParamNode); auto bn_variable_var = std::make_shared<CondVar>(IsParamNode);
auto bn_other_var = std::make_shared<SeqVar>(); auto bn_other_var = std::make_shared<SeqVar>();
return VectorRef({bn_var, conv_var, bn_mean_var, bn_variable_var, bn_other_var});; return VectorRef({bn_var, conv_var, bn_mean_var, bn_variable_var, bn_other_var});
} }
// BatchNorm weight Tensor definition: // BatchNorm weight Tensor definition:
// caffe // caffe
...@@ -106,7 +108,7 @@ const BaseRef ConvBatchNormFusion::DefinePattern() const { ...@@ -106,7 +108,7 @@ const BaseRef ConvBatchNormFusion::DefinePattern() const {
// estimated_mean --2 // estimated_mean --2
// estimated_variance --3 // estimated_variance --3
const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kernel_num, float *trans_scale, const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kernel_num, float *trans_scale,
float *trans_bias) const { float *trans_bias) const {
MS_ASSERT(bn_node != nullptr); MS_ASSERT(bn_node != nullptr);
AnfNodePtr bn_mean_node = nullptr; AnfNodePtr bn_mean_node = nullptr;
AnfNodePtr bn_variance_node = nullptr; AnfNodePtr bn_variance_node = nullptr;
...@@ -119,13 +121,19 @@ const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kern ...@@ -119,13 +121,19 @@ const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kern
bn_variance_node = bn_node->input(kCaffeBNVarIndex); bn_variance_node = bn_node->input(kCaffeBNVarIndex);
CheckIfNodeIsParam(bn_mean_node); CheckIfNodeIsParam(bn_mean_node);
CheckIfNodeIsParam(bn_variance_node); CheckIfNodeIsParam(bn_variance_node);
eps = primitiveT_value->GetPrimitiveT()->value.AsBatchNorm()->epsilon; MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::BatchNorm>>(primitiveT_value));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::BatchNorm>>(primitiveT_value);
MS_ASSERT(primc != nullptr);
eps = primc->GetEpsilon();
} else if (GetCNodeType(bn_node) == schema::PrimitiveType_FusedBatchNorm) { } else if (GetCNodeType(bn_node) == schema::PrimitiveType_FusedBatchNorm) {
bn_scale_node = bn_node->input(kTFBNScaleIndex); bn_scale_node = bn_node->input(kTFBNScaleIndex);
bn_bias_node = bn_node->input(kTFBNBiasIndex); bn_bias_node = bn_node->input(kTFBNBiasIndex);
bn_mean_node = bn_node->input(kTFBNMeanIndex); bn_mean_node = bn_node->input(kTFBNMeanIndex);
bn_variance_node = bn_node->input(kTFBNVarIndex); bn_variance_node = bn_node->input(kTFBNVarIndex);
eps = primitiveT_value->GetPrimitiveT()->value.AsFusedBatchNorm()->epsilon; MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::FusedBatchNorm>>(primitiveT_value));
auto primc = utils::cast<std::shared_ptr<mindspore::lite::FusedBatchNorm>>(primitiveT_value);
MS_ASSERT(primc != nullptr);
eps = primc->GetEpsilon();
} else { } else {
MS_LOG(EXCEPTION) << "not caffe or tf batchnorm op."; MS_LOG(EXCEPTION) << "not caffe or tf batchnorm op.";
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册