提交 6c1eb3c2 编写于 作者: Y yeyunpeng

Add MS_DECLARE_PARENT and UnPackAttr judge null and change setter position

上级 80d570f0
......@@ -19,11 +19,6 @@
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/arithmetic_self.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif
#ifndef LITE_MINDSPORE_LITE_C_OPS_ABS_H_
#define LITE_MINDSPORE_LITE_C_OPS_ABS_H_
......@@ -33,6 +28,7 @@ namespace lite {
class Abs : public ArithmeticSelf {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Abs, ArithmeticSelf);
Abs() = default;
explicit Abs(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
#else
......
......@@ -27,7 +27,18 @@ void Activation::SetType(int type) { this->primitive_->value.AsActivation()->typ
void Activation::SetAlpha(float alpha) { this->primitive_->value.AsActivation()->alpha = alpha; }
int Activation::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
this->primitive_ = new (schema::PrimitiveT);
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Activation;
}
if (this->primitive_->value.type != schema::PrimitiveType_Activation) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
auto attr = std::make_unique<schema::ActivationT>();
if (prim.name() == "ReLU") {
attr->type = schema::ActivationType_RELU;
......@@ -36,18 +47,17 @@ int Activation::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
} else if (prim.name() == "ReLU6") {
attr->type = schema::ActivationType_RELU6;
}
this->primitive_->value.type = schema::PrimitiveType_Activation;
this->primitive_->value.value = attr.release();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
return RET_OK;
}
#else
int Activation::GetType() const { return this->primitive_->value_as_Activation()->type(); }
float Activation::GetAlpha() const { return this->primitive_->value_as_Activation()->alpha(); }
void Activation::SetType(int type) {}
void Activation::SetAlpha(float alpha) {}
#endif
} // namespace lite
} // namespace mindspore
......@@ -27,16 +27,17 @@ namespace lite {
class Activation : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Activation, PrimitiveC);
Activation() = default;
explicit Activation(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
void SetType(int type);
void SetAlpha(float alpha);
#else
explicit Activation(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int GetType() const;
float GetAlpha() const;
void SetType(int type);
void SetAlpha(float alpha);
};
} // namespace lite
} // namespace mindspore
......
......@@ -29,7 +29,6 @@ void ActivationGrad::SetType(int type) {
int ActivationGrad::GetType() const { return this->primitive_->value_as_ActivationGrad()->type(); }
void ActivationGrad::SetType(int type) {}
#endif
} // namespace lite
} // namespace mindspore
......@@ -28,13 +28,14 @@ namespace lite {
class ActivationGrad : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(ActivationGrad, PrimitiveC);
ActivationGrad() = default;
explicit ActivationGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetType(int type);
#else
explicit ActivationGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int GetType() const;
void SetType(int type);
};
} // namespace lite
} // namespace mindspore
......
......@@ -36,7 +36,7 @@ int Add::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs
this->primitive_->value.type = schema::PrimitiveType_Add;
}
if (this->primitive_->value.type != schema::PrimitiveType_Add) {
MS_LOG(ERROR) << "Primitive type should be add";
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
......@@ -53,7 +53,6 @@ int Add::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs
int Add::GetActivationType() const { return this->primitive_->value_as_Add()->activationType(); }
void Add::SetActivationType(int activation_type) {}
#endif
} // namespace lite
} // namespace mindspore
......@@ -22,25 +22,21 @@
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/arithmetic.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif
namespace mindspore {
namespace lite {
class Add : public Arithmetic {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Add, Arithmetic);
Add() = default;
explicit Add(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
void SetActivationType(int activation_type);
#else
explicit Add(schema::Primitive *primitive) : Arithmetic(primitive) {}
#endif
int GetActivationType() const;
void SetActivationType(int activation_type);
};
} // namespace lite
} // namespace mindspore
......
......@@ -27,7 +27,6 @@ void AddN::SetN(int n) { this->primitive_->value.AsAddN()->N = n; }
int AddN::GetN() const { return this->primitive_->value_as_AddN()->N(); }
void AddN::SetN(int n) {}
#endif
namespace {
......
......@@ -28,14 +28,15 @@ namespace lite {
class AddN : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(AddN, PrimitiveC);
AddN() = default;
explicit AddN(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetN(int n);
#else
explicit AddN(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetN() const;
void SetN(int n);
};
} // namespace lite
} // namespace mindspore
......
......@@ -39,11 +39,6 @@ int ArgMax::GetTopK() const { return this->primitive_->value_as_ArgMax()->topK()
bool ArgMax::GetKeepDims() const { return this->primitive_->value_as_ArgMax()->keepDims(); }
int ArgMax::GetAxisType() const { return this->primitive_->value_as_ArgMax()->axisType(); }
void ArgMax::SetAxis(int axis) {}
void ArgMax::SetOutMaxValue(bool out_max_value) {}
void ArgMax::SetTopK(int top_k) {}
void ArgMax::SetKeepDims(bool keep_dims) {}
void ArgMax::SetAxisType(int axis_type) {}
#endif
int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
......
......@@ -28,8 +28,14 @@ namespace lite {
class ArgMax : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(ArgMax, PrimitiveC);
ArgMax() = default;
explicit ArgMax(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetAxis(int axis);
void SetOutMaxValue(bool out_max_value);
void SetTopK(int top_k);
void SetKeepDims(bool keep_dims);
void SetAxisType(int axis_type);
#else
explicit ArgMax(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
......@@ -39,11 +45,6 @@ class ArgMax : public PrimitiveC {
int GetTopK() const;
bool GetKeepDims() const;
int GetAxisType() const;
void SetAxis(int axis);
void SetOutMaxValue(bool out_max_value);
void SetTopK(int top_k);
void SetKeepDims(bool keep_dims);
void SetAxisType(int axis_type);
};
} // namespace lite
} // namespace mindspore
......
......@@ -39,11 +39,6 @@ int ArgMin::GetTopK() const { return this->primitive_->value_as_ArgMin()->topK()
bool ArgMin::GetKeepDims() const { return this->primitive_->value_as_ArgMin()->keepDims(); }
int ArgMin::GetAxisType() const { return this->primitive_->value_as_ArgMin()->axisType(); }
void ArgMin::SetAxis(int axis) {}
void ArgMin::SetOutMaxValue(bool out_max_value) {}
void ArgMin::SetTopK(int top_k) {}
void ArgMin::SetKeepDims(bool keep_dims) {}
void ArgMin::SetAxisType(int axis_type) {}
#endif
int ArgMin::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
......
......@@ -28,8 +28,14 @@ namespace lite {
class ArgMin : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(ArgMin, PrimitiveC);
ArgMin() = default;
explicit ArgMin(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetAxis(int axis);
void SetOutMaxValue(bool out_max_value);
void SetTopK(int top_k);
void SetKeepDims(bool keep_dims);
void SetAxisType(int axis_type);
#else
explicit ArgMin(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
......@@ -39,11 +45,6 @@ class ArgMin : public PrimitiveC {
int GetTopK() const;
bool GetKeepDims() const;
int GetAxisType() const;
void SetAxis(int axis);
void SetOutMaxValue(bool out_max_value);
void SetTopK(int top_k);
void SetKeepDims(bool keep_dims);
void SetAxisType(int axis_type);
};
} // namespace lite
} // namespace mindspore
......
......@@ -28,6 +28,7 @@ namespace lite {
class Arithmetic : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Arithmetic, PrimitiveC);
Arithmetic() = default;
explicit Arithmetic(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
......
......@@ -25,6 +25,7 @@ namespace lite {
class ArithmeticSelf : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(ArithmeticSelf, PrimitiveC);
ArithmeticSelf() = default;
explicit ArithmeticSelf(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
......
......@@ -24,11 +24,27 @@ float BatchNorm::GetEpsilon() const { return this->primitive_->value.AsBatchNorm
void BatchNorm::SetEpsilon(float epsilon) { this->primitive_->value.AsBatchNorm()->epsilon = epsilon; }
int BatchNorm::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
this->primitive_ = new (schema::PrimitiveT);
auto attr = std::make_unique<schema::FusedBatchNormT>();
attr->epsilon = GetValue<float>(prim.GetAttr("epsilon"));
this->primitive_->value.type = schema::PrimitiveType_FusedBatchNorm;
this->primitive_->value.value = attr.release();
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_FusedBatchNorm;
}
if (this->primitive_->value.type != schema::PrimitiveType_FusedBatchNorm) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::FusedBatchNormT();
attr->epsilon = GetValue<float>(prim.GetAttr("epsilon"));
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
}
return RET_OK;
}
......@@ -36,7 +52,6 @@ int BatchNorm::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &
float BatchNorm::GetEpsilon() const { return this->primitive_->value_as_BatchNorm()->epsilon(); }
void BatchNorm::SetEpsilon(float epsilon) {}
#endif
} // namespace lite
} // namespace mindspore
......@@ -28,14 +28,15 @@ namespace lite {
class BatchNorm : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(BatchNorm, PrimitiveC);
BatchNorm() = default;
explicit BatchNorm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
void SetEpsilon(float epsilon);
#else
explicit BatchNorm(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
float GetEpsilon() const;
void SetEpsilon(float epsilon);
};
} // namespace lite
} // namespace mindspore
......
......@@ -42,8 +42,6 @@ std::vector<int> BatchToSpace::GetCrops() const {
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
void BatchToSpace::SetBlockShape(const std::vector<int> &block_shape) {}
void BatchToSpace::SetCrops(const std::vector<int> &crops) {}
#endif
namespace {
constexpr int kBatchToSpaceOutputNum = 1;
......
......@@ -28,16 +28,17 @@ namespace lite {
class BatchToSpace : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(BatchToSpace, PrimitiveC);
BatchToSpace() = default;
explicit BatchToSpace(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetBlockShape(const std::vector<int> &block_shape);
void SetCrops(const std::vector<int> &crops);
#else
explicit BatchToSpace(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
std::vector<int> GetBlockShape() const;
std::vector<int> GetCrops() const;
void SetBlockShape(const std::vector<int> &block_shape);
void SetCrops(const std::vector<int> &crops);
};
} // namespace lite
} // namespace mindspore
......
......@@ -25,12 +25,31 @@ std::vector<int> BiasAdd::GetAxis() const { return this->primitive_->value.AsBia
void BiasAdd::SetAxis(const std::vector<int> &axis) { this->primitive_->value.AsBiasAdd()->axis = axis; }
int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
this->primitive_ = new (schema::PrimitiveT);
auto attr = std::make_unique<schema::BiasAddT>();
attr->axis = {0};
this->primitive_->value.type = schema::PrimitiveType_BiasAdd;
this->primitive_->value.value = attr.release();
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_BiasAdd;
}
if (this->primitive_->value.type != schema::PrimitiveType_BiasAdd) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::BiasAddT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
attr->axis = {0};
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
......@@ -41,7 +60,6 @@ std::vector<int> BiasAdd::GetAxis() const {
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
void BiasAdd::SetAxis(const std::vector<int> &axis) {}
#endif
} // namespace lite
} // namespace mindspore
......@@ -28,14 +28,15 @@ namespace lite {
class BiasAdd : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(BiasAdd, PrimitiveC);
BiasAdd() = default;
explicit BiasAdd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
void SetAxis(const std::vector<int> &axis);
#else
explicit BiasAdd(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
std::vector<int> GetAxis() const;
void SetAxis(const std::vector<int> &axis);
};
} // namespace lite
} // namespace mindspore
......
......@@ -30,7 +30,6 @@ std::vector<int> BiasGrad::GetAxis() const {
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
void BiasGrad::SetAxis(const std::vector<int> &axis) {}
#endif
} // namespace lite
} // namespace mindspore
......@@ -28,13 +28,15 @@ namespace lite {
class BiasGrad : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(BiasGrad, PrimitiveC);
BiasGrad() = default;
explicit BiasGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetAxis(const std::vector<int> &axis);
#else
explicit BiasGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
std::vector<int> GetAxis() const;
void SetAxis(const std::vector<int> &axis);
};
} // namespace lite
} // namespace mindspore
......
......@@ -30,8 +30,6 @@ void BNGradInput::SetChannels(int channels) { this->primitive_->value.AsBNGradIn
float BNGradInput::GetEps() const { return this->primitive_->value_as_BNGradInput()->eps(); }
int BNGradInput::GetChannels() const { return this->primitive_->value_as_BNGradInput()->channels(); }
void BNGradInput::SetEps(float eps) {}
void BNGradInput::SetChannels(int channels) {}
#endif
} // namespace lite
} // namespace mindspore
......@@ -28,15 +28,16 @@ namespace lite {
class BNGradInput : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(BNGradInput, PrimitiveC);
BNGradInput() = default;
explicit BNGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetEps(float eps);
void SetChannels(int channels);
#else
explicit BNGradInput(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
float GetEps() const;
int GetChannels() const;
void SetEps(float eps);
void SetChannels(int channels);
};
} // namespace lite
} // namespace mindspore
......
......@@ -32,7 +32,6 @@ std::vector<int> BroadcastTo::GetDstShape() const {
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
void BroadcastTo::SetDstShape(const std::vector<int> &dst_shape) {}
#endif
namespace {
constexpr int kBroadcastToInputNum = 1;
......
......@@ -28,14 +28,16 @@ namespace lite {
class BroadcastTo : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(BroadcastTo, PrimitiveC);
BroadcastTo() = default;
explicit BroadcastTo(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetDstShape(const std::vector<int> &dst_shape);
#else
explicit BroadcastTo(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
std::vector<int> GetDstShape() const;
void SetDstShape(const std::vector<int> &dst_shape);
};
} // namespace lite
} // namespace mindspore
......
......@@ -29,7 +29,6 @@ void CaffePReLU::SetChannelShared(bool channel_shared) {
bool CaffePReLU::GetChannelShared() const { return this->primitive_->value_as_CaffePReLU()->channelShared(); }
void CaffePReLU::SetChannelShared(bool channel_shared) {}
#endif
} // namespace lite
} // namespace mindspore
......@@ -28,13 +28,15 @@ namespace lite {
class CaffePReLU : public Activation {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(CaffePReLU, Activation);
CaffePReLU() = default;
explicit CaffePReLU(schema::PrimitiveT *primitive) : Activation(primitive) {}
void SetChannelShared(bool channel_shared);
#else
explicit CaffePReLU(schema::Primitive *primitive) : Activation(primitive) {}
#endif
bool GetChannelShared() const;
void SetChannelShared(bool channel_shared);
};
} // namespace lite
} // namespace mindspore
......
......@@ -30,8 +30,6 @@ void Cast::SetDstT(int dst_t) { this->primitive_->value.AsCast()->dstT = dst_t;
int Cast::GetSrcT() const { return this->primitive_->value_as_Cast()->srcT(); }
int Cast::GetDstT() const { return this->primitive_->value_as_Cast()->dstT(); }
void Cast::SetSrcT(int src_t) {}
void Cast::SetDstT(int dst_t) {}
#endif
int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
......
......@@ -28,16 +28,17 @@ namespace lite {
class Cast : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Cast, PrimitiveC);
Cast() = default;
explicit Cast(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetSrcT(int src_t);
void SetDstT(int dst_t);
#else
explicit Cast(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetSrcT() const;
int GetDstT() const;
void SetSrcT(int src_t);
void SetDstT(int dst_t);
};
} // namespace lite
} // namespace mindspore
......
......@@ -20,14 +20,15 @@
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/arithmetic_self.h"
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class Ceil : public ArithmeticSelf {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Ceil, ArithmeticSelf);
Ceil() = default;
explicit Ceil(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
#else
......
......@@ -30,8 +30,6 @@ void Clip::SetMin(float min) { this->primitive_->value.AsClip()->min = min; }
float Clip::GetMax() const { return this->primitive_->value_as_Clip()->max(); }
float Clip::GetMin() const { return this->primitive_->value_as_Clip()->min(); }
void Clip::SetMax(float max) {}
void Clip::SetMin(float min) {}
#endif
} // namespace lite
} // namespace mindspore
......@@ -28,15 +28,16 @@ namespace lite {
class Clip : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Clip, PrimitiveC);
Clip() = default;
explicit Clip(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetMax(float max);
void SetMin(float min);
#else
explicit Clip(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
float GetMax() const;
float GetMin() const;
void SetMax(float max);
void SetMin(float min);
};
} // namespace lite
} // namespace mindspore
......
......@@ -30,12 +30,32 @@ void Concat::SetAxis(int axis) { this->primitive_->value.AsConcat()->axis = axis
void Concat::SetN(int n) { this->primitive_->value.AsConcat()->n = n; }
int Concat::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
this->primitive_ = new (schema::PrimitiveT);
auto attr = std::make_unique<schema::ConcatT>();
auto prim_axis = GetValue<int>(prim.GetAttr("axis"));
attr->axis = prim_axis;
this->primitive_->value.type = schema::PrimitiveType_Concat;
this->primitive_->value.value = attr.release();
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Concat;
}
if (this->primitive_->value.type != schema::PrimitiveType_Concat) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::ConcatT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
auto prim_axis = GetValue<int>(prim.GetAttr("axis"));
attr->axis = prim_axis;
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
......@@ -44,8 +64,6 @@ int Concat::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
int Concat::GetAxis() const { return this->primitive_->value_as_Concat()->axis(); }
int Concat::GetN() const { return this->primitive_->value_as_Concat()->n(); }
void Concat::SetAxis(int axis) {}
void Concat::SetN(int n) {}
#endif
namespace {
......
......@@ -28,17 +28,18 @@ namespace lite {
class Concat : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Concat, PrimitiveC);
Concat() = default;
explicit Concat(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
void SetAxis(int axis);
void SetN(int n);
#else
explicit Concat(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetAxis() const;
int GetN() const;
void SetAxis(int axis);
void SetN(int n);
};
} // namespace lite
} // namespace mindspore
......
......@@ -33,7 +33,6 @@ void ConstantOfShape::SetValue(float value) { this->primitive_->value.AsConstant
float ConstantOfShape::GetValue() const { return this->primitive_->value_as_ConstantOfShape()->value(); }
void ConstantOfShape::SetValue(float value) {}
#endif
int ConstantOfShape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
......
......@@ -28,14 +28,15 @@ namespace lite {
class ConstantOfShape : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(ConstantOfShape, PrimitiveC);
ConstantOfShape() = default;
explicit ConstantOfShape(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetValue(float value);
#else
explicit ConstantOfShape(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
float GetValue() const;
void SetValue(float value);
};
} // namespace lite
} // namespace mindspore
......
......@@ -19,7 +19,6 @@
#include <memory>
#include "include/errorcode.h"
#include "utils/log_adapter.h"
#include "src/ir/tensor.h"
#ifdef PRIMITIVE_WRITEABLE
#include "tools/converter/quantizer/quantize_util.h"
#endif
......@@ -309,8 +308,18 @@ void Conv2D::PopulaterQuantParam(const Primitive &prim,
}
int Conv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
this->primitive_ = new (schema::PrimitiveT);
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Conv2D;
}
if (this->primitive_->value.type != schema::PrimitiveType_Conv2D) {
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
return RET_ERROR;
}
int group = GetValue<int>(prim.GetAttr("group"));
if (group > 1) {
PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs);
......@@ -348,23 +357,6 @@ int Conv2D::GetDilateH() const { return this->primitive_->value_as_Conv2D()->dil
bool Conv2D::GetHasBias() const { return this->primitive_->value_as_Conv2D()->hasBias(); }
int Conv2D::GetActivationType() const { return this->primitive_->value_as_Conv2D()->activationType(); }
void Conv2D::SetFormat(int format) {}
void Conv2D::SetGroup(int group) {}
void Conv2D::SetChannelIn(int channel_in) {}
void Conv2D::SetChannelOut(int channel_out) {}
void Conv2D::SetKernelW(int kernel_w) {}
void Conv2D::SetKernelH(int kernel_h) {}
void Conv2D::SetStrideW(int stride_w) {}
void Conv2D::SetStrideH(int stride_h) {}
void Conv2D::SetPadMode(int pad_mode) {}
void Conv2D::SetPadUp(int pad_up) {}
void Conv2D::SetPadDown(int pad_down) {}
void Conv2D::SetPadLeft(int pad_left) {}
void Conv2D::SetPadRight(int pad_right) {}
void Conv2D::SetDilateW(int dilate_w) {}
void Conv2D::SetDilateH(int dilate_h) {}
void Conv2D::SetHasBias(bool has_bias) {}
void Conv2D::SetActivationType(int activation_type) {}
#endif
void Conv2D::ConvInferShape(int input_h, int input_w, int *output_h, int *output_w) {
MS_ASSERT(this->primitive_ != nullptr);
......
......@@ -28,12 +28,30 @@ namespace mindspore {
namespace lite {
class Conv2D : public PrimitiveC {
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Conv2D, PrimitiveC);
public:
Conv2D() = default;
explicit Conv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
void SetFormat(int format);
void SetGroup(int group);
void SetChannelIn(int channel_in);
void SetChannelOut(int channel_out);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
private:
void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group,
......@@ -72,23 +90,6 @@ class Conv2D : public PrimitiveC {
int GetDilateH() const;
bool GetHasBias() const;
int GetActivationType() const;
void SetFormat(int format);
void SetGroup(int group);
void SetChannelIn(int channel_in);
void SetChannelOut(int channel_out);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
protected:
void ConvInferShape(int input_h, int input_w, int *output_h, int *output_w);
......
......@@ -89,23 +89,6 @@ int Conv2DGradFilter::GetActivationType() const {
return this->primitive_->value_as_Conv2DGradFilter()->activationType();
}
void Conv2DGradFilter::SetFormat(int format) {}
void Conv2DGradFilter::SetGroup(int group) {}
void Conv2DGradFilter::SetChannelIn(int channel_in) {}
void Conv2DGradFilter::SetChannelOut(int channel_out) {}
void Conv2DGradFilter::SetKernelW(int kernel_w) {}
void Conv2DGradFilter::SetKernelH(int kernel_h) {}
void Conv2DGradFilter::SetStrideW(int stride_w) {}
void Conv2DGradFilter::SetStrideH(int stride_h) {}
void Conv2DGradFilter::SetPadMode(int pad_mode) {}
void Conv2DGradFilter::SetPadUp(int pad_up) {}
void Conv2DGradFilter::SetPadDown(int pad_down) {}
void Conv2DGradFilter::SetPadLeft(int pad_left) {}
void Conv2DGradFilter::SetPadRight(int pad_right) {}
void Conv2DGradFilter::SetDilateW(int dilate_w) {}
void Conv2DGradFilter::SetDilateH(int dilate_h) {}
void Conv2DGradFilter::SetHasBias(bool has_bias) {}
void Conv2DGradFilter::SetActivationType(int activation_type) {}
#endif
} // namespace lite
} // namespace mindspore
......@@ -28,8 +28,26 @@ namespace lite {
class Conv2DGradFilter : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Conv2DGradFilter, PrimitiveC);
Conv2DGradFilter() = default;
explicit Conv2DGradFilter(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetFormat(int format);
void SetGroup(int group);
void SetChannelIn(int channel_in);
void SetChannelOut(int channel_out);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
#else
explicit Conv2DGradFilter(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
......@@ -50,23 +68,6 @@ class Conv2DGradFilter : public PrimitiveC {
int GetDilateH() const;
bool GetHasBias() const;
int GetActivationType() const;
void SetFormat(int format);
void SetGroup(int group);
void SetChannelIn(int channel_in);
void SetChannelOut(int channel_out);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
};
} // namespace lite
} // namespace mindspore
......
......@@ -87,23 +87,6 @@ int Conv2DGradInput::GetActivationType() const {
return this->primitive_->value_as_Conv2DGradInput()->activationType();
}
void Conv2DGradInput::SetFormat(int format) {}
void Conv2DGradInput::SetGroup(int group) {}
void Conv2DGradInput::SetChannelIn(int channel_in) {}
void Conv2DGradInput::SetChannelOut(int channel_out) {}
void Conv2DGradInput::SetKernelW(int kernel_w) {}
void Conv2DGradInput::SetKernelH(int kernel_h) {}
void Conv2DGradInput::SetStrideW(int stride_w) {}
void Conv2DGradInput::SetStrideH(int stride_h) {}
void Conv2DGradInput::SetPadMode(int pad_mode) {}
void Conv2DGradInput::SetPadUp(int pad_up) {}
void Conv2DGradInput::SetPadDown(int pad_down) {}
void Conv2DGradInput::SetPadLeft(int pad_left) {}
void Conv2DGradInput::SetPadRight(int pad_right) {}
void Conv2DGradInput::SetDilateW(int dilate_w) {}
void Conv2DGradInput::SetDilateH(int dilate_h) {}
void Conv2DGradInput::SetHasBias(bool has_bias) {}
void Conv2DGradInput::SetActivationType(int activation_type) {}
#endif
} // namespace lite
} // namespace mindspore
......@@ -28,8 +28,26 @@ namespace lite {
class Conv2DGradInput : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Conv2DGradInput, PrimitiveC);
Conv2DGradInput() = default;
explicit Conv2DGradInput(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetFormat(int format);
void SetGroup(int group);
void SetChannelIn(int channel_in);
void SetChannelOut(int channel_out);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
#else
explicit Conv2DGradInput(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
......@@ -50,23 +68,6 @@ class Conv2DGradInput : public PrimitiveC {
int GetDilateH() const;
bool GetHasBias() const;
int GetActivationType() const;
void SetFormat(int format);
void SetGroup(int group);
void SetChannelIn(int channel_in);
void SetChannelOut(int channel_out);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
};
} // namespace lite
} // namespace mindspore
......
......@@ -33,8 +33,6 @@ std::vector<int64_t> Crop::GetOffsets() const {
return std::vector<int64_t>(fb_vector->begin(), fb_vector->end());
}
void Crop::SetAxis(int64_t axis) {}
void Crop::SetOffsets(const std::vector<int64_t> &offsets) {}
#endif
namespace {
constexpr int kCropOutputNum = 1;
......
......@@ -28,16 +28,17 @@ namespace lite {
class Crop : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Crop, PrimitiveC);
Crop() = default;
explicit Crop(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetAxis(int64_t axis);
void SetOffsets(const std::vector<int64_t> &offsets);
#else
explicit Crop(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int64_t GetAxis() const;
std::vector<int64_t> GetOffsets() const;
void SetAxis(int64_t axis);
void SetOffsets(const std::vector<int64_t> &offsets);
};
} // namespace lite
} // namespace mindspore
......
......@@ -77,23 +77,6 @@ int DeConv2D::GetDilateH() const { return this->primitive_->value_as_DeConv2D()-
bool DeConv2D::GetHasBias() const { return this->primitive_->value_as_DeConv2D()->hasBias(); }
int DeConv2D::GetActivationType() const { return this->primitive_->value_as_DeConv2D()->activationType(); }
void DeConv2D::SetFormat(int format) {}
void DeConv2D::SetGroup(int group) {}
void DeConv2D::SetChannelIn(int channel_in) {}
void DeConv2D::SetChannelOut(int channel_out) {}
void DeConv2D::SetKernelW(int kernel_w) {}
void DeConv2D::SetKernelH(int kernel_h) {}
void DeConv2D::SetStrideW(int stride_w) {}
void DeConv2D::SetStrideH(int stride_h) {}
void DeConv2D::SetPadMode(int pad_mode) {}
void DeConv2D::SetPadUp(int pad_up) {}
void DeConv2D::SetPadDown(int pad_down) {}
void DeConv2D::SetPadLeft(int pad_left) {}
void DeConv2D::SetPadRight(int pad_right) {}
void DeConv2D::SetDilateW(int dilate_w) {}
void DeConv2D::SetDilateH(int dilate_h) {}
void DeConv2D::SetHasBias(bool has_bias) {}
void DeConv2D::SetActivationType(int activation_type) {}
#endif
int DeConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);
......
......@@ -28,8 +28,26 @@ namespace lite {
class DeConv2D : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(DeConv2D, PrimitiveC);
DeConv2D() = default;
explicit DeConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetFormat(int format);
void SetGroup(int group);
void SetChannelIn(int channel_in);
void SetChannelOut(int channel_out);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
#else
explicit DeConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
......@@ -51,23 +69,6 @@ class DeConv2D : public PrimitiveC {
int GetDilateH() const;
bool GetHasBias() const;
int GetActivationType() const;
void SetFormat(int format);
void SetGroup(int group);
void SetChannelIn(int channel_in);
void SetChannelOut(int channel_out);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
int PadUp() const { return this->pad_u_; }
int PadDown() const { return this->pad_d_; }
......
......@@ -92,22 +92,6 @@ int DeDepthwiseConv2D::GetActivationType() const {
return this->primitive_->value_as_DeDepthwiseConv2D()->activationType();
}
void DeDepthwiseConv2D::SetFormat(int format) {}
void DeDepthwiseConv2D::SetChannelIn(int channel_in) {}
void DeDepthwiseConv2D::SetChannelMultiplier(int channel_multiplier) {}
void DeDepthwiseConv2D::SetKernelW(int kernel_w) {}
void DeDepthwiseConv2D::SetKernelH(int kernel_h) {}
void DeDepthwiseConv2D::SetStrideW(int stride_w) {}
void DeDepthwiseConv2D::SetStrideH(int stride_h) {}
void DeDepthwiseConv2D::SetPadMode(int pad_mode) {}
void DeDepthwiseConv2D::SetPadUp(int pad_up) {}
void DeDepthwiseConv2D::SetPadDown(int pad_down) {}
void DeDepthwiseConv2D::SetPadLeft(int pad_left) {}
void DeDepthwiseConv2D::SetPadRight(int pad_right) {}
void DeDepthwiseConv2D::SetDilateW(int dilate_w) {}
void DeDepthwiseConv2D::SetDilateH(int dilate_h) {}
void DeDepthwiseConv2D::SetHasBias(bool has_bias) {}
void DeDepthwiseConv2D::SetActivationType(int activation_type) {}
#endif
int DeDepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
std::vector<lite::tensor::Tensor *> outputs_) {
......
......@@ -28,8 +28,25 @@ namespace lite {
class DeDepthwiseConv2D : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(DeDepthwiseConv2D, PrimitiveC);
DeDepthwiseConv2D() = default;
explicit DeDepthwiseConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetFormat(int format);
void SetChannelIn(int channel_in);
void SetChannelMultiplier(int channel_multiplier);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
#else
explicit DeDepthwiseConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
......@@ -50,22 +67,6 @@ class DeDepthwiseConv2D : public PrimitiveC {
int GetDilateH() const;
bool GetHasBias() const;
int GetActivationType() const;
void SetFormat(int format);
void SetChannelIn(int channel_in);
void SetChannelMultiplier(int channel_multiplier);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
int PadUp() const { return this->pad_u_; }
int PadDown() const { return this->pad_d_; }
......
......@@ -30,8 +30,6 @@ void DepthToSpace::SetFormat(int format) { this->primitive_->value.AsDepthToSpac
int DepthToSpace::GetBlockSize() const { return this->primitive_->value_as_DepthToSpace()->blockSize(); }
int DepthToSpace::GetFormat() const { return this->primitive_->value_as_DepthToSpace()->format(); }
void DepthToSpace::SetBlockSize(int block_size) {}
void DepthToSpace::SetFormat(int format) {}
#endif
namespace {
constexpr int kDepthToSpaceOutputNum = 1;
......
......@@ -28,16 +28,17 @@ namespace lite {
class DepthToSpace : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(DepthToSpace, PrimitiveC);
DepthToSpace() = default;
explicit DepthToSpace(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetBlockSize(int block_size);
void SetFormat(int format);
#else
explicit DepthToSpace(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetBlockSize() const;
int GetFormat() const;
void SetBlockSize(int block_size);
void SetFormat(int format);
};
} // namespace lite
} // namespace mindspore
......
......@@ -254,22 +254,6 @@ int DepthwiseConv2D::GetActivationType() const {
return this->primitive_->value_as_DepthwiseConv2D()->activationType();
}
void DepthwiseConv2D::SetFormat(int format) {}
void DepthwiseConv2D::SetChannelIn(int channel_in) {}
void DepthwiseConv2D::SetChannelMultiplier(int channel_multiplier) {}
void DepthwiseConv2D::SetKernelW(int kernel_w) {}
void DepthwiseConv2D::SetKernelH(int kernel_h) {}
void DepthwiseConv2D::SetStrideW(int stride_w) {}
void DepthwiseConv2D::SetStrideH(int stride_h) {}
void DepthwiseConv2D::SetPadMode(int pad_mode) {}
void DepthwiseConv2D::SetPadUp(int pad_up) {}
void DepthwiseConv2D::SetPadDown(int pad_down) {}
void DepthwiseConv2D::SetPadLeft(int pad_left) {}
void DepthwiseConv2D::SetPadRight(int pad_right) {}
void DepthwiseConv2D::SetDilateW(int dilate_w) {}
void DepthwiseConv2D::SetDilateH(int dilate_h) {}
void DepthwiseConv2D::SetHasBias(bool has_bias) {}
void DepthwiseConv2D::SetActivationType(int activation_type) {}
#endif
int DepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
std::vector<lite::tensor::Tensor *> outputs_) {
......
......@@ -27,12 +27,29 @@ namespace mindspore {
namespace lite {
class DepthwiseConv2D : public PrimitiveC {
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(DepthwiseConv2D, PrimitiveC);
public:
DepthwiseConv2D() = default;
explicit DepthwiseConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
void SetFormat(int format);
void SetChannelIn(int channel_in);
void SetChannelMultiplier(int channel_multiplier);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
private:
void PopulaterQuantParam(const Primitive &prim, std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
......@@ -62,22 +79,6 @@ class DepthwiseConv2D : public PrimitiveC {
int GetDilateH() const;
bool GetHasBias() const;
int GetActivationType() const;
void SetFormat(int format);
void SetChannelIn(int channel_in);
void SetChannelMultiplier(int channel_multiplier);
void SetKernelW(int kernel_w);
void SetKernelH(int kernel_h);
void SetStrideW(int stride_w);
void SetStrideH(int stride_h);
void SetPadMode(int pad_mode);
void SetPadUp(int pad_up);
void SetPadDown(int pad_down);
void SetPadLeft(int pad_left);
void SetPadRight(int pad_right);
void SetDilateW(int dilate_w);
void SetDilateH(int dilate_h);
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
int PadUp() const { return this->pad_u_; }
int PadDown() const { return this->pad_d_; }
......
......@@ -21,10 +21,30 @@ namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Dequant::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
this->primitive_ = new (schema::PrimitiveT);
auto attr = std::make_unique<schema::OnnxInt8DequantizeT>();
this->primitive_->value.type = schema::PrimitiveType_OnnxInt8Dequantize;
this->primitive_->value.value = attr.release();
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_OnnxInt8Dequantize;
}
if (this->primitive_->value.type != schema::PrimitiveType_OnnxInt8Dequantize) {
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow)(schema::OnnxInt8DequantizeT);
if (attr == nullptr) {
MS_LOG(ERROR) << "attr is nullptr";
return RET_ERROR;
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#endif
......
......@@ -25,6 +25,7 @@ namespace lite {
class Dequant : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Dequant, PrimitiveC);
Dequant() = default;
explicit Dequant(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
......
......@@ -119,19 +119,6 @@ bool DetectionPostProcess::GetUseRegularNms() const {
return this->primitive_->value_as_DetectionPostProcess()->UseRegularNms();
}
void DetectionPostProcess::SetFormat(int format) {}
void DetectionPostProcess::SetInputSize(int input_size) {}
void DetectionPostProcess::SetHScale(float h_scale) {}
void DetectionPostProcess::SetWScale(float w_scale) {}
void DetectionPostProcess::SetXScale(float x_scale) {}
void DetectionPostProcess::SetYScale(float y_scale) {}
void DetectionPostProcess::SetNmsIouThreshold(float nms_iou_threshold) {}
void DetectionPostProcess::SetNmsScoreThreshold(float nms_score_threshold) {}
void DetectionPostProcess::SetMaxDetections(int64_t max_detections) {}
void DetectionPostProcess::SetDetectionsPreClass(int64_t detections_pre_class) {}
void DetectionPostProcess::SetMaxClassesPreDetection(int64_t max_classes_pre_detection) {}
void DetectionPostProcess::SetNumClasses(int64_t num_classes) {}
void DetectionPostProcess::SetUseRegularNms(bool use_regular_nms) {}
#endif
} // namespace lite
} // namespace mindspore
......@@ -28,8 +28,22 @@ namespace lite {
class DetectionPostProcess : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(DetectionPostProcess, PrimitiveC);
DetectionPostProcess() = default;
explicit DetectionPostProcess(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetFormat(int format);
void SetInputSize(int input_size);
void SetHScale(float h_scale);
void SetWScale(float w_scale);
void SetXScale(float x_scale);
void SetYScale(float y_scale);
void SetNmsIouThreshold(float nms_iou_threshold);
void SetNmsScoreThreshold(float nms_score_threshold);
void SetMaxDetections(int64_t max_detections);
void SetDetectionsPreClass(int64_t detections_pre_class);
void SetMaxClassesPreDetection(int64_t max_classes_pre_detection);
void SetNumClasses(int64_t num_classes);
void SetUseRegularNms(bool use_regular_nms);
#else
explicit DetectionPostProcess(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
......@@ -46,19 +60,6 @@ class DetectionPostProcess : public PrimitiveC {
int64_t GetMaxClassesPreDetection() const;
int64_t GetNumClasses() const;
bool GetUseRegularNms() const;
void SetFormat(int format);
void SetInputSize(int input_size);
void SetHScale(float h_scale);
void SetWScale(float w_scale);
void SetXScale(float x_scale);
void SetYScale(float y_scale);
void SetNmsIouThreshold(float nms_iou_threshold);
void SetNmsScoreThreshold(float nms_score_threshold);
void SetMaxDetections(int64_t max_detections);
void SetDetectionsPreClass(int64_t detections_pre_class);
void SetMaxClassesPreDetection(int64_t max_classes_pre_detection);
void SetNumClasses(int64_t num_classes);
void SetUseRegularNms(bool use_regular_nms);
};
} // namespace lite
} // namespace mindspore
......
......@@ -29,7 +29,6 @@ void Div::SetActivationType(int activation_type) {
int Div::GetActivationType() const { return this->primitive_->value_as_Div()->activationType(); }
void Div::SetActivationType(int activation_type) {}
#endif
} // namespace lite
} // namespace mindspore
......@@ -28,13 +28,15 @@ namespace lite {
class Div : public Arithmetic {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Div, Arithmetic);
Div() = default;
explicit Div(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
void SetActivationType(int activation_type);
#else
explicit Div(schema::Primitive *primitive) : Arithmetic(primitive) {}
#endif
int GetActivationType() const;
void SetActivationType(int activation_type);
};
} // namespace lite
} // namespace mindspore
......
......@@ -27,7 +27,6 @@ void Dropout::SetRatio(float ratio) { this->primitive_->value.AsDropout()->ratio
float Dropout::GetRatio() const { return this->primitive_->value_as_Dropout()->ratio(); }
void Dropout::SetRatio(float ratio) {}
#endif
} // namespace lite
} // namespace mindspore
......@@ -20,21 +20,23 @@
#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#include "ir/dtype/type_id.h"
namespace mindspore {
namespace lite {
class Dropout : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Dropout, PrimitiveC);
Dropout() = default;
explicit Dropout(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetRatio(float ratio);
#else
explicit Dropout(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
float GetRatio() const;
void SetRatio(float ratio);
};
} // namespace lite
} // namespace mindspore
......
......@@ -27,7 +27,6 @@ void Eltwise::SetMode(int mode) { this->primitive_->value.AsEltwise()->mode = (s
int Eltwise::GetMode() const { return this->primitive_->value_as_Eltwise()->mode(); }
void Eltwise::SetMode(int mode) {}
#endif
} // namespace lite
} // namespace mindspore
......@@ -28,13 +28,15 @@ namespace lite {
class Eltwise : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Eltwise, PrimitiveC);
Eltwise() = default;
explicit Eltwise(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetMode(int mode);
#else
explicit Eltwise(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int GetMode() const;
void SetMode(int mode);
};
} // namespace lite
} // namespace mindspore
......
......@@ -27,7 +27,6 @@ void Elu::SetAlpha(float alpha) { this->primitive_->value.AsElu()->alpha = alpha
float Elu::GetAlpha() const { return this->primitive_->value_as_Elu()->alpha(); }
void Elu::SetAlpha(float alpha) {}
#endif
} // namespace lite
} // namespace mindspore
......@@ -28,13 +28,15 @@ namespace lite {
class Elu : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Elu, PrimitiveC);
Elu() = default;
explicit Elu(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetAlpha(float alpha);
#else
explicit Elu(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
float GetAlpha() const;
void SetAlpha(float alpha);
};
} // namespace lite
} // namespace mindspore
......
......@@ -27,7 +27,6 @@ void EmbeddingLookup::SetMaxNorm(float max_norm) { this->primitive_->value.AsEmb
float EmbeddingLookup::GetMaxNorm() const { return this->primitive_->value_as_EmbeddingLookup()->maxNorm(); }
void EmbeddingLookup::SetMaxNorm(float max_norm) {}
#endif
int EmbeddingLookup::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
......
......@@ -28,14 +28,16 @@ namespace lite {
class EmbeddingLookup : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(EmbeddingLookup, PrimitiveC);
EmbeddingLookup() = default;
explicit EmbeddingLookup(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetMaxNorm(float max_norm);
#else
explicit EmbeddingLookup(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
float GetMaxNorm() const;
void SetMaxNorm(float max_norm);
};
} // namespace lite
} // namespace mindspore
......
......@@ -51,9 +51,6 @@ float EmbeddingLookupSparse::GetMaxNortm() const {
return this->primitive_->value_as_EmbeddingLookupSparse()->maxNortm();
}
void EmbeddingLookupSparse::SetSpIds(const std::vector<int> &sp_ids) {}
void EmbeddingLookupSparse::SetSpWeights(const std::vector<float> &sp_weights) {}
void EmbeddingLookupSparse::SetMaxNortm(float max_nortm) {}
#endif
} // namespace lite
} // namespace mindspore
......@@ -28,17 +28,18 @@ namespace lite {
class EmbeddingLookupSparse : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(EmbeddingLookupSparse, PrimitiveC);
EmbeddingLookupSparse() = default;
explicit EmbeddingLookupSparse(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetSpIds(const std::vector<int> &sp_ids);
void SetSpWeights(const std::vector<float> &sp_weights);
void SetMaxNortm(float max_nortm);
#else
explicit EmbeddingLookupSparse(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
std::vector<int> GetSpIds() const;
std::vector<float> GetSpWeights() const;
float GetMaxNortm() const;
void SetSpIds(const std::vector<int> &sp_ids);
void SetSpWeights(const std::vector<float> &sp_weights);
void SetMaxNortm(float max_nortm);
};
} // namespace lite
} // namespace mindspore
......
......@@ -28,6 +28,7 @@ namespace lite {
class Equal : public Arithmetic {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Equal, PrimitiveC);
Equal() = default;
explicit Equal(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#else
......
......@@ -28,6 +28,7 @@ namespace lite {
class Exp : public ArithmeticSelf {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Exp, ArithmeticSelf);
Exp() = default;
explicit Exp(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
#else
......
......@@ -27,7 +27,6 @@ void ExpandDims::SetDim(int dim) { this->primitive_->value.AsExpandDims()->dim =
int ExpandDims::GetDim() const { return this->primitive_->value_as_ExpandDims()->dim(); }
void ExpandDims::SetDim(int dim) {}
#endif
int ExpandDims::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
......
......@@ -28,14 +28,16 @@ namespace lite {
class ExpandDims : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(ExpandDims, PrimitiveC);
ExpandDims() = default;
explicit ExpandDims(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetDim(int dim);
#else
explicit ExpandDims(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetDim() const;
void SetDim(int dim);
};
} // namespace lite
} // namespace mindspore
......
......@@ -40,8 +40,6 @@ int FakeQuantWithMinMaxVars::GetNumBits() const {
return this->primitive_->value_as_FakeQuantWithMinMaxVars()->numBits();
}
void FakeQuantWithMinMaxVars::SetNarrowRange(bool narrow_range) {}
void FakeQuantWithMinMaxVars::SetNumBits(int num_bits) {}
#endif
} // namespace lite
} // namespace mindspore
......@@ -28,15 +28,16 @@ namespace lite {
class FakeQuantWithMinMaxVars : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(FakeQuantWithMinMaxVars, PrimitiveC);
FakeQuantWithMinMaxVars() = default;
explicit FakeQuantWithMinMaxVars(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetNarrowRange(bool narrow_range);
void SetNumBits(int num_bits);
#else
explicit FakeQuantWithMinMaxVars(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
bool GetNarrowRange() const;
int GetNumBits() const;
void SetNarrowRange(bool narrow_range);
void SetNumBits(int num_bits);
};
} // namespace lite
} // namespace mindspore
......
......@@ -30,7 +30,6 @@ std::vector<int> Fill::GetDims() const {
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
void Fill::SetDims(const std::vector<int> &dims) {}
#endif
int Fill::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
......
......@@ -28,14 +28,16 @@ namespace lite {
class Fill : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Fill, PrimitiveC);
Fill() = default;
explicit Fill(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetDims(const std::vector<int> &dims);
#else
explicit Fill(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
std::vector<int> GetDims() const;
void SetDims(const std::vector<int> &dims);
};
} // namespace lite
} // namespace mindspore
......
......@@ -51,10 +51,30 @@ int Flatten::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
}
#ifdef PRIMITIVE_WRITEABLE
int Flatten::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
this->primitive_ = new (schema::PrimitiveT);
auto attr = std::make_unique<schema::FlattenT>();
this->primitive_->value.type = schema::PrimitiveType_Flatten;
this->primitive_->value.value = attr.release();
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Flatten;
}
if (this->primitive_->value.type != schema::PrimitiveType_Flatten) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::FlattenT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#endif
......
......@@ -28,6 +28,7 @@ namespace lite {
class Flatten : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Flatten, PrimitiveC);
Flatten() = default;
explicit Flatten(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
......
......@@ -28,6 +28,7 @@ namespace lite {
class Floor : public ArithmeticSelf {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Floor, ArithmeticSelf);
Floor() = default;
explicit Floor(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
#else
......
......@@ -28,6 +28,7 @@ namespace lite {
class FloorDiv : public Arithmetic {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(FloorDiv, Arithmetic);
FloorDiv() = default;
explicit FloorDiv(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#else
......
......@@ -28,6 +28,7 @@ namespace lite {
class FloorMod : public Arithmetic {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(FloorMod, Arithmetic);
FloorMod() = default;
explicit FloorMod(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#else
......
......@@ -37,10 +37,6 @@ int FullConnection::GetAxis() const { return this->primitive_->value_as_FullConn
bool FullConnection::GetUseAxis() const { return this->primitive_->value_as_FullConnection()->useAxis(); }
int FullConnection::GetActivationType() const { return this->primitive_->value_as_FullConnection()->activationType(); }
void FullConnection::SetHasBias(bool has_bias) {}
void FullConnection::SetAxis(int axis) {}
void FullConnection::SetUseAxis(bool use_axis) {}
void FullConnection::SetActivationType(int activationType) {}
#endif
int FullConnection::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
std::vector<lite::tensor::Tensor *> outputs_) {
......
......@@ -28,8 +28,13 @@ namespace lite {
class FullConnection : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(FullConnection, PrimitiveC);
FullConnection() = default;
explicit FullConnection(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetHasBias(bool has_bias);
void SetAxis(int axis);
void SetUseAxis(bool use_axis);
void SetActivationType(int activationType);
#else
explicit FullConnection(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
......@@ -38,10 +43,6 @@ class FullConnection : public PrimitiveC {
int GetAxis() const;
bool GetUseAxis() const;
int GetActivationType() const;
void SetHasBias(bool has_bias);
void SetAxis(int axis);
void SetUseAxis(bool use_axis);
void SetActivationType(int activationType);
};
} // namespace lite
} // namespace mindspore
......
......@@ -33,9 +33,6 @@ float FusedBatchNorm::GetEpsilon() const { return this->primitive_->value_as_Fus
float FusedBatchNorm::GetMomentum() const { return this->primitive_->value_as_FusedBatchNorm()->momentum(); }
int FusedBatchNorm::GetSpatial() const { return this->primitive_->value_as_FusedBatchNorm()->spatial(); }
void FusedBatchNorm::SetEpsilon(float epsilon) {}
void FusedBatchNorm::SetMomentum(float momentum) {}
void FusedBatchNorm::SetSpatial(int spatial) {}
#endif
} // namespace lite
} // namespace mindspore
......@@ -28,17 +28,18 @@ namespace lite {
class FusedBatchNorm : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(FusedBatchNorm, PrimitiveC);
FusedBatchNorm() = default;
explicit FusedBatchNorm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetEpsilon(float epsilon);
void SetMomentum(float momentum);
void SetSpatial(int spatial);
#else
explicit FusedBatchNorm(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
float GetEpsilon() const;
float GetMomentum() const;
int GetSpatial() const;
void SetEpsilon(float epsilon);
void SetMomentum(float momentum);
void SetSpatial(int spatial);
};
} // namespace lite
} // namespace mindspore
......
......@@ -33,8 +33,6 @@ void Gather::SetBatchDims(int batch_dims) { this->primitive_->value.AsGather()->
int Gather::GetAxis() const { return this->primitive_->value_as_Gather()->axis(); }
int Gather::GetBatchDims() const { return this->primitive_->value_as_Gather()->batchDims(); }
void Gather::SetAxis(int axis) {}
void Gather::SetBatchDims(int batch_dims) {}
#endif
int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
......
......@@ -28,16 +28,17 @@ namespace lite {
class Gather : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Gather, PrimitiveC);
Gather() = default;
explicit Gather(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetAxis(int axis);
void SetBatchDims(int batch_dims);
#else
explicit Gather(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetAxis() const;
int GetBatchDims() const;
void SetAxis(int axis);
void SetBatchDims(int batch_dims);
};
} // namespace lite
} // namespace mindspore
......
......@@ -27,7 +27,6 @@ void GatherNd::SetBatchDims(int batch_dims) { this->primitive_->value.AsGatherNd
int GatherNd::GetBatchDims() const { return this->primitive_->value_as_GatherNd()->batchDims(); }
void GatherNd::SetBatchDims(int batch_dims) {}
#endif
int GatherNd::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
......
......@@ -28,14 +28,16 @@ namespace lite {
class GatherNd : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(GatherNd, PrimitiveC);
GatherNd() = default;
explicit GatherNd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetBatchDims(int batch_dims);
#else
explicit GatherNd(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetBatchDims() const;
void SetBatchDims(int batch_dims);
};
} // namespace lite
} // namespace mindspore
......
......@@ -27,6 +27,7 @@ namespace lite {
class Greater : public Arithmetic {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Greater, Arithmetic);
Greater() = default;
explicit Greater(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#else
......
......@@ -28,6 +28,7 @@ namespace lite {
class GreaterEqual : public Arithmetic {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(GreaterEqual, Arithmetic);
GreaterEqual() = default;
explicit GreaterEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#else
......
......@@ -33,8 +33,6 @@ std::vector<int> L2Norm::GetAxis() const {
}
float L2Norm::GetEpsilon() const { return this->primitive_->value_as_L2Norm()->epsilon(); }
void L2Norm::SetAxis(const std::vector<int> &axis) {}
void L2Norm::SetEpsilon(float epsilon) {}
#endif
} // namespace lite
} // namespace mindspore
......@@ -28,15 +28,16 @@ namespace lite {
class L2Norm : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(L2Norm, PrimitiveC);
L2Norm() = default;
explicit L2Norm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetAxis(const std::vector<int> &axis);
void SetEpsilon(float epsilon);
#else
explicit L2Norm(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
std::vector<int> GetAxis() const;
float GetEpsilon() const;
void SetAxis(const std::vector<int> &axis);
void SetEpsilon(float epsilon);
};
} // namespace lite
} // namespace mindspore
......
......@@ -29,7 +29,6 @@ void LeakyReLU::SetNegativeSlope(float negative_slope) {
float LeakyReLU::GetNegativeSlope() const { return this->primitive_->value_as_LeakyReLU()->negativeSlope(); }
void LeakyReLU::SetNegativeSlope(float negative_slope) {}
#endif
} // namespace lite
} // namespace mindspore
......@@ -28,13 +28,15 @@ namespace lite {
class LeakyReLU : public PrimitiveC {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(LeakyReLU, PrimitiveC);
LeakyReLU() = default;
explicit LeakyReLU(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetNegativeSlope(float negative_slope);
#else
explicit LeakyReLU(schema::Primitive *primitive) : PrimitiveC(primitive) {}
#endif
float GetNegativeSlope() const;
void SetNegativeSlope(float negative_slope);
};
} // namespace lite
} // namespace mindspore
......
......@@ -28,6 +28,7 @@ namespace lite {
class Less : public Arithmetic {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Less, Arithmetic);
Less() = default;
explicit Less(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#else
......
......@@ -28,6 +28,7 @@ namespace lite {
class LessEqual : public Arithmetic {
public:
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(LessEqual, Arithmetic);
LessEqual() = default;
explicit LessEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#else
......
......@@ -60,10 +60,6 @@ float LocalResponseNormalization::GetBeta() const {
return this->primitive_->value_as_LocalResponseNormalization()->beta();
}
void LocalResponseNormalization::SetDepthRadius(int depth_radius) {}
void LocalResponseNormalization::SetBias(float bias) {}
void LocalResponseNormalization::SetAlpha(float alpha) {}
void LocalResponseNormalization::SetBeta(float beta) {}
#endif
} // namespace lite
} // namespace mindspore
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册