提交 e65ff9ca 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4784 change ops getter

Merge pull request !4784 from yeyunpeng2020/master_cops_4
......@@ -14,10 +14,10 @@
* limitations under the License.
*/
#include "src/ops/constant_of_shape.h"
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "src/ir/tensor.h"
#include "src/ops/constant_of_shape.h"
namespace mindspore::lite {
namespace {
......@@ -25,9 +25,9 @@ constexpr int kShapeInputNum = 1;
constexpr int kShapeOutputNum = 1;
} // namespace
#ifdef PRIMITIVE_WRITEABLE
int ConstantOfShape::GetValue() const { return this->primitive->value.AsConstantOfShape()->Value; }
float ConstantOfShape::GetValue() const { return this->primitive->value.AsConstantOfShape()->value; }
void ConstantOfShape::SetValue(float value) { this->primitive->value.AsConstantOfShape()->Value = value; }
void ConstantOfShape::SetValue(float value) { this->primitive->value.AsConstantOfShape()->value = value; }
#else
......
......@@ -104,19 +104,18 @@ void Conv2D::SetActivationType(int activation_type) {}
#endif
void Conv2D::ConvInferShape(int input_h, int input_w, int *output_h, int *output_w) {
MS_ASSERT(this->primitive != nullptr);
auto conv2DPrim = this->primitive->value_as_Conv2D();
int kernel_w = conv2DPrim->kernelW();
int kernel_h = conv2DPrim->kernelH();
int stride_w = conv2DPrim->strideW();
int stride_h = conv2DPrim->strideH();
int dilate_w = conv2DPrim->dilateW();
int dilate_h = conv2DPrim->dilateH();
pad_l_ = conv2DPrim->padLeft();
pad_u_ = conv2DPrim->padUp();
pad_d_ = conv2DPrim->padDown();
pad_r_ = conv2DPrim->padRight();
int kernel_w = GetKernelW();
int kernel_h = GetKernelH();
int stride_w = GetStrideW();
int stride_h = GetStrideH();
int dilate_w = GetDilateW();
int dilate_h = GetDilateH();
pad_l_ = GetPadLeft();
pad_u_ = GetPadUp();
pad_d_ = GetPadDown();
pad_r_ = GetPadRight();
if (conv2DPrim->padMode() == schema::PadMode_SAME) {
if (GetPadMode() == schema::PadMode_SAME) {
*output_w = std::ceil(static_cast<float>(input_w) / static_cast<float>(stride_w));
*output_h = std::ceil(static_cast<float>(input_h) / static_cast<float>(stride_h));
auto pad_h_all = ((*output_h - 1) * stride_h + (kernel_h - 1) * dilate_h + 1 - input_h);
......
......@@ -23,7 +23,7 @@ int DepthToSpace::GetBlockSize() const { return this->primitive->value.AsDepthTo
int DepthToSpace::GetFormat() const { return this->primitive->value.AsDepthToSpace()->format; }
void DepthToSpace::SetBlockSize(int block_size) { this->primitive->value.AsDepthToSpace()->blockSize = block_size; }
void DepthToSpace::SetFormat(int format) { this->primitive->value.AsDepthToSpace()->format = format; }
void DepthToSpace::SetFormat(int format) { this->primitive->value.AsDepthToSpace()->format = (schema::Format)format; }
#else
......
......@@ -50,13 +50,12 @@ int Fill::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
if (!GetInferFlag()) {
return RET_OK;
}
auto fill_prim = this->primitive->value_as_Fill();
if (fill_prim == nullptr) {
MS_LOG(ERROR) << "Fill primitive is null!";
return RET_ERROR;
}
std::vector<int> output_shape;
(void)output_shape.insert(output_shape.begin(), fill_prim->dims()->begin(), fill_prim->dims()->end());
for (int i = 0; i < GetDims().size(); i++) {
output_shape.push_back(GetDims()[i]);
}
// (void)output_shape.insert(output_shape.begin(), GetDims().begin(), GetDims().end());
output->set_shape(output_shape);
return RET_OK;
}
......
......@@ -22,13 +22,13 @@ namespace lite {
bool FullConnection::GetHasBias() const { return this->primitive->value.AsFullConnection()->hasBias; }
int FullConnection::GetAxis() const { return this->primitive->value.AsFullConnection()->axis; }
bool FullConnection::GetUseAxis() const { return this->primitive->value.AsFullConnection()->useAxis; }
int FullConnection::GetActivationType() const { return this->primitive->value.AsFullConnection()->activationType(); }
int FullConnection::GetActivationType() const { return this->primitive->value.AsFullConnection()->activationType; }
void FullConnection::SetHasBias(bool has_bias) { this->primitive->value.AsFullConnection()->hasBias = has_bias; }
void FullConnection::SetAxis(int axis) { this->primitive->value.AsFullConnection()->axis = axis; }
void FullConnection::SetUseAxis(bool use_axis) { this->primitive->value.AsFullConnection()->useAxis = use_axis; }
void FullConnection::SetActivationType(int activationType) {
his->primitive->value.AsFullConnection()->activationType = (schema::ActivationType)activationType;
this->primitive->value.AsFullConnection()->activationType = (schema::ActivationType) activationType;
}
#else
......
......@@ -21,7 +21,9 @@ namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Mul::GetActivationType() const { return this->primitive->value.AsMul()->activationType; }
void Mul::SetActivationType(int activation_type) { this->primitive->value.AsMul()->activationType = activation_type; }
void Mul::SetActivationType(int activation_type) {
this->primitive->value.AsMul()->activationType = (schema::ActivationType) activation_type;
}
#else
......
......@@ -24,7 +24,9 @@ int Pad::GetPaddingMode() const { return this->primitive->value.AsPad()->padding
float Pad::GetConstantValue() const { return this->primitive->value.AsPad()->constantValue; }
void Pad::SetPaddings(const std::vector<int> &paddings) { this->primitive->value.AsPad()->paddings = paddings; }
void Pad::SetPaddingMode(int padding_mode) { this->primitive->value.AsPad()->paddingMode = padding_mode; }
void Pad::SetPaddingMode(int padding_mode) {
this->primitive->value.AsPad()->paddingMode = (schema::PaddingMode) padding_mode;
}
void Pad::SetConstantValue(float constant_value) { this->primitive->value.AsPad()->constantValue = constant_value; }
#else
......
......@@ -34,22 +34,22 @@ int Pooling::GetPadLeft() const { return this->primitive->value.AsPooling()->pad
int Pooling::GetPadRight() const { return this->primitive->value.AsPooling()->padRight; }
int Pooling::GetRoundMode() const { return this->primitive->value.AsPooling()->roundMode; }
void Pooling::SetFormat(int format) { this->primitive->value.AsPooling()->format = (schema::Format)format; }
void Pooling::SetFormat(int format) { this->primitive->value.AsPooling()->format = (schema::Format) format; }
void Pooling::SetPoolingMode(int pooling_mode) {
this->primitive->value.AsPooling()->poolingMode = (schema::PoolMode)pooling_mode;
this->primitive->value.AsPooling()->poolingMode = (schema::PoolMode) pooling_mode;
}
void Pooling::SetGlobal(bool global) { this->primitive->value.AsPooling()->global = global; }
void Pooling::SetWindowW(int window_w) { this->primitive->value.AsPooling()->windowW = window_w; }
void Pooling::SetWindowH(int window_h) { this->primitive->value.AsPooling()->windowH = window_h; }
void Pooling::SetStrideW(int stride_w) { this->primitive->value.AsPooling()->strideW = stride_w; }
void Pooling::SetStrideH(int stride_h) { this->primitive->value.AsPooling()->strideH = stride_h; }
void Pooling::SetPadMode(int pad_mode) { this->primitive->value.AsPooling()->padMode = (schema::PadMode)pad_mode; }
void Pooling::SetPadMode(int pad_mode) { this->primitive->value.AsPooling()->padMode = (schema::PadMode) pad_mode; }
void Pooling::SetPadUp(int pad_up) { this->primitive->value.AsPooling()->padUp = pad_up; }
void Pooling::SetPadDown(int pad_down) { this->primitive->value.AsPooling()->padDown = pad_down; }
void Pooling::SetPadLeft(int pad_left) { this->primitive->value.AsPooling()->padLeft = pad_left; }
void Pooling::SetPadRight(int pad_right) { this->primitive->value.AsPooling()->padRight = pad_right; }
void Pooling::SetRoundMode(int round_mode) {
this->primitive->value.AsPooling()->roundMode = (schema::RoundMode)round_mode;
this->primitive->value.AsPooling()->roundMode = (schema::RoundMode) round_mode;
}
#else
......@@ -82,13 +82,13 @@ void Pooling::SetPadLeft(int pad_left) {}
void Pooling::SetPadRight(int pad_right) {}
void Pooling::SetRoundMode(int round_mode) {}
#endif
int Pooling::PadUp() const { return this->pad_u_; }
int Pooling::PadDown() const { return this->pad_d_; }
int Pooling::PadLeft() const { return this->pad_l_; }
int Pooling::PadRight() const { return this->pad_r_; }
#endif
int Pooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
......@@ -102,37 +102,37 @@ int Pooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
}
int input_h = input->shape().at(1);
int input_w = input->shape().at(2);
auto pooling_prim = this->primitive->value_as_Pooling();
MS_ASSERT(pooling_prim != nullptr);
auto window_h = pooling_prim->windowH();
auto window_w = pooling_prim->windowW();
if (pooling_prim->global()) {
auto window_h = GetWindowH();
auto window_w = GetWindowW();
if (GetGlobal()) {
window_h = input_h;
window_w = input_w;
}
int output_h = 0;
int output_w = 0;
pad_l_ = pooling_prim->padLeft();
pad_u_ = pooling_prim->padUp();
pad_d_ = pooling_prim->padDown();
pad_r_ = pooling_prim->padRight();
if (pooling_prim->padMode() == schema::PadMode_SAME) {
output_w = std::ceil(static_cast<float>(input_w) / static_cast<float>(pooling_prim->strideW()));
output_h = std::ceil(static_cast<float>(input_h) / static_cast<float>(pooling_prim->strideH()));
auto pad_h_all = ((output_h - 1) * pooling_prim->strideH() + (window_h - 1) + 1 - input_h);
auto pad_w_all = ((output_w - 1) * pooling_prim->strideW() + (window_w - 1) + 1 - input_w);
pad_l_ = GetPadLeft();
pad_u_ = GetPadUp();
pad_d_ = GetPadDown();
pad_r_ = GetPadRight();
if (GetPadMode() == schema::PadMode_SAME) {
output_w = std::ceil(static_cast<float>(input_w) / static_cast<float>(GetStrideW()));
output_h = std::ceil(static_cast<float>(input_h) / static_cast<float>(GetStrideH()));
auto pad_h_all = ((output_h - 1) * GetStrideH() + (window_h - 1) + 1 - input_h);
auto pad_w_all = ((output_w - 1) * GetStrideW() + (window_w - 1) + 1 - input_w);
pad_u_ = pad_h_all / 2;
pad_d_ = pad_h_all - pad_u_;
pad_l_ = pad_w_all / 2;
pad_r_ = pad_w_all - pad_l_;
} else {
auto round_mode = pooling_prim->roundMode();
auto round_mode = (schema::RoundMode) GetRoundMode();
if (round_mode == schema::RoundMode_FLOOR) {
output_h = std::floor(static_cast<float>(input_h + pad_u_ + pad_d_ - window_h) / pooling_prim->strideH()) + 1;
output_w = std::floor(static_cast<float>(input_w + pad_l_ + pad_r_ - window_w) / pooling_prim->strideW()) + 1;
output_h = std::floor(static_cast<float>(input_h + pad_u_ + pad_d_ - window_h) / GetStrideH()) + 1;
output_w = std::floor(static_cast<float>(input_w + pad_l_ + pad_r_ - window_w) / GetStrideW()) + 1;
} else if (round_mode == schema::RoundMode_CEIL) {
output_h = std::ceil(static_cast<float>(input_h + pad_u_ + pad_d_ - window_h) / pooling_prim->strideH()) + 1;
output_w = std::ceil(static_cast<float>(input_w + pad_l_ + pad_r_ - window_w) / pooling_prim->strideW()) + 1;
output_h = std::ceil(static_cast<float>(input_h + pad_u_ + pad_d_ - window_h) / GetStrideH()) + 1;
output_w = std::ceil(static_cast<float>(input_w + pad_l_ + pad_r_ - window_w) / GetStrideW()) + 1;
} else {
MS_LOG(ERROR) << "unsupported round mode.";
}
......
......@@ -26,7 +26,7 @@ namespace lite {
int Reshape::GetFormat() const { return this->primitive->value.AsReshape()->format; }
std::vector<long> Reshape::GetShape() const { return this->primitive->value.AsReshape()->shape; }
void Reshape::SetFormat(int format) { this->primitive->value.AsReshape()->format = format; }
void Reshape::SetFormat(int format) { this->primitive->value.AsReshape()->format = (schema::Format) format; }
void Reshape::SetShape(const std::vector<long> &shape) { this->primitive->value.AsReshape()->shape = shape; }
#else
......@@ -75,7 +75,7 @@ int Reshape::CalNewShape(const tensor::Tensor *in_tensor, std::vector<int> *out_
}
return RET_OK;
}
template <typename T>
template<typename T>
void CalShape(const T *data, const std::vector<tensor::Tensor *> &inputs, std::vector<int> *out_shape, int shape_size) {
int input_count = inputs[0]->ElementsNum();
int index = 0;
......@@ -103,7 +103,7 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
if (!GetInferFlag()) {
return RET_OK;
}
auto reshape_prim = this->primitive->value_as_Reshape();
MS_ASSERT(reshape_prim != nullptr);
std::vector<int> out_shape;
if (inputs_.size() == kDoubleNum) {
......@@ -117,30 +117,38 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
case kNumberTypeInt8: {
auto data = reinterpret_cast<int8_t *>(shape_tensor->Data());
CalShape<int8_t>(data, inputs_, &out_shape, shape_size);
} break;
}
break;
case kNumberTypeInt32: {
auto data = reinterpret_cast<int32_t *>(shape_tensor->Data());
CalShape<int32_t>(data, inputs_, &out_shape, shape_size);
} break;
}
break;
case kNumberTypeInt64: {
auto data = reinterpret_cast<int64_t *>(shape_tensor->Data());
CalShape<int64_t>(data, inputs_, &out_shape, shape_size);
} break;
}
break;
case kNumberTypeFloat: {
auto data = reinterpret_cast<float *>(shape_tensor->Data());
CalShape<float>(data, inputs_, &out_shape, shape_size);
} break;
}
break;
case kNumberTypeUInt32: {
auto data = reinterpret_cast<uint32_t *>(shape_tensor->Data());
CalShape<uint32_t>(data, inputs_, &out_shape, shape_size);
} break;
}
break;
default: {
MS_LOG(ERROR) << "Reshape weight tensor has unsupported dataType: " << shape_tensor->data_type();
return RET_INFER_ERR;
}
}
} else if (inputs_.size() == kSingleNum) {
std::copy(reshape_prim->shape()->begin(), reshape_prim->shape()->end(), std::back_inserter(out_shape));
for (int i = 0; i < GetShape().size(); ++i) {
out_shape.push_back(GetShape()[i]);
}
// std::copy(GetShape().begin(), GetShape().end(), std::back_inserter(out_shape));
} else {
MS_LOG(ERROR) << "inputs tensor size invalid.";
return RET_INFER_ERR;
......
......@@ -30,7 +30,7 @@ int SliceOp::GetFormat() const { return this->primitive->value.AsSlice()->format
std::vector<int> SliceOp::GetBegin() const { return this->primitive->value.AsSlice()->begin; }
std::vector<int> SliceOp::GetSize() const { return this->primitive->value.AsSlice()->size; }
void SliceOp::SetFormat(int format) { this->primitive->value.AsSlice()->format = format; }
void SliceOp::SetFormat(int format) { this->primitive->value.AsSlice()->format = (schema::Format)format; }
void SliceOp::SetBegin(const std::vector<int> &begin) { this->primitive->value.AsSlice()->begin = begin; }
void SliceOp::SetSize(const std::vector<int> &size) { this->primitive->value.AsSlice()->size = size; }
......
......@@ -24,7 +24,7 @@ int SpaceToDepth::GetBlockSize() const { return this->primitive->value.AsSpaceTo
int SpaceToDepth::GetFormat() const { return this->primitive->value.AsSpaceToDepth()->format; }
void SpaceToDepth::SetBlockSize(int block_size) { this->primitive->value.AsSpaceToDepth()->blockSize = block_size; }
void SpaceToDepth::SetFormat(int format) { this->primitive->value.AsSpaceToDepth()->format = format; }
void SpaceToDepth::SetFormat(int format) { this->primitive->value.AsSpaceToDepth()->format = (schema::Format)format; }
#else
......
......@@ -50,7 +50,6 @@ int Split::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto spilt_prim = this->primitive->value_as_Split();
MS_ASSERT(spilt_prim != nullptr);
if (inputs_.size() != kSplitInputNum) {
MS_LOG(ERROR) << "inputs number is not equal to " << kSplitInputNum;
......@@ -61,7 +60,7 @@ int Split::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
MS_LOG(ERROR) << "output null pointer dereferencing.";
return RET_ERROR;
}
int number_split = spilt_prim->numberSplit();
int number_split = GetNumberSplit();
if (static_cast<int>(outputs_.size()) != number_split) {
MS_LOG(ERROR) << "outputs number is not equal to " << number_split;
return RET_ERROR;
......@@ -73,10 +72,12 @@ int Split::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
if (!GetInferFlag()) {
return RET_OK;
}
int split_dim = spilt_prim->splitDim();
int split_dim = GetSplitDim();
std::vector<int> input_shape = input->shape();
std::vector<int> size_split;
size_split.insert(size_split.begin(), spilt_prim->sizeSplits()->begin(), spilt_prim->sizeSplits()->end());
for (int i = 0; i < GetSizeSplits().size(); ++i) {
size_split.push_back(GetSizeSplits()[i]);
}
for (int i = 0; i < number_split; ++i) {
std::vector<int> output_shape;
output_shape.insert(output_shape.begin(), input_shape.begin(), input_shape.end());
......
......@@ -24,6 +24,10 @@ std::vector<int> Tile::GetMultiples() const { return this->primitive->value.AsTi
void Tile::SetMultiples(const std::vector<int> &multiples) { this->primitive->value.AsTile()->multiples = multiples; }
std::vector<int> Tile::GetDims() const { return this->primitive->value.AsTile()->multiples; }
void Tile::SetDims(const std::vector<int> &dims) { this->primitive->value.AsTile()->dims = dims; }
#else
std::vector<int> Tile::GetMultiples() const {
......@@ -32,6 +36,13 @@ std::vector<int> Tile::GetMultiples() const {
}
void Tile::SetMultiples(const std::vector<int> &multiples) {}
std::vector<int> Tile::GetDims() const {
auto fb_vector = this->primitive->value_as_Tile()->dims();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
void Tile::SetDims(const std::vector<int> &dims) {}
#endif
int Tile::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
......@@ -45,11 +56,14 @@ int Tile::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
if (!GetInferFlag()) {
return RET_OK;
}
auto tile_prim = this->primitive->value_as_Tile();
MS_ASSERT(tile_prim != nullptr);
std::vector<int> out_shape;
std::vector<int> multiples;
std::copy(tile_prim->multiples()->begin(), tile_prim->multiples()->end(), std::back_inserter(multiples));
for (int i = 0; i < GetMultiples().size(); ++i) {
multiples.push_back(GetMultiples()[i]);
}
// std::copy(GetMultiples().begin(), GetMultiples().end(), std::back_inserter(multiples));
for (size_t i = 0; i < input->shape().size(); ++i) {
int tmp = input->shape()[i] * multiples[i];
out_shape.push_back(tmp);
......
......@@ -37,6 +37,8 @@ class Tile : public PrimitiveC {
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
std::vector<int> GetMultiples() const;
void SetMultiples(const std::vector<int> &multiples);
std::vector<int> GetDims() const;
void SetDims(const std::vector<int> &dims);
};
} // namespace lite
} // namespace mindspore
......
......@@ -52,14 +52,17 @@ int Transpose::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<ten
}
MS_ASSERT(inputs_.size() == kSingleNum);
MS_ASSERT(outputs_.size() == kSingleNum);
auto transpore_prim = this->primitive->value_as_Transpose();
int conjugate = transpore_prim->conjugate();
int conjugate = GetConjugate();
if (conjugate) {
MS_LOG(ERROR) << "Transpose conjugate is not support currently";
return RET_ERROR;
}
std::vector<int> perm;
perm.insert(perm.begin(), transpore_prim->perm()->begin(), transpore_prim->perm()->end());
for (int i = 0; i < GetPerm().size(); i++) {
perm.push_back(GetPerm()[i]);
}
// perm.insert(perm.begin(), GetPerm().begin(), GetPerm().end());
std::vector<int> in_shape = input->shape();
std::vector<int> out_shape;
out_shape.resize(perm.size());
......
......@@ -988,7 +988,7 @@ OpParameter *PopulateSliceParameter(const mindspore::lite::PrimitiveC *primitive
}
slice_param->param_length_ = static_cast<int32_t>(param_begin.size());
for (int32_t i = 0; i < slice_param->param_length_; ++i) {
slice_param->begin_[i] = param_begin[1];
slice_param->begin_[i] = param_begin[i];
slice_param->size_[i] = param_size[i];
}
return reinterpret_cast<OpParameter *>(slice_param);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册