提交 e5b563e6 编写于 作者: S superjomn

update

上级 cf6b65a1
...@@ -2,6 +2,7 @@ cc_library(executor_lite SRCS executor.cc) ...@@ -2,6 +2,7 @@ cc_library(executor_lite SRCS executor.cc)
cc_library(op_lite SRCS op_lite.cc) cc_library(op_lite SRCS op_lite.cc)
cc_library(memory_lite SRCS memory.cc) cc_library(memory_lite SRCS memory.cc)
cc_library(tensor_lite SRCS tensor.cc DEPS memory_lite) cc_library(tensor_lite SRCS tensor.cc DEPS memory_lite)
cc_library(variable_lite SRCS variable.cc)
cc_library(op_registry_lite SRCS op_registry.cc) cc_library(op_registry_lite SRCS op_registry.cc)
add_subdirectory(x86) add_subdirectory(x86)
......
...@@ -55,7 +55,8 @@ class OpContext final { ...@@ -55,7 +55,8 @@ class OpContext final {
explicit OpContext(TargetType target) explicit OpContext(TargetType target)
: targets_(std::vector<TargetType>({target})) {} : targets_(std::vector<TargetType>({target})) {}
// @param target valid target. // @param target valid target.
explicit OpContext(const std::vector<TargetType>& target) : targets_(target) {} explicit OpContext(const std::vector<TargetType>& target)
: targets_(target) {}
const std::vector<TargetType>& target() const { return targets_; } const std::vector<TargetType>& target() const { return targets_; }
......
cc_library(model_parser SRCS model_parser.cc) cc_library(model_parser_lite SRCS model_parser.cc)
cc_library(runtime_lite SRCS runtime.cc)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/model_parser/runtime.h"
#include "runtime.h"
#include <glog/logging.h>
namespace paddle {
namespace lite {
void VarDesc::Parse(const framework::proto::VarDesc& desc) {
name = desc.name();
this->persistable = desc.persistable();
type.Parse(desc.type());
}
void OpDesc::Parse(const framework::proto::OpDesc& desc) {
op_type = desc.type();
// prepare inputs
for (const auto& input : desc.inputs()) {
for (const auto& arg : input.arguments()) {
inputs[input.parameter()].push_back(arg);
}
}
// prepare outputs
for (const auto& output : desc.inputs()) {
for (const auto& arg : output.arguments()) {
inputs[output.parameter()].push_back(arg);
}
}
// prepare attributes
for (const auto& attr : desc.attrs()) {
switch (static_cast<int>(attr.type())) {
case framework::proto::AttrType::INT:
attrs[attr.name()] = attr.i();
break;
case framework::proto::AttrType::FLOAT:
attrs[attr.name()] = attr.f();
break;
case framework::proto::AttrType::STRING:
attrs[attr.name()] = attr.s();
break;
case framework::proto::AttrType::INTS:
attrs[attr.name()] = attr.ints();
break;
case framework::proto::AttrType::FLOATS:
attrs[attr.name()] = attr.floats();
break;
case framework::proto::AttrType::STRINGS:
attrs[attr.name()] = attr.strings();
break;
case framework::proto::AttrType::BOOLEAN:
attrs[attr.name()] = attr.b();
break;
case framework::proto::AttrType::BOOLEANS:
attrs[attr.name()] = attr.bools();
break;
case framework::proto::AttrType::LONG:
attrs[attr.name()] = attr.l();
break;
case framework::proto::AttrType::LONGS:
attrs[attr.name()] = attr.longs();
break;
case framework::proto::AttrType::BLOCK:
attrs[attr.name()] = attr.block_idx();
break;
case framework::proto::AttrType::BLOCKS:
attrs[attr.name()] = attr.blocks_idx();
break;
default:
LOG(ERROR) << "unknown attribute type found";
}
}
}
void BlockDesc::Parse(const framework::proto::BlockDesc& desc) {
idx = desc.idx();
parent_idx = desc.parent_idx();
}
void VarType::Parse(const framework::proto::VarType& proto) {
switch (static_cast<int>(proto.type())) {
case framework::proto::VarType_Type::VarType_Type_LOD_TENSOR:
desc = LoDTensorDesc(proto.lod_tensor());
break;
case framework::proto::VarType_Type::VarType_Type_LOD_TENSOR_ARRAY:
desc = LoDTensorArrayDesc(proto.tensor_array());
break;
default:
LOG(ERROR) << "no valid var type found";
return;
}
}
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
// We define the runtime data structure for framework.proto to support some
// other model format such as JSON if needed.
using proto_type_t = framework::proto::VarType::Type;
class TensorDesc {
public:
proto_type_t data_type;
std::vector<int64_t> dims;
TensorDesc() = default;
explicit TensorDesc(const framework::proto::VarType_TensorDesc& proto) {
Parse(proto);
}
void Parse(const framework::proto::VarType_TensorDesc& proto) {
data_type = proto.data_type();
for (auto& d : proto.dims()) dims.push_back(d);
}
};
class LoDTensorDesc {
public:
TensorDesc tensor;
int lod_level{-1};
LoDTensorDesc(const framework::proto::VarType_LoDTensorDesc& proto) {
Parse(proto);
}
void Parse(const framework::proto::VarType_LoDTensorDesc& proto) {
tensor.Parse(proto.tensor());
lod_level = proto.lod_level();
}
};
class LoDTensorArrayDesc {
public:
TensorDesc tensor;
int lod_level{-1};
LoDTensorArrayDesc(
const framework::proto::VarType_LoDTensorArrayDesc& proto) {
Parse(proto);
}
void Parse(const framework::proto::VarType_LoDTensorArrayDesc& proto) {
tensor.Parse(proto.tensor());
lod_level = proto.lod_level();
}
};
class VarType {
public:
framework::proto::VarType::Type type;
any desc;
void Parse(const framework::proto::VarType& proto);
};
class VarDesc {
public:
void Parse(const framework::proto::VarDesc& desc);
std::string name;
VarType type;
bool persistable{false};
};
class OpDesc {
public:
void Parse(const framework::proto::OpDesc& desc);
std::string op_type;
std::map<std::string, std::vector<std::string>> inputs;
std::map<std::string, std::vector<std::string>> outputs;
std::map<std::string, any> attrs;
};
class BlockDesc {
public:
void Parse(const framework::proto::BlockDesc& desc);
int idx{-1};
int parent_idx{-1};
int forward_block_idx{-1};
std::map<std::string, VarDesc> vars;
std::vector<OpDesc> ops;
};
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/scope.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <list>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace paddle {
namespace lite {
class Scope final {
public:
Scope() {}
~Scope();
Scope& NewScope() const;
Variable* Var(std::string* name = nullptr);
Variable* FindVar(const std::string& name) const;
Variable* FindLocalVar(const std::string& name) const;
const Scope* parent() const { return parent_; }
private:
// Scope in `kids_` are owned by this class.
mutable std::list<Scope*> kids_;
const Scope* parent_{nullptr};
};
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/variable.h"
namespace paddle {
namespace lite {} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
class Variable {
public:
template <typename T>
T& Get() {
return blob_;
}
template <typename T>
T* GetMutable() {
return any_cast<T>(&blob_);
}
private:
any blob_;
};
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册