提交 6b46acb3 编写于 作者: Y yeyunpeng

change long to int64

上级 dde25759
......@@ -19,22 +19,22 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
long Crop::GetAxis() const { return this->primitive_->value.AsCrop()->axis; }
std::vector<long> Crop::GetOffsets() const { return this->primitive_->value.AsCrop()->offsets; }
int64_t Crop::GetAxis() const { return this->primitive_->value.AsCrop()->axis; }
std::vector<int64_t> Crop::GetOffsets() const { return this->primitive_->value.AsCrop()->offsets; }
void Crop::SetAxis(long axis) { this->primitive_->value.AsCrop()->axis = axis; }
void Crop::SetOffsets(const std::vector<long> &offsets) { this->primitive_->value.AsCrop()->offsets = offsets; }
void Crop::SetAxis(int64_t axis) { this->primitive_->value.AsCrop()->axis = axis; }
void Crop::SetOffsets(const std::vector<int64_t> &offsets) { this->primitive_->value.AsCrop()->offsets = offsets; }
#else
long Crop::GetAxis() const { return this->primitive_->value_as_Crop()->axis(); }
std::vector<long> Crop::GetOffsets() const {
int64_t Crop::GetAxis() const { return this->primitive_->value_as_Crop()->axis(); }
std::vector<int64_t> Crop::GetOffsets() const {
auto fb_vector = this->primitive_->value_as_Crop()->offsets();
return std::vector<long>(fb_vector->begin(), fb_vector->end());
return std::vector<int64_t>(fb_vector->begin(), fb_vector->end());
}
void Crop::SetAxis(long axis) {}
void Crop::SetOffsets(const std::vector<long> &offsets) {}
void Crop::SetAxis(int64_t axis) {}
void Crop::SetOffsets(const std::vector<int64_t> &offsets) {}
#endif
namespace {
constexpr int kCropOutputNum = 1;
......
......@@ -34,10 +34,10 @@ class Crop : public PrimitiveC {
explicit Crop(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
long GetAxis() const;
std::vector<long> GetOffsets() const;
void SetAxis(long axis);
void SetOffsets(const std::vector<long> &offsets);
int64_t GetAxis() const;
std::vector<int64_t> GetOffsets() const;
void SetAxis(int64_t axis);
void SetOffsets(const std::vector<int64_t> &offsets);
};
} // namespace lite
} // namespace mindspore
......
......@@ -31,16 +31,16 @@ float DetectionPostProcess::GetNmsIouThreshold() const {
float DetectionPostProcess::GetNmsScoreThreshold() const {
return this->primitive_->value.AsDetectionPostProcess()->NmsScoreThreshold;
}
long DetectionPostProcess::GetMaxDetections() const {
int64_t DetectionPostProcess::GetMaxDetections() const {
return this->primitive_->value.AsDetectionPostProcess()->MaxDetections;
}
long DetectionPostProcess::GetDetectionsPreClass() const {
int64_t DetectionPostProcess::GetDetectionsPreClass() const {
return this->primitive_->value.AsDetectionPostProcess()->DetectionsPreClass;
}
long DetectionPostProcess::GetMaxClassesPreDetection() const {
int64_t DetectionPostProcess::GetMaxClassesPreDetection() const {
return this->primitive_->value.AsDetectionPostProcess()->MaxClassesPreDetection;
}
long DetectionPostProcess::GetNumClasses() const {
int64_t DetectionPostProcess::GetNumClasses() const {
return this->primitive_->value.AsDetectionPostProcess()->NumClasses;
}
bool DetectionPostProcess::GetUseRegularNms() const {
......@@ -71,16 +71,16 @@ void DetectionPostProcess::SetNmsIouThreshold(float nms_iou_threshold) {
void DetectionPostProcess::SetNmsScoreThreshold(float nms_score_threshold) {
this->primitive_->value.AsDetectionPostProcess()->NmsScoreThreshold = nms_score_threshold;
}
void DetectionPostProcess::SetMaxDetections(long max_detections) {
void DetectionPostProcess::SetMaxDetections(int64_t max_detections) {
this->primitive_->value.AsDetectionPostProcess()->MaxClassesPreDetection = max_detections;
}
void DetectionPostProcess::SetDetectionsPreClass(long detections_pre_class) {
void DetectionPostProcess::SetDetectionsPreClass(int64_t detections_pre_class) {
this->primitive_->value.AsDetectionPostProcess()->DetectionsPreClass = detections_pre_class;
}
void DetectionPostProcess::SetMaxClassesPreDetection(long max_classes_pre_detection) {
void DetectionPostProcess::SetMaxClassesPreDetection(int64_t max_classes_pre_detection) {
this->primitive_->value.AsDetectionPostProcess()->MaxClassesPreDetection = max_classes_pre_detection;
}
void DetectionPostProcess::SetNumClasses(long num_classes) {
void DetectionPostProcess::SetNumClasses(int64_t num_classes) {
this->primitive_->value.AsDetectionPostProcess()->NumClasses = num_classes;
}
void DetectionPostProcess::SetUseRegularNms(bool use_regular_nms) {
......@@ -103,16 +103,16 @@ float DetectionPostProcess::GetNmsIouThreshold() const {
float DetectionPostProcess::GetNmsScoreThreshold() const {
return this->primitive_->value_as_DetectionPostProcess()->NmsScoreThreshold();
}
long DetectionPostProcess::GetMaxDetections() const {
int64_t DetectionPostProcess::GetMaxDetections() const {
return this->primitive_->value_as_DetectionPostProcess()->MaxDetections();
}
long DetectionPostProcess::GetDetectionsPreClass() const {
int64_t DetectionPostProcess::GetDetectionsPreClass() const {
return this->primitive_->value_as_DetectionPostProcess()->DetectionsPreClass();
}
long DetectionPostProcess::GetMaxClassesPreDetection() const {
int64_t DetectionPostProcess::GetMaxClassesPreDetection() const {
return this->primitive_->value_as_DetectionPostProcess()->MaxClassesPreDetection();
}
long DetectionPostProcess::GetNumClasses() const {
int64_t DetectionPostProcess::GetNumClasses() const {
return this->primitive_->value_as_DetectionPostProcess()->NumClasses();
}
bool DetectionPostProcess::GetUseRegularNms() const {
......@@ -127,10 +127,10 @@ void DetectionPostProcess::SetXScale(float x_scale) {}
void DetectionPostProcess::SetYScale(float y_scale) {}
void DetectionPostProcess::SetNmsIouThreshold(float nms_iou_threshold) {}
void DetectionPostProcess::SetNmsScoreThreshold(float nms_score_threshold) {}
void DetectionPostProcess::SetMaxDetections(long max_detections) {}
void DetectionPostProcess::SetDetectionsPreClass(long detections_pre_class) {}
void DetectionPostProcess::SetMaxClassesPreDetection(long max_classes_pre_detection) {}
void DetectionPostProcess::SetNumClasses(long num_classes) {}
void DetectionPostProcess::SetMaxDetections(int64_t max_detections) {}
void DetectionPostProcess::SetDetectionsPreClass(int64_t detections_pre_class) {}
void DetectionPostProcess::SetMaxClassesPreDetection(int64_t max_classes_pre_detection) {}
void DetectionPostProcess::SetNumClasses(int64_t num_classes) {}
void DetectionPostProcess::SetUseRegularNms(bool use_regular_nms) {}
#endif
} // namespace lite
......
......@@ -41,10 +41,10 @@ class DetectionPostProcess : public PrimitiveC {
float GetYScale() const;
float GetNmsIouThreshold() const;
float GetNmsScoreThreshold() const;
long GetMaxDetections() const;
long GetDetectionsPreClass() const;
long GetMaxClassesPreDetection() const;
long GetNumClasses() const;
int64_t GetMaxDetections() const;
int64_t GetDetectionsPreClass() const;
int64_t GetMaxClassesPreDetection() const;
int64_t GetNumClasses() const;
bool GetUseRegularNms() const;
void SetFormat(int format);
void SetInputSize(int input_size);
......@@ -54,10 +54,10 @@ class DetectionPostProcess : public PrimitiveC {
void SetYScale(float y_scale);
void SetNmsIouThreshold(float nms_iou_threshold);
void SetNmsScoreThreshold(float nms_score_threshold);
void SetMaxDetections(long max_detections);
void SetDetectionsPreClass(long detections_pre_class);
void SetMaxClassesPreDetection(long max_classes_pre_detection);
void SetNumClasses(long num_classes);
void SetMaxDetections(int64_t max_detections);
void SetDetectionsPreClass(int64_t detections_pre_class);
void SetMaxClassesPreDetection(int64_t max_classes_pre_detection);
void SetNumClasses(int64_t num_classes);
void SetUseRegularNms(bool use_regular_nms);
};
} // namespace lite
......
......@@ -19,18 +19,18 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
std::vector<long> Permute::GetOrder() const { return this->primitive_->value.AsPermute()->order; }
std::vector<int64_t> Permute::GetOrder() const { return this->primitive_->value.AsPermute()->order; }
void Permute::SetOrder(const std::vector<long> &order) { this->primitive_->value.AsPermute()->order = order; }
void Permute::SetOrder(const std::vector<int64_t> &order) { this->primitive_->value.AsPermute()->order = order; }
#else
std::vector<long> Permute::GetOrder() const {
std::vector<int64_t> Permute::GetOrder() const {
auto fb_vector = this->primitive_->value_as_Permute()->order();
return std::vector<long>(fb_vector->begin(), fb_vector->end());
return std::vector<int64_t>(fb_vector->begin(), fb_vector->end());
}
void Permute::SetOrder(const std::vector<long> &order) {}
void Permute::SetOrder(const std::vector<int64_t> &order) {}
#endif
} // namespace lite
} // namespace mindspore
......@@ -33,8 +33,8 @@ class Permute : public PrimitiveC {
#else
explicit Permute(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
std::vector<long> GetOrder() const;
void SetOrder(const std::vector<long> &order);
std::vector<int64_t> GetOrder() const;
void SetOrder(const std::vector<int64_t> &order);
};
} // namespace lite
} // namespace mindspore
......
......@@ -410,11 +410,32 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT
return new Shape(primitive);
case schema::PrimitiveType_Unsqueeze:
return new Unsqueeze(primitive);
case schema::PrimitiveType_BatchToSpace:
return new BatchToSpace(primitive);
case schema::PrimitiveType_SpaceToBatch:
return new SpaceToBatch(primitive);
case schema::PrimitiveType_BroadcastTo:
return new BroadcastTo(primitive);
case schema::PrimitiveType_DepthToSpace:
return new DepthToSpace(primitive);
case schema::PrimitiveType_Lstm:
return new Lstm(primitive);
case schema::PrimitiveType_ZerosLike:
return new ZerosLike(primitive);
case schema::PrimitiveType_MakeTuple:
return new MakeTuple(primitive);
case schema::PrimitiveType_Where:
return new Where(primitive);
case schema::PrimitiveType_ScatterND:
return new ScatterND(primitive);
case schema::PrimitiveType_ConstantOfShape:
return new ConstantOfShape(primitive);
default:
MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromSchemaPrimitiveT : "
<< schema::EnumNamePrimitiveType(op_type);
return nullptr;
break;
}
return nullptr;
}
#else
PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive *primitive) {
......@@ -433,6 +454,8 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive *
return new Reduce(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Pooling:
return new Pooling(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_ROIPooling:
return new ROIPooling(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_DepthwiseConv2D:
return new DepthwiseConv2D(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_FusedBatchNorm:
......@@ -443,6 +466,8 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive *
return new FullConnection(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Power:
return new Power(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Pad:
return new Pad(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Range:
return new Range(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Mul:
......@@ -469,20 +494,22 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive *
return new Scale(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Eltwise:
return new Eltwise(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Ceil:
return new Ceil(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Concat:
return new Concat(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Fill:
return new Fill(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Nhwc2Nchw:
return new Nhwc2Nchw(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Nchw2Nhwc:
return new Nchw2Nhwc(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Transpose:
return new Transpose(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Slice:
return new Slice(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Squeeze:
return new Squeeze(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Nchw2Nhwc:
return new Nchw2Nhwc(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Nhwc2Nchw:
return new Nhwc2Nchw(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Flatten:
return new Flatten(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Mean:
......@@ -521,8 +548,6 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive *
return new Maximum(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Minimum:
return new Minimum(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Pad:
return new Pad(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_StridedSlice:
return new StridedSlice(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Prelu:
......@@ -559,12 +584,12 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive *
return new GreaterEqual(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Floor:
return new Floor(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Ceil:
return new Ceil(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Split:
return new Split(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_OneHot:
return new OneHot(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_PriorBox:
return new PriorBox(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_SpaceToDepth:
return new SpaceToDepth(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Tile:
......@@ -591,7 +616,29 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(mindspore::schema::Primitive *
return new Shape(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Unsqueeze:
return new Unsqueeze(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_BatchToSpace:
return new BatchToSpace(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_SpaceToBatch:
return new SpaceToBatch(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_BroadcastTo:
return new BroadcastTo(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_DepthToSpace:
return new DepthToSpace(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Lstm:
return new Lstm(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_ZerosLike:
return new ZerosLike(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_MakeTuple:
return new MakeTuple(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_Where:
return new Where(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_ScatterND:
return new ScatterND(const_cast<schema::Primitive *>(primitive));
case schema::PrimitiveType_ConstantOfShape:
return new ConstantOfShape(const_cast<schema::Primitive *>(primitive));
default:
MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromSchemaPrimitive : "
<< schema::EnumNamePrimitiveType(op_type);
break;
}
return nullptr;
......
......@@ -25,10 +25,10 @@ namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Reshape::GetFormat() const { return this->primitive_->value.AsReshape()->format; }
std::vector<long> Reshape::GetShape() const { return this->primitive_->value.AsReshape()->shape; }
std::vector<int64_t> Reshape::GetShape() const { return this->primitive_->value.AsReshape()->shape; }
void Reshape::SetFormat(int format) { this->primitive_->value.AsReshape()->format = (schema::Format)format; }
void Reshape::SetShape(const std::vector<long> &shape) { this->primitive_->value.AsReshape()->shape = shape; }
void Reshape::SetShape(const std::vector<int64_t> &shape) { this->primitive_->value.AsReshape()->shape = shape; }
int Reshape::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
this->primitive_ = new (schema::PrimitiveT);
auto attr = std::make_unique<schema::ReshapeT>();
......@@ -59,13 +59,13 @@ int Reshape::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
#else
int Reshape::GetFormat() const { return this->primitive_->value_as_Reshape()->format(); }
std::vector<long> Reshape::GetShape() const {
std::vector<int64_t> Reshape::GetShape() const {
auto fb_vector = this->primitive_->value_as_Reshape()->shape();
return std::vector<long>(fb_vector->begin(), fb_vector->end());
return std::vector<int64_t>(fb_vector->begin(), fb_vector->end());
}
void Reshape::SetFormat(int format) {}
void Reshape::SetShape(const std::vector<long> &shape) {}
void Reshape::SetShape(const std::vector<int64_t> &shape) {}
#endif
int Reshape::CalNewShape(const tensor::Tensor *in_tensor, std::vector<int> *out_shape) const {
......
......@@ -36,9 +36,9 @@ class Reshape : public PrimitiveC {
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetFormat() const;
std::vector<long> GetShape() const;
std::vector<int64_t> GetShape() const;
void SetFormat(int format);
void SetShape(const std::vector<long> &shape);
void SetShape(const std::vector<int64_t> &shape);
private:
int CalNewShape(const lite::tensor::Tensor *in_tensor, std::vector<int> *out_shape) const;
......
......@@ -21,15 +21,15 @@ namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Resize::GetFormat() const { return this->primitive_->value.AsResize()->format; }
int Resize::GetMethod() const { return this->primitive_->value.AsResize()->method; }
long Resize::GetNewHeight() const { return this->primitive_->value.AsResize()->newHeight; }
long Resize::GetNewWidth() const { return this->primitive_->value.AsResize()->newWidth; }
int64_t Resize::GetNewHeight() const { return this->primitive_->value.AsResize()->newHeight; }
int64_t Resize::GetNewWidth() const { return this->primitive_->value.AsResize()->newWidth; }
bool Resize::GetAlignCorners() const { return this->primitive_->value.AsResize()->alignCorners; }
bool Resize::GetPreserveAspectRatio() const { return this->primitive_->value.AsResize()->preserveAspectRatio; }
void Resize::SetFormat(int format) { this->primitive_->value.AsResize()->format = (schema::Format)format; }
void Resize::SetMethod(int method) { this->primitive_->value.AsResize()->method = (schema::ResizeMethod)method; }
void Resize::SetNewHeight(long new_height) { this->primitive_->value.AsResize()->newHeight = new_height; }
void Resize::SetNewWidth(long new_width) { this->primitive_->value.AsResize()->newWidth = new_width; }
void Resize::SetNewHeight(int64_t new_height) { this->primitive_->value.AsResize()->newHeight = new_height; }
void Resize::SetNewWidth(int64_t new_width) { this->primitive_->value.AsResize()->newWidth = new_width; }
void Resize::SetAlignCorners(bool align_corners) { this->primitive_->value.AsResize()->alignCorners = align_corners; }
void Resize::SetPreserveAspectRatio(bool preserve_aspect_ratio) {
this->primitive_->value.AsResize()->preserveAspectRatio = preserve_aspect_ratio;
......@@ -39,15 +39,15 @@ void Resize::SetPreserveAspectRatio(bool preserve_aspect_ratio) {
int Resize::GetFormat() const { return this->primitive_->value_as_Resize()->format(); }
int Resize::GetMethod() const { return this->primitive_->value_as_Resize()->method(); }
long Resize::GetNewHeight() const { return this->primitive_->value_as_Resize()->newHeight(); }
long Resize::GetNewWidth() const { return this->primitive_->value_as_Resize()->newWidth(); }
int64_t Resize::GetNewHeight() const { return this->primitive_->value_as_Resize()->newHeight(); }
int64_t Resize::GetNewWidth() const { return this->primitive_->value_as_Resize()->newWidth(); }
bool Resize::GetAlignCorners() const { return this->primitive_->value_as_Resize()->alignCorners(); }
bool Resize::GetPreserveAspectRatio() const { return this->primitive_->value_as_Resize()->preserveAspectRatio(); }
void Resize::SetFormat(int format) {}
void Resize::SetMethod(int method) {}
void Resize::SetNewHeight(long new_height) {}
void Resize::SetNewWidth(long new_width) {}
void Resize::SetNewHeight(int64_t new_height) {}
void Resize::SetNewWidth(int64_t new_width) {}
void Resize::SetAlignCorners(bool align_corners) {}
void Resize::SetPreserveAspectRatio(bool preserve_aspect_ratio) {}
#endif
......
......@@ -36,14 +36,14 @@ class Resize : public PrimitiveC {
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetFormat() const;
int GetMethod() const;
long GetNewHeight() const;
long GetNewWidth() const;
int64_t GetNewHeight() const;
int64_t GetNewWidth() const;
bool GetAlignCorners() const;
bool GetPreserveAspectRatio() const;
void SetFormat(int format);
void SetMethod(int method);
void SetNewHeight(long new_height);
void SetNewWidth(long new_width);
void SetNewHeight(int64_t new_height);
void SetNewWidth(int64_t new_width);
void SetAlignCorners(bool align_corners);
void SetPreserveAspectRatio(bool preserve_aspect_ratio);
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册