提交 19425363 编写于 作者: H hangq

free meta_graph after compile graph

上级 9d8fb786
......@@ -61,5 +61,22 @@ std::vector<size_t> GetGraphOutputNodes(const schema::MetaGraph *meta_graph) {
}
return ret;
}
std::vector<size_t> GetLinkedPostNodeIdx(const schema::MetaGraph &graph, const size_t &tensor_idx) {
std::vector<size_t> post_node_idxes;
for (size_t i = 0; i < graph.nodes()->size(); i++) {
auto node = graph.nodes()->GetAs<schema::CNode>(i);
if (node == nullptr) {
continue;
}
auto node_input_idxes = node->inputIndex();
auto is_contain = std::any_of(node_input_idxes->begin(), node_input_idxes->end(),
[&](const uint32_t &node_input_idx) { return node_input_idx == tensor_idx; });
if (is_contain) {
post_node_idxes.emplace_back(i);
}
}
return post_node_idxes;
}
} // namespace lite
} // namespace mindspore
......@@ -34,215 +34,8 @@ std::vector<size_t> GetGraphInputNodes(const schema::MetaGraph *meta_graph);
std::vector<size_t> GetGraphOutputNodes(const schema::MetaGraph *meta_graph);
class OpNode {
public:
explicit OpNode(const NODE_ID &nodeId) : id(nodeId) {}
NODE_ID ID() { return id; };
void AddInEdge(NODE_ID nodeId) { inEdges.insert(nodeId); }
void AddOutEdge(NODE_ID nodeId) { outEdges.insert(nodeId); }
std::unordered_set<NODE_ID> GetAllInEdges() { return inEdges; }
std::unordered_set<NODE_ID> GetAllOutEdges() { return outEdges; }
protected:
NODE_ID id;
std::unordered_set<NODE_ID> inEdges;
std::unordered_set<NODE_ID> outEdges;
};
template <typename NODE_T>
class OpGraph {
public:
OpGraph() {}
~OpGraph();
int Build(const schema::MetaGraph *subGraphDef);
NODE_T *GetNode(NODE_ID nodeId);
NODE_T *AddNode(NODE_ID nodeId);
std::unordered_set<NODE_T *> GetInputNode();
std::unordered_set<NODE_T *> GetOutputNode();
void AddNodes(std::vector<NODE_T *> addNodes);
void DeleteNodes(std::vector<NODE_T *> deleteNodes);
void AddEdge(NODE_ID nodeId);
int AddEdge(NODE_ID srcId, NODE_ID dstId);
int AddEdge(const schema::CNode *srcNodeDef, const flatbuffers::Vector<flatbuffers::Offset<schema::CNode>> *opDefs);
std::unordered_map<NODE_T *, std::unordered_set<NODE_T *>> GetDepends();
protected:
std::unordered_map<NODE_ID, NODE_T *> nodes;
};
template <typename NODE_T>
int OpGraph<NODE_T>::Build(const schema::MetaGraph *subGraphDef) {
if (subGraphDef == nullptr) {
// MS_LOGE("subGraphDef is nullptr");
return RET_ERROR;
}
auto opDefs = subGraphDef->nodes();
uint32_t opCount = opDefs->size();
for (uint32_t i = 0; i < opCount; i++) {
auto opDef = opDefs->GetAs<schema::CNode>(i);
auto node = AddNode(std::string(opDef->name()->c_str()));
if (node == nullptr) {
// MS_LOGE("add srcNode failed,name %s", opDef->name()->c_str());
return RET_ERROR;
}
auto ret = AddEdge(opDef, opDefs);
if (ret != RET_OK) {
// MS_LOGE("%s add edge failed. ret:%d", opDef->name()->c_str(), ret);
return RET_ERROR;
}
}
return RET_OK;
}
template <typename NODE_T>
int OpGraph<NODE_T>::AddEdge(const schema::CNode *srcNodeDef,
const flatbuffers::Vector<flatbuffers::Offset<schema::CNode>> *nodeDefs) {
MS_ASSERT(srcNodeDef != nullptr);
MS_ASSERT(nodeDefs != nullptr);
NODE_ID srcId = std::string(srcNodeDef->name()->c_str());
uint32_t opCount = nodeDefs->size();
// for single op condition
AddNode(srcId);
for (auto index : *(srcNodeDef->outputIndex())) {
for (uint32_t i = 0; i < opCount; i++) {
auto dstNodeDef = nodeDefs->GetAs<schema::CNode>(i);
bool find = false;
auto inputIndex = dstNodeDef->inputIndex();
if (std::any_of(inputIndex->begin(), inputIndex->end(), [&index](int i) { return i == index; })) {
find = true;
}
if (!find) {
continue;
}
NODE_ID dstId = std::string(dstNodeDef->name()->c_str());
auto ret = AddEdge(srcId, dstId);
if (ret != RET_OK) {
return ret;
}
}
}
return RET_OK;
}
template <typename NODE_T>
int OpGraph<NODE_T>::AddEdge(NODE_ID srcId, NODE_ID dstId) {
auto srcNode = AddNode(srcId);
if (srcNode == nullptr) {
// MS_LOGE("add srcNode failed");
return RET_ERROR;
}
auto dstNode = AddNode(dstId);
if (dstNode == nullptr) {
// MS_LOGE("add dstNode failed");
return RET_ERROR;
}
srcNode->AddOutEdge(dstNode);
dstNode->AddInEdge(srcNode);
return RET_OK;
}
template <typename NODE_T>
NODE_T *OpGraph<NODE_T>::GetNode(NODE_ID nodeId) {
auto node = nodes.find(nodeId);
if (node == nodes.end()) {
return nullptr;
}
return node->second;
}
template <typename NODE_T>
NODE_T *OpGraph<NODE_T>::AddNode(NODE_ID nodeId) {
auto node = GetNode(nodeId);
if (node != nullptr) {
return node;
}
node = new (std::nothrow) NODE_T(nodeId);
if (node == nullptr) {
// MS_LOGE("new node failed");
return nullptr;
}
nodes[nodeId] = node;
return node;
}
template <typename NODE_T>
void OpGraph<NODE_T>::AddNodes(std::vector<NODE_T *> addNodes) {
for (auto node : addNodes) {
if (node == nullptr) {
return;
}
nodes[node->ID()] = node;
}
}
template <typename NODE_T>
void OpGraph<NODE_T>::DeleteNodes(std::vector<NODE_T *> deleteNodes) {
for (auto deletenode : deleteNodes) {
if (deletenode == nullptr) {
continue;
}
auto node = GetNode(deletenode->ID());
if (node == nullptr) {
continue;
}
nodes.erase(deletenode->ID());
}
}
template <typename NODE_T>
std::unordered_set<NODE_T *> OpGraph<NODE_T>::GetInputNode() {
std::unordered_set<NODE_T *> inputNodes;
for (const auto &iter : nodes) {
auto node = iter.second;
if (node->GetAllInEdges().empty()) {
inputNodes.insert(node);
}
}
return inputNodes;
}
template <typename NODE_T>
std::unordered_set<NODE_T *> OpGraph<NODE_T>::GetOutputNode() {
std::unordered_set<NODE_T *> outputNodes;
for (const auto &iter : nodes) {
auto node = iter.second;
if (node->GetAllOutEdges().empty()) {
outputNodes.insert(node);
}
}
return outputNodes;
}
template <typename NODE_T>
std::unordered_map<NODE_T *, std::unordered_set<NODE_T *>> OpGraph<NODE_T>::GetDepends() {
std::unordered_map<NODE_T *, std::unordered_set<NODE_T *>> depends;
for (auto nodeIter : nodes) {
depends[nodeIter.second] = nodeIter.second->GetAllInEdges();
}
return depends;
}
template <typename NODE_T>
OpGraph<NODE_T>::~OpGraph() {
for (auto iter : nodes) {
delete iter.second;
}
nodes.clear();
}
std::vector<size_t> GetLinkedPostNodeIdx(const schema::MetaGraph &graph, const size_t &tensor_idx);
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_COMMON_GRAPH_UTIL_H_
......@@ -32,10 +32,29 @@
namespace mindspore {
namespace lite {
static std::vector<schema::PrimitiveType> packed_op = {
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D,
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D,
schema::PrimitiveType_MatMul};
// this method will not check whether tensor_idx is a weight tensor index, caller should ensure this.
static bool WeightTensorNeedCopy(const lite::Model *model, const uint32_t tensor_idx) {
MS_ASSERT(nullptr != model);
auto meta_graph = model->GetMetaGraph();
MS_ASSERT(nullptr != meta_graph);
auto post_node_idxes = GetLinkedPostNodeIdx(*meta_graph, tensor_idx);
return std::none_of(post_node_idxes.begin(), post_node_idxes.end(), [&](const size_t &post_node_idx) {
auto cNode = meta_graph->nodes()->GetAs<schema::CNode>(post_node_idx);
MS_ASSERT(cNode != nullptr);
return IsContain(packed_op, cNode->primitive()->value_type());
});
}
int LiteSession::ConvertTensors(const lite::Model *model) {
MS_ASSERT(nullptr != model);
auto meta_graph = model->GetMetaGraph();
MS_ASSERT(nullptr != meta_graph);
copyed_tensor_idxes_.clear();
uint32_t tensorCount = meta_graph->allTensors()->size();
for (uint32_t i = 0; i < tensorCount; i++) {
auto *srcTensor = meta_graph->allTensors()->GetAs<schema::Tensor>(i);
......@@ -54,16 +73,30 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
}
}
int dataType = srcTensor->dataType();
auto *dstTensor = new tensor::Tensor(TypeId(dataType), shape, srcTensor->format(), srcTensor->nodeType());
auto *dstTensor =
new (std::nothrow) tensor::Tensor(TypeId(dataType), shape, srcTensor->format(), srcTensor->nodeType());
if (dstTensor == nullptr) {
MS_LOG(ERROR) << "new " << i << "th tensor failed";
return RET_NULL_PTR;
}
if (srcTensor->nodeType() == schema::NodeType_ValueNode && srcTensor->data() != nullptr &&
srcTensor->data()->size() > 0) {
if (shape.empty()) {
shape.push_back(1);
dstTensor->set_shape(shape);
}
MS_ASSERT(dstTensor != nullptr);
MS_ASSERT(dstTensor->Size() == srcTensor->data()->size());
// no copy data, do copy when call LiteKernel::Init
dstTensor->SetData(const_cast<unsigned char *>(srcTensor->data()->data()));
if (WeightTensorNeedCopy(model, i)) {
auto ret = dstTensor->MallocData();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Malloc data for " << i << "th tensor failed";
return RET_ERROR;
}
memcpy(dstTensor->Data(), srcTensor->data()->data(), dstTensor->Size());
copyed_tensor_idxes_.emplace_back(i);
} else {
dstTensor->SetData(const_cast<unsigned char *>(srcTensor->data()->data()));
}
}
auto quant_params = srcTensor->quantParams();
if (quant_params != nullptr) {
......@@ -74,7 +107,6 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
dstTensor->AddQuantParam(quant_arg);
}
}
this->tensors_.emplace_back(dstTensor);
}
......@@ -240,6 +272,7 @@ int LiteSession::CompileGraph(Model *model) {
}
executor->Prepare(this->kernels_);
model->FreeMetaGraph();
return RET_OK;
}
......@@ -277,7 +310,10 @@ int LiteSession::Init(Context *context) {
}
#endif
executor = new Executor();
MS_ASSERT(nullptr != executor);
if (nullptr == executor) {
MS_LOG(ERROR) << "new Executor failed";
return RET_ERROR;
}
return RET_OK;
}
......@@ -288,9 +324,12 @@ void LiteSession::BindThread(bool if_bind) {
}
LiteSession::~LiteSession() {
for (auto *tensor : tensors_) {
// weight data can not be to free, we will free weight data when freeing meta_graph
if (tensor->TensorType() == schema::NodeType_ValueNode && !IsContain(this->inputs_, tensor)) {
for (size_t i = 0; i < tensors_.size(); i++) {
auto *tensor = tensors_.at(i);
MS_ASSERT(tensor != nullptr);
// data of weight tensor of node in packed_op can not be to free, we will free weight data when freeing meta_graph
if (tensor->TensorType() == schema::NodeType_ValueNode && !IsContain(this->inputs_, tensor) &&
!IsContain(copyed_tensor_idxes_, i)) {
tensor->SetData(nullptr);
}
delete tensor;
......
......@@ -87,6 +87,7 @@ class LiteSession : public session::LiteSession {
Context *context_ = nullptr;
std::vector<kernel::LiteKernel *> kernels_;
std::vector<tensor::Tensor *> tensors_;
std::vector<size_t> copyed_tensor_idxes_;
// graph input tensors
std::vector<tensor::Tensor *> inputs_;
// graph output tensors
......
......@@ -135,7 +135,7 @@ mindspore::lite::PrimitiveC *Model::GetOp(const std::string &name) const {
void Model::FreeMetaGraph() {
MS_ASSERT(nullptr != model_impl_);
return model_impl_->FreeMetaGraph();
model_impl_->FreeMetaGraph();
}
const schema::MetaGraph *Model::GetMetaGraph() const {
......
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/abs.h"
namespace mindspore {
namespace lite {
#ifndef PRIMITIVE_WRITEABLE
int Abs::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateAbs(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Abs, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif
} // namespace lite
} // namespace mindspore
......@@ -32,27 +32,9 @@ class Abs : public ArithmeticSelf {
Abs() = default;
explicit Abs(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
#else
explicit Abs(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto val_offset = schema::CreateAbs(fbb);
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Abs, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Abs() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
};
} // namespace lite
......
......@@ -55,7 +55,19 @@ int Activation::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr>
return RET_OK;
}
#else
int Activation::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Activation();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Activation return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateActivation(*fbb, attr->type(), attr->alpha());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Activation, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int Activation::GetType() const { return this->primitive_->value_as_Activation()->type(); }
float Activation::GetAlpha() const { return this->primitive_->value_as_Activation()->alpha(); }
#endif
......
......@@ -30,34 +30,13 @@ class Activation : public PrimitiveC {
MS_DECLARE_PARENT(Activation, PrimitiveC);
Activation() = default;
explicit Activation(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void SetType(int type);
void SetAlpha(float alpha);
#else
explicit Activation(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_Activation();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateActivation(fbb, attr->type(), attr->alpha());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Activation, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Activation() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int GetType() const;
float GetAlpha() const;
......
......@@ -26,7 +26,19 @@ void ActivationGrad::SetType(int type) {
}
#else
int ActivationGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_ActivationGrad();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_ActivationGrad return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateActivationGrad(*fbb, attr->type());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ActivationGrad, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int ActivationGrad::GetType() const { return this->primitive_->value_as_ActivationGrad()->type(); }
#endif
......
......@@ -33,30 +33,9 @@ class ActivationGrad : public PrimitiveC {
explicit ActivationGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetType(int type);
#else
explicit ActivationGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_ActivationGrad();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateActivationGrad(fbb, attr->type());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_ActivationGrad, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
ActivationGrad() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int GetType() const;
};
......
......@@ -50,7 +50,19 @@ int Add::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs
}
#else
int Add::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Add();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Add return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateAdd(*fbb, attr->activationType());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Add, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int Add::GetActivationType() const { return this->primitive_->value_as_Add()->activationType(); }
#endif
......
......@@ -31,33 +31,12 @@ class Add : public Arithmetic {
MS_DECLARE_PARENT(Add, Arithmetic);
Add() = default;
explicit Add(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void SetActivationType(int activation_type);
#else
explicit Add(schema::Primitive *primitive) : Arithmetic(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_Add();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateAdd(fbb, attr->activationType());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Add, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Add() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int GetActivationType() const;
};
......
......@@ -24,7 +24,19 @@ int AddN::GetN() const { return this->primitive_->value.AsAddN()->N; }
void AddN::SetN(int n) { this->primitive_->value.AsAddN()->N = n; }
#else
int AddN::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_AddN();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_AddN return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateAddN(*fbb, attr->N());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_AddN, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int AddN::GetN() const { return this->primitive_->value_as_AddN()->N(); }
#endif
......
......@@ -33,30 +33,9 @@ class AddN : public PrimitiveC {
explicit AddN(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetN(int n);
#else
explicit AddN(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_AddN();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateAddN(fbb, attr->N());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_AddN, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
AddN() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetN() const;
......
......@@ -32,7 +32,20 @@ void ArgMax::SetKeepDims(bool keep_dims) { this->primitive_->value.AsArgMax()->k
void ArgMax::SetAxisType(int axis_type) { this->primitive_->value.AsArgMax()->axisType = axis_type; }
#else
int ArgMax::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_ArgMax();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_ArgMax return nullptr";
return RET_ERROR;
}
auto val_offset =
schema::CreateArgMax(*fbb, attr->axis(), attr->outMaxValue(), attr->topK(), attr->keepDims(), attr->axisType());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ArgMax, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int ArgMax::GetAxis() const { return this->primitive_->value_as_ArgMax()->axis(); }
bool ArgMax::GetOutMaxValue() const { return this->primitive_->value_as_ArgMax()->outMaxValue(); }
int ArgMax::GetTopK() const { return this->primitive_->value_as_ArgMax()->topK(); }
......
......@@ -37,31 +37,9 @@ class ArgMax : public PrimitiveC {
void SetKeepDims(bool keep_dims);
void SetAxisType(int axis_type);
#else
explicit ArgMax(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_ArgMax();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateArgMax(fbb, attr->axis(), attr->outMaxValue(),
attr->topK(), attr->keepDims(), attr->axisType());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_ArgMax, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
ArgMax() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetAxis() const;
......
......@@ -32,7 +32,20 @@ void ArgMin::SetKeepDims(bool keep_dims) { this->primitive_->value.AsArgMin()->k
void ArgMin::SetAxisType(int axis_type) { this->primitive_->value.AsArgMin()->axisType = axis_type; }
#else
int ArgMin::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_ArgMin();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_ArgMin return nullptr";
return RET_ERROR;
}
auto val_offset =
schema::CreateArgMin(*fbb, attr->axis(), attr->outMaxValue(), attr->topK(), attr->keepDims(), attr->axisType());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ArgMin, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int ArgMin::GetAxis() const { return this->primitive_->value_as_ArgMin()->axis(); }
bool ArgMin::GetOutMaxValue() const { return this->primitive_->value_as_ArgMin()->outMaxValue(); }
int ArgMin::GetTopK() const { return this->primitive_->value_as_ArgMin()->topK(); }
......
......@@ -37,31 +37,9 @@ class ArgMin : public PrimitiveC {
void SetKeepDims(bool keep_dims);
void SetAxisType(int axis_type);
#else
explicit ArgMin(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_ArgMin();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateArgMin(fbb, attr->axis(), attr->outMaxValue(),
attr->topK(), attr->keepDims(), attr->axisType());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_ArgMin, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
ArgMin() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetAxis() const;
......
......@@ -32,7 +32,11 @@ class Arithmetic : public PrimitiveC {
Arithmetic() = default;
explicit Arithmetic(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
explicit Arithmetic(schema::Primitive *primitive) : PrimitiveC(primitive) {}
// explicit Arithmetic(schema::Primitive *primitive) : PrimitiveC(primitive) {}
Arithmetic() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override {
return RET_ERROR;
}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
bool Broadcasting() { return this->broadcasting_; }
......
......@@ -29,7 +29,11 @@ class ArithmeticSelf : public PrimitiveC {
ArithmeticSelf() = default;
explicit ArithmeticSelf(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
explicit ArithmeticSelf(schema::Primitive *primitive) : PrimitiveC(primitive) {}
// explicit ArithmeticSelf(schema::Primitive *primitive) : PrimitiveC(primitive) {}
ArithmeticSelf() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override {
return RET_ERROR;
}
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
};
......
......@@ -49,7 +49,14 @@ int BatchNorm::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &
}
#else
int BatchNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateBatchNorm(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BatchNorm, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
float BatchNorm::GetEpsilon() const { return this->primitive_->value_as_BatchNorm()->epsilon(); }
#endif
......
......@@ -31,30 +31,12 @@ class BatchNorm : public PrimitiveC {
MS_DECLARE_PARENT(BatchNorm, PrimitiveC);
BatchNorm() = default;
explicit BatchNorm(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void SetEpsilon(float epsilon);
#else
explicit BatchNorm(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto val_offset = schema::CreateBatchNorm(fbb);
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_BatchNorm, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
BatchNorm() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
float GetEpsilon() const;
};
......
......@@ -32,7 +32,31 @@ void BatchToSpace::SetBlockShape(const std::vector<int> &block_shape) {
void BatchToSpace::SetCrops(const std::vector<int> &crops) { this->primitive_->value.AsBatchToSpace()->crops = crops; }
#else
int BatchToSpace::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_BatchToSpace();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_BatchToSpace return nullptr";
return RET_ERROR;
}
std::vector<int32_t> blockShape;
if (attr->blockShape() != nullptr) {
for (int i = 0; i < static_cast<int>(attr->blockShape()->size()); i++) {
blockShape.push_back(attr->blockShape()->data()[i]);
}
}
std::vector<int32_t> crops;
if (attr->crops() != nullptr) {
for (int i = 0; i < static_cast<int>(attr->crops()->size()); i++) {
crops.push_back(attr->crops()->data()[i]);
}
}
auto val_offset = schema::CreateBatchToSpaceDirect(*fbb, &blockShape, &crops);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BatchToSpace, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
std::vector<int> BatchToSpace::GetBlockShape() const {
auto fb_vector = this->primitive_->value_as_BatchToSpace()->blockShape();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
......
......@@ -35,39 +35,9 @@ class BatchToSpace : public PrimitiveC {
void SetBlockShape(const std::vector<int> &block_shape);
void SetCrops(const std::vector<int> &crops);
#else
explicit BatchToSpace(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_BatchToSpace();
MS_ASSERT(attr != nullptr);
auto blockShape = std::make_unique<std::vector<int32_t>>();
for (int i = 0; i < static_cast<int>(attr->blockShape()->size()); i++) {
blockShape->push_back(attr->blockShape()->data()[i]);
}
auto crops = std::make_unique<std::vector<int32_t>>();
for (int i = 0; i < static_cast<int>(attr->crops()->size()); i++) {
crops->push_back(attr->crops()->data()[i]);
}
auto val_offset = schema::CreateBatchToSpaceDirect(fbb, blockShape.release(), crops.release());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_BatchToSpace, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
BatchToSpace() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
std::vector<int> GetBlockShape() const;
......
......@@ -54,7 +54,25 @@ int BiasAdd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
}
#else
int BiasAdd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_BiasAdd();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_BiasAdd return nullptr";
return RET_ERROR;
}
std::vector<int32_t> axis;
if (attr->axis() != nullptr) {
for (int i = 0; i < static_cast<int>(attr->axis()->size()); i++) {
axis.push_back(attr->axis()->data()[i]);
}
}
auto val_offset = schema::CreateBiasAddDirect(*fbb, &axis);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BiasAdd, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
std::vector<int> BiasAdd::GetAxis() const {
auto fb_vector = this->primitive_->value_as_BiasAdd()->axis();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
......
......@@ -32,38 +32,12 @@ class BiasAdd : public PrimitiveC {
MS_DECLARE_PARENT(BiasAdd, PrimitiveC);
BiasAdd() = default;
explicit BiasAdd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void SetAxis(const std::vector<int> &axis);
#else
explicit BiasAdd(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_BiasAdd();
MS_ASSERT(attr != nullptr);
auto axis = std::make_unique<std::vector<int32_t>>();
for (int i = 0; i < static_cast<int>(attr->axis()->size()); i++) {
axis->push_back(attr->axis()->data()[i]);
}
auto val_offset = schema::CreateBiasAddDirect(fbb, axis.release());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_BiasAdd, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
BiasAdd() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
std::vector<int> GetAxis() const;
};
......
......@@ -24,7 +24,25 @@ std::vector<int> BiasGrad::GetAxis() const { return this->primitive_->value.AsBi
void BiasGrad::SetAxis(const std::vector<int> &axis) { this->primitive_->value.AsBiasGrad()->axis = axis; }
#else
int BiasGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_BiasGrad();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_BiasGrad return nullptr";
return RET_ERROR;
}
std::vector<int32_t> axis;
if (attr->axis() != nullptr) {
for (int i = 0; i < static_cast<int>(attr->axis()->size()); i++) {
axis.push_back(attr->axis()->data()[i]);
}
}
auto val_offset = schema::CreateBiasGradDirect(*fbb, &axis);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BiasGrad, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
std::vector<int> BiasGrad::GetAxis() const {
auto fb_vector = this->primitive_->value_as_BiasGrad()->axis();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
......
......@@ -35,35 +35,9 @@ class BiasGrad : public PrimitiveC {
void SetAxis(const std::vector<int> &axis);
#else
explicit BiasGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_BiasGrad();
MS_ASSERT(attr != nullptr);
auto axis = std::make_unique<std::vector<int32_t>>();
for (int i = 0; i < static_cast<int>(attr->axis()->size()); i++) {
axis->push_back(attr->axis()->data()[i]);
}
auto val_offset = schema::CreateBiasGradDirect(fbb, axis.release());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_BiasGrad, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
BiasGrad() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
std::vector<int> GetAxis() const;
};
......
......@@ -26,7 +26,19 @@ void BNGradInput::SetEps(float eps) { this->primitive_->value.AsBNGradInput()->e
void BNGradInput::SetChannels(int channels) { this->primitive_->value.AsBNGradInput()->channels = channels; }
#else
int BNGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_BNGradInput();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_BNGradInput return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateBNGradInput(*fbb, attr->eps(), attr->channels());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BNGradInput, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
float BNGradInput::GetEps() const { return this->primitive_->value_as_BNGradInput()->eps(); }
int BNGradInput::GetChannels() const { return this->primitive_->value_as_BNGradInput()->channels(); }
......
......@@ -34,30 +34,9 @@ class BNGradInput : public PrimitiveC {
void SetEps(float eps);
void SetChannels(int channels);
#else
explicit BNGradInput(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_BNGradInput();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateBNGradInput(fbb, attr->eps(), attr->channels());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_BNGradInput, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
BNGradInput() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
float GetEps() const;
int GetChannels() const;
......
......@@ -26,7 +26,25 @@ void BroadcastTo::SetDstShape(const std::vector<int> &dst_shape) {
}
#else
int BroadcastTo::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_BroadcastTo();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_BroadcastTo return nullptr";
return RET_ERROR;
}
std::vector<int32_t> dst_shape;
if (attr->dst_shape() != nullptr) {
for (int i = 0; i < static_cast<int>(attr->dst_shape()->size()); i++) {
dst_shape.push_back(attr->dst_shape()->data()[i]);
}
}
auto val_offset = schema::CreateBroadcastToDirect(*fbb, &dst_shape);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BroadcastTo, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
std::vector<int> BroadcastTo::GetDstShape() const {
auto fb_vector = this->primitive_->value_as_BroadcastTo()->dst_shape();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
......
......@@ -35,35 +35,9 @@ class BroadcastTo : public PrimitiveC {
void SetDstShape(const std::vector<int> &dst_shape);
#else
explicit BroadcastTo(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_BroadcastTo();
MS_ASSERT(attr != nullptr);
auto dst_shape = std::make_unique<std::vector<int32_t>>();
for (int i = 0; i < static_cast<int>(attr->dst_shape()->size()); i++) {
dst_shape->push_back(attr->dst_shape()->data()[i]);
}
auto val_offset = schema::CreateBroadcastToDirect(fbb, dst_shape.release());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_BroadcastTo, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
BroadcastTo() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
std::vector<int> GetDstShape() const;
......
......@@ -26,7 +26,19 @@ void Cast::SetSrcT(int src_t) { this->primitive_->value.AsCast()->srcT = src_t;
void Cast::SetDstT(int dst_t) { this->primitive_->value.AsCast()->dstT = dst_t; }
#else
int Cast::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Cast();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Cast return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateCast(*fbb, attr->srcT(), attr->dstT());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Cast, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int Cast::GetSrcT() const { return this->primitive_->value_as_Cast()->srcT(); }
int Cast::GetDstT() const { return this->primitive_->value_as_Cast()->dstT(); }
......
......@@ -34,30 +34,9 @@ class Cast : public PrimitiveC {
void SetSrcT(int src_t);
void SetDstT(int dst_t);
#else
explicit Cast(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_Cast();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateCast(fbb, attr->srcT(), attr->dstT());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Cast, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Cast() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetSrcT() const;
......
......@@ -32,26 +32,15 @@ class Ceil : public ArithmeticSelf {
Ceil() = default;
explicit Ceil(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
#else
explicit Ceil(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto val_offset = schema::CreateCeil(fbb);
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Ceil, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Ceil() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateCeil(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Ceil, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif
};
......
......@@ -26,7 +26,19 @@ void Clip::SetMax(float max) { this->primitive_->value.AsClip()->max = max; }
void Clip::SetMin(float min) { this->primitive_->value.AsClip()->min = min; }
#else
int Clip::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Clip();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Clip return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateClip(*fbb, attr->max(), attr->min());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Clip, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
float Clip::GetMax() const { return this->primitive_->value_as_Clip()->max(); }
float Clip::GetMin() const { return this->primitive_->value_as_Clip()->min(); }
......
......@@ -34,30 +34,9 @@ class Clip : public PrimitiveC {
void SetMax(float max);
void SetMin(float min);
#else
explicit Clip(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_Clip();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateClip(fbb, attr->max(), attr->min());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Clip, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Clip() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
float GetMax() const;
float GetMin() const;
......
......@@ -60,7 +60,19 @@ int Concat::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
}
#else
int Concat::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Concat();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Concat return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateConcat(*fbb, attr->axis(), attr->n());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Concat, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int Concat::GetAxis() const { return this->primitive_->value_as_Concat()->axis(); }
int Concat::GetN() const { return this->primitive_->value_as_Concat()->n(); }
......
......@@ -31,34 +31,13 @@ class Concat : public PrimitiveC {
MS_DECLARE_PARENT(Concat, PrimitiveC);
Concat() = default;
explicit Concat(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void SetAxis(int axis);
void SetN(int n);
#else
explicit Concat(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_Concat();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateConcat(fbb, attr->axis(), attr->n());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Concat, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Concat() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetAxis() const;
......
......@@ -30,7 +30,19 @@ float ConstantOfShape::GetValue() const { return this->primitive_->value.AsConst
void ConstantOfShape::SetValue(float value) { this->primitive_->value.AsConstantOfShape()->value = value; }
#else
int ConstantOfShape::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_ConstantOfShape();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_ConstantOfShape return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateConstantOfShape(*fbb, attr->value());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ConstantOfShape, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
float ConstantOfShape::GetValue() const { return this->primitive_->value_as_ConstantOfShape()->value(); }
#endif
......
......@@ -33,30 +33,9 @@ class ConstantOfShape : public PrimitiveC {
explicit ConstantOfShape(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetValue(float value);
#else
explicit ConstantOfShape(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_ConstantOfShape();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateConstantOfShape(fbb, attr->value());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_ConstantOfShape, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
ConstantOfShape() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
float GetValue() const;
......
......@@ -338,7 +338,23 @@ int Conv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
}
#else
int Conv2D::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Conv2D();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Conv2D return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateConv2D(
*fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), attr->kernelW(), attr->kernelH(),
attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Conv2D, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int Conv2D::GetFormat() const { return this->primitive_->value_as_Conv2D()->format(); }
int Conv2D::GetGroup() const { return this->primitive_->value_as_Conv2D()->group(); }
int Conv2D::GetChannelIn() const { return this->primitive_->value_as_Conv2D()->channelIn(); }
......
......@@ -34,7 +34,7 @@ class Conv2D : public PrimitiveC {
Conv2D() = default;
explicit Conv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void SetFormat(int format);
void SetGroup(int group);
void SetChannelIn(int channel_in);
......@@ -63,34 +63,9 @@ class Conv2D : public PrimitiveC {
#else
public:
explicit Conv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_Conv2D();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateConv2D(fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(),
attr->kernelW(), attr->kernelH(), attr->strideW(), attr->strideH(),
attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
attr->padRight(), attr->dilateW(), attr->dilateH(),
attr->hasBias(), attr->activationType());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Conv2D, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Conv2D() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
public:
......
......@@ -68,7 +68,22 @@ void Conv2DGradFilter::SetActivationType(int activation_type) {
}
#else
int Conv2DGradFilter::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Conv2DGradFilter();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Conv2DGradFilter return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateConv2DGradFilter(
*fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), attr->kernelW(), attr->kernelH(),
attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Conv2DGradFilter, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int Conv2DGradFilter::GetFormat() const { return this->primitive_->value_as_Conv2DGradFilter()->format(); }
int Conv2DGradFilter::GetGroup() const { return this->primitive_->value_as_Conv2DGradFilter()->group(); }
int Conv2DGradFilter::GetChannelIn() const { return this->primitive_->value_as_Conv2DGradFilter()->channelIn(); }
......
......@@ -49,35 +49,9 @@ class Conv2DGradFilter : public PrimitiveC {
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
#else
explicit Conv2DGradFilter(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_Conv2DGradFilter();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateConv2DGradFilter(fbb, attr->format(), attr->group(),
attr->channelIn(), attr->channelOut(),
attr->kernelW(), attr->kernelH(), attr->strideW(), attr->strideH(),
attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
attr->padRight(), attr->dilateW(), attr->dilateH(),
attr->hasBias(), attr->activationType());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Conv2DGradFilter, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Conv2DGradFilter() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int GetFormat() const;
int GetGroup() const;
......
......@@ -66,7 +66,22 @@ void Conv2DGradInput::SetActivationType(int activation_type) {
}
#else
int Conv2DGradInput::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Conv2DGradInput();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Conv2DGradInput return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateConv2DGradInput(
*fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), attr->kernelW(), attr->kernelH(),
attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Conv2DGradInput, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int Conv2DGradInput::GetFormat() const { return this->primitive_->value_as_Conv2DGradInput()->format(); }
int Conv2DGradInput::GetGroup() const { return this->primitive_->value_as_Conv2DGradInput()->group(); }
int Conv2DGradInput::GetChannelIn() const { return this->primitive_->value_as_Conv2DGradInput()->channelIn(); }
......
......@@ -49,35 +49,9 @@ class Conv2DGradInput : public PrimitiveC {
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
#else
explicit Conv2DGradInput(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_Conv2DGradInput();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateConv2DGradInput(fbb, attr->format(), attr->group(),
attr->channelIn(), attr->channelOut(),
attr->kernelW(), attr->kernelH(), attr->strideW(), attr->strideH(),
attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
attr->padRight(), attr->dilateW(), attr->dilateH(),
attr->hasBias(), attr->activationType());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Conv2DGradInput, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Conv2DGradInput() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int GetFormat() const;
int GetGroup() const;
......
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/cos.h"
namespace mindspore {
namespace lite {
#ifndef PRIMITIVE_WRITEABLE
int Cos::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateCos(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Cos, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif
} // namespace lite
} // namespace mindspore
......@@ -21,7 +21,7 @@
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#include "src/ops/arithmetic_self.h"
namespace mindspore {
namespace lite {
......@@ -31,27 +31,9 @@ class Cos : public ArithmeticSelf {
Cos() = default;
explicit Cos(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
#else
explicit Cos(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto val_offset = schema::CreateCos(fbb);
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Cos, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Cos() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
};
} // namespace lite
......
......@@ -26,7 +26,25 @@ void Crop::SetAxis(int64_t axis) { this->primitive_->value.AsCrop()->axis = axis
void Crop::SetOffsets(const std::vector<int64_t> &offsets) { this->primitive_->value.AsCrop()->offsets = offsets; }
#else
int Crop::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Crop();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Crop return nullptr";
return RET_ERROR;
}
std::vector<int64_t> offsets;
if (attr->offsets() != nullptr) {
for (int i = 0; i < static_cast<int>(attr->offsets()->size()); i++) {
offsets.push_back(attr->offsets()->data()[i]);
}
}
auto val_offset = schema::CreateCropDirect(*fbb, attr->axis(), &offsets);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Crop, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int64_t Crop::GetAxis() const { return this->primitive_->value_as_Crop()->axis(); }
std::vector<int64_t> Crop::GetOffsets() const {
auto fb_vector = this->primitive_->value_as_Crop()->offsets();
......
......@@ -35,35 +35,9 @@ class Crop : public PrimitiveC {
void SetAxis(int64_t axis);
void SetOffsets(const std::vector<int64_t> &offsets);
#else
explicit Crop(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_Crop();
MS_ASSERT(attr != nullptr);
auto offsets = std::make_unique<std::vector<int64_t>>();
for (int i = 0; i < static_cast<int>(attr->offsets()->size()); i++) {
offsets->push_back(attr->offsets()->data()[i]);
}
auto val_offset = schema::CreateCropDirect(fbb, attr->axis(), offsets.release());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Crop, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Crop() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int64_t GetAxis() const;
......
......@@ -58,7 +58,22 @@ void DeConv2D::SetActivationType(int activation_type) {
}
#else
int DeConv2D::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_DeConv2D();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_DeConv2D return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateDeConv2D(
*fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(), attr->kernelW(), attr->kernelH(),
attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_DeConv2D, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int DeConv2D::GetFormat() const { return this->primitive_->value_as_DeConv2D()->format(); }
int DeConv2D::GetGroup() const { return this->primitive_->value_as_DeConv2D()->group(); }
int DeConv2D::GetChannelIn() const { return this->primitive_->value_as_DeConv2D()->channelIn(); }
......
......@@ -49,34 +49,9 @@ class DeConv2D : public PrimitiveC {
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
#else
explicit DeConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_DeConv2D();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateDeConv2D(fbb, attr->format(), attr->group(), attr->channelIn(), attr->channelOut(),
attr->kernelW(), attr->kernelH(), attr->strideW(), attr->strideH(),
attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
attr->padRight(), attr->dilateW(), attr->dilateH(),
attr->hasBias(), attr->activationType());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_DeConv2D, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
DeConv2D() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetFormat() const;
......
......@@ -70,7 +70,24 @@ void DeDepthwiseConv2D::SetActivationType(int activation_type) {
}
#else
int DeDepthwiseConv2D::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_DeDepthwiseConv2D();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_DeDepthwiseConv2D return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateDeDepthwiseConv2D(
*fbb, attr->format(), attr->channelIn(), attr->channelMultiplier(), attr->kernelW(), attr->kernelH(),
attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_DeDepthwiseConv2D, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int DeDepthwiseConv2D::GetFormat() const { return this->primitive_->value_as_DeDepthwiseConv2D()->format(); }
int DeDepthwiseConv2D::GetChannelIn() const { return this->primitive_->value_as_DeDepthwiseConv2D()->channelIn(); }
int DeDepthwiseConv2D::GetChannelMultiplier() const {
......
......@@ -48,34 +48,9 @@ class DeDepthwiseConv2D : public PrimitiveC {
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
#else
explicit DeDepthwiseConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_DeDepthwiseConv2D();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateDeDepthwiseConv2D(fbb, attr->format(), attr->channelIn(), attr->channelMultiplier(),
attr->kernelW(), attr->kernelH(), attr->strideW(), attr->strideH(),
attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
attr->padRight(), attr->dilateW(), attr->dilateH(),
attr->hasBias(), attr->activationType());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_DeDepthwiseConv2D, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
DeDepthwiseConv2D() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetFormat() const;
......
......@@ -26,7 +26,19 @@ void DepthToSpace::SetBlockSize(int block_size) { this->primitive_->value.AsDept
void DepthToSpace::SetFormat(int format) { this->primitive_->value.AsDepthToSpace()->format = (schema::Format)format; }
#else
int DepthToSpace::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_DepthToSpace();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_DepthToSpace return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateDepthToSpace(*fbb, attr->blockSize(), attr->format());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_DepthToSpace, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int DepthToSpace::GetBlockSize() const { return this->primitive_->value_as_DepthToSpace()->blockSize(); }
int DepthToSpace::GetFormat() const { return this->primitive_->value_as_DepthToSpace()->format(); }
......
......@@ -34,30 +34,9 @@ class DepthToSpace : public PrimitiveC {
void SetBlockSize(int block_size);
void SetFormat(int format);
#else
explicit DepthToSpace(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_DepthToSpace();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateDepthToSpace(fbb, attr->blockSize(), attr->format());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_DepthToSpace, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
DepthToSpace() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetBlockSize() const;
......
......@@ -232,7 +232,22 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNode
}
#else
int DepthwiseConv2D::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_DepthwiseConv2D();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_DepthwiseConv2D return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateDepthwiseConv2D(
*fbb, attr->format(), attr->channelIn(), attr->channelMultiplier(), attr->kernelW(), attr->kernelH(),
attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
attr->padRight(), attr->dilateW(), attr->dilateH(), attr->hasBias(), attr->activationType());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_DepthwiseConv2D, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int DepthwiseConv2D::GetFormat() const { return this->primitive_->value_as_DepthwiseConv2D()->format(); }
int DepthwiseConv2D::GetChannelIn() const { return this->primitive_->value_as_DepthwiseConv2D()->channelIn(); }
int DepthwiseConv2D::GetChannelMultiplier() const {
......
......@@ -33,7 +33,7 @@ class DepthwiseConv2D : public PrimitiveC {
DepthwiseConv2D() = default;
explicit DepthwiseConv2D(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void SetFormat(int format);
void SetChannelIn(int channel_in);
void SetChannelMultiplier(int channel_multiplier);
......@@ -58,35 +58,9 @@ class DepthwiseConv2D : public PrimitiveC {
#else
public:
explicit DepthwiseConv2D(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_DepthwiseConv2D();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateDepthwiseConv2D(fbb, attr->format(),
attr->channelIn(), attr->channelMultiplier(),
attr->kernelW(), attr->kernelH(), attr->strideW(), attr->strideH(),
attr->padMode(), attr->padUp(), attr->padDown(), attr->padLeft(),
attr->padRight(), attr->dilateW(), attr->dilateH(),
attr->hasBias(), attr->activationType());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_DepthwiseConv2D, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
DepthwiseConv2D() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
public:
......
......@@ -28,9 +28,9 @@ class Dequant : public PrimitiveC {
MS_DECLARE_PARENT(Dequant, PrimitiveC);
Dequant() = default;
explicit Dequant(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
explicit Dequant(schema::Primitive *primitive) : PrimitiveC(primitive) {}
Dequant() = default;
#endif
};
} // namespace lite
......
......@@ -88,7 +88,22 @@ void DetectionPostProcess::SetUseRegularNms(bool use_regular_nms) {
}
#else
int DetectionPostProcess::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_DetectionPostProcess();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_DetectionPostProcess return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateDetectionPostProcess(
*fbb, attr->format(), attr->inputSize(), attr->hScale(), attr->wScale(), attr->xScale(), attr->yScale(),
attr->NmsIouThreshold(), attr->NmsScoreThreshold(), attr->MaxDetections(), attr->DetectionsPreClass(),
attr->MaxClassesPreDetection(), attr->NumClasses(), attr->UseRegularNms());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_DetectionPostProcess, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int DetectionPostProcess::GetFormat() const { return this->primitive_->value_as_DetectionPostProcess()->format(); }
int DetectionPostProcess::GetInputSize() const {
return this->primitive_->value_as_DetectionPostProcess()->inputSize();
......
......@@ -45,36 +45,9 @@ class DetectionPostProcess : public PrimitiveC {
void SetNumClasses(int64_t num_classes);
void SetUseRegularNms(bool use_regular_nms);
#else
explicit DetectionPostProcess(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_DetectionPostProcess();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateDetectionPostProcess(fbb, attr->format(), attr->inputSize(),
attr->hScale(), attr->wScale(),
attr->xScale(), attr->yScale(),
attr->NmsIouThreshold(), attr->NmsScoreThreshold(),
attr->MaxDetections(), attr->DetectionsPreClass(),
attr->MaxClassesPreDetection(), attr->NumClasses(),
attr->UseRegularNms());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_DetectionPostProcess, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
DetectionPostProcess() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int GetFormat() const;
int GetInputSize() const;
......
......@@ -26,7 +26,19 @@ void Div::SetActivationType(int activation_type) {
}
#else
int Div::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Div();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Div return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateDiv(*fbb, attr->activationType());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Div, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int Div::GetActivationType() const { return this->primitive_->value_as_Div()->activationType(); }
#endif
......
......@@ -34,30 +34,9 @@ class Div : public Arithmetic {
void SetActivationType(int activation_type);
#else
explicit Div(schema::Primitive *primitive) : Arithmetic(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_Div();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateDiv(fbb, attr->activationType());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Div, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Div() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int GetActivationType() const;
};
......
......@@ -24,7 +24,19 @@ float Dropout::GetRatio() const { return this->primitive_->value.AsDropout()->ra
void Dropout::SetRatio(float ratio) { this->primitive_->value.AsDropout()->ratio = ratio; }
#else
int Dropout::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Dropout();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Dropout return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateDropout(*fbb, attr->ratio());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Dropout, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
float Dropout::GetRatio() const { return this->primitive_->value_as_Dropout()->ratio(); }
#endif
......
......@@ -34,30 +34,9 @@ class Dropout : public PrimitiveC {
void SetRatio(float ratio);
#else
explicit Dropout(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_Dropout();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateDropout(fbb, attr->ratio());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Dropout, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Dropout() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
float GetRatio() const;
};
......
......@@ -24,7 +24,19 @@ int Eltwise::GetMode() const { return this->primitive_->value.AsEltwise()->mode;
void Eltwise::SetMode(int mode) { this->primitive_->value.AsEltwise()->mode = (schema::EltwiseMode)mode; }
#else
int Eltwise::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Eltwise();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Eltwise return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateEltwise(*fbb, attr->mode());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Eltwise, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int Eltwise::GetMode() const { return this->primitive_->value_as_Eltwise()->mode(); }
#endif
......
......@@ -34,30 +34,9 @@ class Eltwise : public PrimitiveC {
void SetMode(int mode);
#else
explicit Eltwise(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_Eltwise();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateEltwise(fbb, attr->mode());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Eltwise, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Eltwise() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int GetMode() const;
};
......
......@@ -24,7 +24,19 @@ float Elu::GetAlpha() const { return this->primitive_->value.AsElu()->alpha; }
void Elu::SetAlpha(float alpha) { this->primitive_->value.AsElu()->alpha = alpha; }
#else
int Elu::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Elu();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Elu return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateElu(*fbb, attr->alpha());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Elu, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
float Elu::GetAlpha() const { return this->primitive_->value_as_Elu()->alpha(); }
#endif
......
......@@ -34,30 +34,9 @@ class Elu : public PrimitiveC {
void SetAlpha(float alpha);
#else
explicit Elu(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_Elu();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateElu(fbb, attr->alpha());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Elu, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Elu() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
float GetAlpha() const;
};
......
......@@ -24,7 +24,21 @@ float EmbeddingLookup::GetMaxNorm() const { return this->primitive_->value.AsEmb
void EmbeddingLookup::SetMaxNorm(float max_norm) { this->primitive_->value.AsEmbeddingLookup()->maxNorm = max_norm; }
#else
int EmbeddingLookup::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_EmbeddingLookup();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_EmbeddingLookup return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateEmbeddingLookup(*fbb, attr->maxNorm());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_EmbeddingLookup, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
float EmbeddingLookup::GetMaxNorm() const { return this->primitive_->value_as_EmbeddingLookup()->maxNorm(); }
#endif
......
......@@ -34,30 +34,9 @@ class EmbeddingLookup : public PrimitiveC {
void SetMaxNorm(float max_norm);
#else
explicit EmbeddingLookup(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_EmbeddingLookup();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateEmbeddingLookup(fbb, attr->maxNorm());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_EmbeddingLookup, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
EmbeddingLookup() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
float GetMaxNorm() const;
......
......@@ -38,7 +38,32 @@ void EmbeddingLookupSparse::SetMaxNortm(float max_nortm) {
}
#else
int EmbeddingLookupSparse::UnPackToFlatBuilder(const schema::Primitive *primitive,
flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_EmbeddingLookupSparse();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_EmbeddingLookupSparse return nullptr";
return RET_ERROR;
}
std::vector<int32_t> spIds;
if (attr->spIds() != nullptr) {
for (int i = 0; i < static_cast<int>(attr->spIds()->size()); i++) {
spIds.push_back(attr->spIds()->data()[i]);
}
}
std::vector<float> spWeights;
if (attr->spWeights() != nullptr) {
for (int i = 0; i < static_cast<int>(attr->spWeights()->size()); i++) {
spWeights.push_back(attr->spWeights()->data()[i]);
}
}
auto val_offset = schema::CreateEmbeddingLookupSparseDirect(*fbb, &spIds, &spWeights);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_EmbeddingLookupSparse, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
std::vector<int> EmbeddingLookupSparse::GetSpIds() const {
auto fb_vector = this->primitive_->value_as_EmbeddingLookupSparse()->spIds();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
......
......@@ -36,39 +36,9 @@ class EmbeddingLookupSparse : public PrimitiveC {
void SetSpWeights(const std::vector<float> &sp_weights);
void SetMaxNortm(float max_nortm);
#else
explicit EmbeddingLookupSparse(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_EmbeddingLookupSparse();
MS_ASSERT(attr != nullptr);
auto spIds = std::make_unique<std::vector<int32_t>>();
for (int i = 0; i < static_cast<int>(attr->spIds()->size()); i++) {
spIds->push_back(attr->spIds()->data()[i]);
}
auto spWeights = std::make_unique<std::vector<float>>();
for (int i = 0; i < static_cast<int>(attr->spWeights()->size()); i++) {
spWeights->push_back(attr->spWeights()->data()[i]);
}
auto val_offset = schema:: CreateEmbeddingLookupSparseDirect(fbb, spIds.release(), spWeights.release());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_EmbeddingLookupSparse, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
EmbeddingLookupSparse() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
std::vector<int> GetSpIds() const;
std::vector<float> GetSpWeights() const;
......
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/equal.h"
namespace mindspore {
namespace lite {
#ifndef PRIMITIVE_WRITEABLE
int Equal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateEqual(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Equal, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif
} // namespace lite
} // namespace mindspore
......@@ -32,27 +32,9 @@ class Equal : public Arithmetic {
Equal() = default;
explicit Equal(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#else
explicit Equal(schema::Primitive *primitive) : Arithmetic(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto val_offset = schema::CreateEqual(fbb);
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Equal, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Equal() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
};
} // namespace lite
......
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/exp.h"
namespace mindspore {
namespace lite {
#ifndef PRIMITIVE_WRITEABLE
int Exp::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateExp(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Exp, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif
} // namespace lite
} // namespace mindspore
......@@ -32,27 +32,9 @@ class Exp : public ArithmeticSelf {
Exp() = default;
explicit Exp(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
#else
explicit Exp(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto val_offset = schema::CreateExp(fbb);
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Exp, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Exp() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
};
} // namespace lite
......
......@@ -24,7 +24,20 @@ int ExpandDims::GetDim() const { return this->primitive_->value.AsExpandDims()->
void ExpandDims::SetDim(int dim) { this->primitive_->value.AsExpandDims()->dim = dim; }
#else
int ExpandDims::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_ExpandDims();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_ExpandDims return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateExpandDims(*fbb, attr->dim());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_ExpandDims, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int ExpandDims::GetDim() const { return this->primitive_->value_as_ExpandDims()->dim(); }
#endif
......
......@@ -34,30 +34,9 @@ class ExpandDims : public PrimitiveC {
void SetDim(int dim);
#else
explicit ExpandDims(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_ExpandDims();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateExpandDims(fbb, attr->dim());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_ExpandDims, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
ExpandDims() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetDim() const;
......
......@@ -32,7 +32,21 @@ void FakeQuantWithMinMaxVars::SetNumBits(int num_bits) {
}
#else
int FakeQuantWithMinMaxVars::UnPackToFlatBuilder(const schema::Primitive *primitive,
flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_FakeQuantWithMinMaxVars();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_FakeQuantWithMinMaxVars return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateFakeQuantWithMinMaxVars(*fbb, attr->narrowRange(), attr->numBits());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FakeQuantWithMinMaxVars, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
bool FakeQuantWithMinMaxVars::GetNarrowRange() const {
return this->primitive_->value_as_FakeQuantWithMinMaxVars()->narrowRange();
}
......
......@@ -34,31 +34,9 @@ class FakeQuantWithMinMaxVars : public PrimitiveC {
void SetNarrowRange(bool narrow_range);
void SetNumBits(int num_bits);
#else
explicit FakeQuantWithMinMaxVars(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_FakeQuantWithMinMaxVars();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateFakeQuantWithMinMaxVars(fbb, attr->narrowRange(), attr->numBits());
auto prim_offset = schema::CreatePrimitive(fbb,
schema::PrimitiveType_FakeQuantWithMinMaxVars, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
FakeQuantWithMinMaxVars() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
bool GetNarrowRange() const;
int GetNumBits() const;
......
......@@ -24,7 +24,25 @@ std::vector<int> Fill::GetDims() const { return this->primitive_->value.AsFill()
void Fill::SetDims(const std::vector<int> &dims) { this->primitive_->value.AsFill()->dims = dims; }
#else
int Fill::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Fill();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Fill return nullptr";
return RET_ERROR;
}
std::vector<int32_t> dims;
if (attr->dims() != nullptr) {
for (int i = 0; i < static_cast<int>(attr->dims()->size()); i++) {
dims.push_back(attr->dims()->data()[i]);
}
}
auto val_offset = schema::CreateFillDirect(*fbb, &dims);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Fill, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
std::vector<int> Fill::GetDims() const {
auto fb_vector = this->primitive_->value_as_Fill()->dims();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
......
......@@ -35,35 +35,9 @@ class Fill : public PrimitiveC {
void SetDims(const std::vector<int> &dims);
#else
explicit Fill(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_Fill();
MS_ASSERT(attr != nullptr);
auto dims = std::make_unique<std::vector<int32_t>>();
for (int i = 0; i < static_cast<int>(attr->dims()->size()); i++) {
dims->push_back(attr->dims()->data()[i]);
}
auto val_offset = schema::CreateFillDirect(fbb, dims.release());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Fill, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Fill() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
std::vector<int> GetDims() const;
......
......@@ -77,6 +77,15 @@ int Flatten::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
}
return RET_OK;
}
#else
int Flatten::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateFlatten(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Flatten, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif
} // namespace lite
} // namespace mindspore
......@@ -31,32 +31,13 @@ class Flatten : public PrimitiveC {
MS_DECLARE_PARENT(Flatten, PrimitiveC);
Flatten() = default;
explicit Flatten(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
explicit Flatten(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto val_offset = schema::CreateFlatten(fbb);
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Flatten, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Flatten() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs);
};
} // namespace lite
} // namespace mindspore
......
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/floor.h"
namespace mindspore {
namespace lite {
#ifndef PRIMITIVE_WRITEABLE
int Floor::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateFloor(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Floor, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif
} // namespace lite
} // namespace mindspore
......@@ -21,7 +21,7 @@
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#include "src/ops/arithmetic_self.h"
namespace mindspore {
namespace lite {
......@@ -32,27 +32,9 @@ class Floor : public ArithmeticSelf {
Floor() = default;
explicit Floor(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
#else
explicit Floor(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto val_offset = schema::CreateFloor(fbb);
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Floor, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Floor() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
};
} // namespace lite
......
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/floor_div.h"
namespace mindspore {
namespace lite {
#ifndef PRIMITIVE_WRITEABLE
int FloorDiv::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateFloor(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Floor, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif
} // namespace lite
} // namespace mindspore
......@@ -32,27 +32,9 @@ class FloorDiv : public Arithmetic {
FloorDiv() = default;
explicit FloorDiv(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#else
explicit FloorDiv(schema::Primitive *primitive) : Arithmetic(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto val_offset = schema::CreateFloorDiv(fbb);
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_FloorDiv, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
FloorDiv() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
};
} // namespace lite
......
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/floor_mod.h"
namespace mindspore {
namespace lite {
#ifndef PRIMITIVE_WRITEABLE
int FloorMod::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto val_offset = schema::CreateFloorMod(*fbb);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FloorMod, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
#endif
} // namespace lite
} // namespace mindspore
......@@ -32,27 +32,9 @@ class FloorMod : public Arithmetic {
FloorMod() = default;
explicit FloorMod(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#else
explicit FloorMod(schema::Primitive *primitive) : Arithmetic(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto val_offset = schema::CreateFloorMod(fbb);
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_FloorMod, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
FloorMod() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
};
} // namespace lite
......
......@@ -31,7 +31,21 @@ void FullConnection::SetActivationType(int activationType) {
this->primitive_->value.AsFullConnection()->activationType = (schema::ActivationType)activationType;
}
#else
int FullConnection::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_FullConnection();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_FullConnection return nullptr";
return RET_ERROR;
}
auto val_offset =
schema::CreateFullConnection(*fbb, attr->hasBias(), attr->axis(), attr->useAxis(), attr->activationType());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FullConnection, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
bool FullConnection::GetHasBias() const { return this->primitive_->value_as_FullConnection()->hasBias(); }
int FullConnection::GetAxis() const { return this->primitive_->value_as_FullConnection()->axis(); }
bool FullConnection::GetUseAxis() const { return this->primitive_->value_as_FullConnection()->useAxis(); }
......
......@@ -36,31 +36,9 @@ class FullConnection : public PrimitiveC {
void SetUseAxis(bool use_axis);
void SetActivationType(int activationType);
#else
explicit FullConnection(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_FullConnection();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateFullConnection(fbb, attr->hasBias(), attr->axis(),
attr->useAxis(), attr->activationType());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_FullConnection, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
FullConnection() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
bool GetHasBias() const;
......
......@@ -28,7 +28,20 @@ void FusedBatchNorm::SetMomentum(float momentum) { this->primitive_->value.AsFus
void FusedBatchNorm::SetSpatial(int spatial) { this->primitive_->value.AsFusedBatchNorm()->spatial = spatial; }
#else
int FusedBatchNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_FusedBatchNorm();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_FusedBatchNorm return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateFusedBatchNorm(*fbb, attr->epsilon(), attr->momentum(), attr->spatial());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FusedBatchNorm, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
float FusedBatchNorm::GetEpsilon() const { return this->primitive_->value_as_FusedBatchNorm()->epsilon(); }
float FusedBatchNorm::GetMomentum() const { return this->primitive_->value_as_FusedBatchNorm()->momentum(); }
int FusedBatchNorm::GetSpatial() const { return this->primitive_->value_as_FusedBatchNorm()->spatial(); }
......
......@@ -35,30 +35,9 @@ class FusedBatchNorm : public PrimitiveC {
void SetMomentum(float momentum);
void SetSpatial(int spatial);
#else
explicit FusedBatchNorm(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_FusedBatchNorm();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateFusedBatchNorm(fbb, attr->epsilon(), attr->momentum(), attr->spatial());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_FusedBatchNorm, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
FusedBatchNorm() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
float GetEpsilon() const;
float GetMomentum() const;
......
......@@ -29,7 +29,20 @@ void Gather::SetAxis(int axis) { this->primitive_->value.AsGather()->axis = axis
void Gather::SetBatchDims(int batch_dims) { this->primitive_->value.AsGather()->batchDims = batch_dims; }
#else
int Gather::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Gather();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Gather return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateGather(*fbb, attr->axis(), attr->batchDims());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Gather, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int Gather::GetAxis() const { return this->primitive_->value_as_Gather()->axis(); }
int Gather::GetBatchDims() const { return this->primitive_->value_as_Gather()->batchDims(); }
......
......@@ -34,30 +34,9 @@ class Gather : public PrimitiveC {
void SetAxis(int axis);
void SetBatchDims(int batch_dims);
#else
explicit Gather(schema::Primitive *primitive) : PrimitiveC(primitive) {}
schema::Primitive *Init(schema::Primitive *primitive) {
flatbuffers::FlatBufferBuilder fbb(1024);
auto attr = primitive->value_as_Gather();
MS_ASSERT(attr != nullptr);
auto val_offset = schema::CreateGather(fbb, attr->axis(), attr->batchDims());
auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_Gather, val_offset.o);
fbb.Finish(prim_offset);
auto buf = fbb.GetBufferPointer();
MS_ASSERT(buf != nullptr);
auto buf_bak = new char[fbb.GetSize()];
memcpy(buf_bak, buf, fbb.GetSize());
auto root = flatbuffers::GetRoot<schema::Primitive>(buf_bak);
auto prim = const_cast<schema::Primitive *>(root);
Gather() = default;
delete[] buf_bak;
fbb.Clear();
return prim;
}
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetAxis() const;
......
......@@ -24,7 +24,20 @@ int GatherNd::GetBatchDims() const { return this->primitive_->value.AsGatherNd()
void GatherNd::SetBatchDims(int batch_dims) { this->primitive_->value.AsGatherNd()->batchDims = batch_dims; }
#else
int GatherNd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_GatherNd();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_GatherNd return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateGatherNd(*fbb, attr->batchDims());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_GatherNd, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
int GatherNd::GetBatchDims() const { return this->primitive_->value_as_GatherNd()->batchDims(); }
#endif
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册