未验证 提交 bdff2240 编写于 作者: 石晓伟 提交者: GitHub

speed up vector assignment, test=develop (#4098)

* speed up vector assignment, test=develop

* update op_desc with c-string, test=develop
上级 04b2c9fa
......@@ -83,9 +83,9 @@ class VectorView {
operator std::vector<T>() const {
VLOG(5) << "Copying elements out of VectorView will damage performance.";
std::vector<T> tmp;
tmp.reserve(size());
tmp.resize(size());
for (size_t i = 0; i < size(); ++i) {
tmp.push_back(cvec_->operator[](i));
tmp[i] = cvec_->operator[](i);
}
return tmp;
}
......
......@@ -30,13 +30,13 @@ class BlockDescView : public BlockDescAPI {
public:
explicit BlockDescView(proto::BlockDesc const* desc) : desc_(desc) {
CHECK(desc_);
vars_.reserve(VarsSize());
ops_.reserve(OpsSize());
vars_.resize(VarsSize());
ops_.resize(OpsSize());
for (size_t idx = 0; idx < VarsSize(); ++idx) {
vars_.push_back(VarDescView(desc_->vars()->Get(idx)));
vars_[idx] = VarDescView(desc_->vars()->Get(idx));
}
for (size_t idx = 0; idx < OpsSize(); ++idx) {
ops_.push_back(OpDescView(desc_->ops()->Get(idx)));
ops_[idx] = OpDescView(desc_->ops()->Get(idx));
}
}
......@@ -76,7 +76,7 @@ class BlockDescView : public BlockDescAPI {
return desc_->forward_block_idx();
}
BlockDescView() { NotImplemented(); }
BlockDescView() = default;
private:
proto::BlockDesc const* desc_; // not_own
......
......@@ -19,8 +19,8 @@ namespace lite {
namespace fbs {
template <>
std::string OpDescView::GetAttr<std::string>(const std::string& name) const {
const auto& it = desc_->attrs()->LookupByKey(name.c_str());
std::string OpDescView::GetAttr<std::string>(const char* name) const {
const auto& it = desc_->attrs()->LookupByKey(name);
if (!it->s()) {
return std::string();
}
......@@ -28,56 +28,48 @@ std::string OpDescView::GetAttr<std::string>(const std::string& name) const {
}
template <>
std::string OpDescView::GetAttr<std::string>(size_t idx) const {
const auto& it = desc_->attrs()->Get(idx);
if (!it->s()) {
return std::string();
}
return it->s()->str();
std::string OpDescView::GetAttr<std::string>(const std::string& name) const {
return GetAttr<std::string>(name.c_str());
}
template <>
lite::VectorView<std::string, Flatbuffers>
OpDescView::GetAttr<std::vector<std::string>>(const std::string& name) const {
const auto& it = desc_->attrs()->LookupByKey(name.c_str());
OpDescView::GetAttr<std::vector<std::string>>(const char* name) const {
const auto& it = desc_->attrs()->LookupByKey(name);
CHECK(it) << "Attr " << name << "does not exist.";
return VectorView<std::string>(it->strings());
}
template <>
VectorView<std::string, Flatbuffers>
OpDescView::GetAttr<std::vector<std::string>>(size_t idx) const {
const auto& it = desc_->attrs()->Get(idx);
CHECK(it) << "Attr " << idx << "does not exist.";
return VectorView<std::string>(it->strings());
lite::VectorView<std::string, Flatbuffers>
OpDescView::GetAttr<std::vector<std::string>>(const std::string& name) const {
return GetAttr<std::vector<std::string>>(name.c_str());
}
#define GET_ATTR_IMPL(T, fb_f__) \
template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
const std::string& name) const { \
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \
const char* name) const { \
const auto& it = desc_->attrs()->LookupByKey(name); \
return it->fb_f__(); \
} \
template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
size_t idx) const { \
const auto& it = desc_->attrs()->Get(idx); \
return it->fb_f__(); \
const std::string& name) const { \
return GetAttr<T>(name.c_str()); \
}
#define GET_ATTRS_IMPL(T, fb_f__) \
template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
const std::string& name) const { \
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \
const char* name) const { \
const auto& it = desc_->attrs()->LookupByKey(name); \
return typename lite::OpDataTypeTrait<T, Flatbuffers>::RT(it->fb_f__()); \
} \
template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
size_t idx) const { \
const auto& it = desc_->attrs()->Get(idx); \
return typename lite::OpDataTypeTrait<T, Flatbuffers>::RT(it->fb_f__()); \
const std::string& name) const { \
return GetAttr<T>(name.c_str()); \
}
GET_ATTR_IMPL(int32_t, i);
......
......@@ -36,57 +36,68 @@ class OpDescView : public OpDescAPI {
std::string Type() const override { return desc_->type()->str(); }
// Get the arguments of parameter called `param`
std::vector<std::string> Input(const std::string& param) const override {
const auto& var = desc_->inputs()->LookupByKey(param.c_str());
std::vector<std::string> Input(const char* param) const {
const auto& var = desc_->inputs()->LookupByKey(param);
std::vector<std::string> args_vec;
if (var->arguments()) {
args_vec.reserve(var->arguments()->size());
for (const auto& in : *var->arguments()) {
args_vec.push_back(in->str());
if (var && var->arguments()) {
args_vec.resize(var->arguments()->size());
for (size_t i = 0; i < var->arguments()->size(); ++i) {
args_vec[i] = (*var->arguments())[i]->str();
}
}
return args_vec;
}
std::vector<std::string> Input(const std::string& param) const override {
return Input(param.c_str());
}
std::vector<std::string> InputArgumentNames() const override {
const auto& vars = desc_->inputs();
std::vector<std::string> input_names_vec;
if (vars) {
input_names_vec.reserve(vars->size());
for (const auto& in : *vars) {
input_names_vec.push_back(in->parameter()->str());
input_names_vec.resize(vars->size());
for (size_t i = 0; i < vars->size(); ++i) {
input_names_vec[i] = (*vars)[i]->parameter()->str();
}
}
return input_names_vec;
}
std::vector<std::string> Output(const std::string& param) const override {
const auto& var = desc_->outputs()->LookupByKey(param.c_str());
std::vector<std::string> Output(const char* param) const {
const auto& var = desc_->outputs()->LookupByKey(param);
std::vector<std::string> args_vec;
if (var && var->arguments()) {
args_vec.reserve(var->arguments()->size());
for (const auto& out : *var->arguments()) {
args_vec.push_back(out->str());
args_vec.resize(var->arguments()->size());
for (size_t i = 0; i < var->arguments()->size(); ++i) {
args_vec[i] = (*var->arguments())[i]->str();
}
}
return args_vec;
}
std::vector<std::string> Output(const std::string& param) const override {
return Output(param.c_str());
}
std::vector<std::string> OutputArgumentNames() const override {
const auto& vars = desc_->outputs();
std::vector<std::string> output_names_vec;
if (vars) {
output_names_vec.reserve(vars->size());
for (const auto& out : *vars) {
output_names_vec.push_back(out->parameter()->str());
output_names_vec.resize(vars->size());
for (size_t i = 0; i < vars->size(); ++i) {
output_names_vec[i] = (*vars)[i]->parameter()->str();
}
}
return output_names_vec;
}
bool HasAttr(const char* name) const {
return desc_->attrs()->LookupByKey(name) != nullptr;
}
bool HasAttr(const std::string& name) const override {
return desc_->attrs()->LookupByKey(name.c_str()) != nullptr;
return HasAttr(name.c_str());
}
size_t AttrsSize() const { return desc_->attrs()->size(); }
......@@ -95,25 +106,23 @@ class OpDescView : public OpDescAPI {
return desc_->attrs()->Get(idx)->name()->str();
}
OpDescAPI::AttrType GetAttrType(const std::string& name) const override {
const auto& attr = desc_->attrs()->LookupByKey(name.c_str());
OpDescAPI::AttrType GetAttrType(const char* name) const {
const auto& attr = desc_->attrs()->LookupByKey(name);
CHECK(attr) << "Can not find attr: " << name;
return ConvertAttrType(attr->type());
}
OpDescAPI::AttrType GetAttrType(size_t idx) const {
const auto& attr = desc_->attrs()->Get(idx);
CHECK(attr);
return ConvertAttrType(attr->type());
OpDescAPI::AttrType GetAttrType(const std::string& name) const override {
return GetAttrType(name.c_str());
}
std::vector<std::string> AttrNames() const override {
const auto& attrs = desc_->attrs();
std::vector<std::string> attr_names_vec;
if (attrs) {
attr_names_vec.reserve(attrs->size());
for (const auto& attr : *attrs) {
attr_names_vec.push_back(attr->name()->str());
attr_names_vec.resize(attrs->size());
for (size_t i = 0; i < attrs->size(); ++i) {
attr_names_vec[i] = (*attrs)[i]->name()->str();
}
}
return attr_names_vec;
......@@ -121,10 +130,11 @@ class OpDescView : public OpDescAPI {
template <typename T>
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT GetAttr(
const std::string& name) const;
const char* name) const;
template <typename T>
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT GetAttr(size_t idx) const;
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT GetAttr(
const std::string& name) const;
private:
proto::OpDesc const* desc_;
......@@ -138,7 +148,7 @@ class OpDescView : public OpDescAPI {
// caused by different building options.
public:
OpDescView() { NotImplemented(); }
OpDescView() = default;
bool HasInput(const std::string& param) const {
return desc_->inputs()->LookupByKey(param.c_str()) != nullptr;
}
......
......@@ -42,9 +42,9 @@ class ParamDescView : public ParamDescReadAPI {
std::vector<int64_t> Dim() const override {
const auto& dims = tensor_desc_->dim();
std::vector<int64_t> dims_vec;
dims_vec.reserve(dims->size());
for (const auto& dim : *dims) {
dims_vec.push_back(dim);
dims_vec.resize(dims->size());
for (size_t i = 0; i < dims->size(); ++i) {
dims_vec[i] = dims->operator[](i);
}
return dims_vec;
}
......@@ -57,7 +57,7 @@ class ParamDescView : public ParamDescReadAPI {
size_t byte_size() const override { return tensor_desc_->data()->size(); }
ParamDescView() = delete;
ParamDescView() = default;
private:
proto::ParamDesc const* desc_;
......@@ -87,9 +87,9 @@ class CombinedParamsDescView : public CombinedParamsDescReadAPI {
void InitParams() {
desc_ = proto::GetCombinedParamsDesc(buf_.data());
size_t params_size = desc_->params()->size();
params_.reserve(params_size);
params_.resize(params_size);
for (size_t idx = 0; idx < params_size; ++idx) {
params_.push_back(ParamDescView(desc_->params()->Get(idx)));
params_[idx] = ParamDescView(desc_->params()->Get(idx));
}
}
......
......@@ -48,9 +48,9 @@ class ProgramDescView : public ProgramDescAPI {
void InitProgramDesc() {
desc_ = proto::GetProgramDesc(buf_.data());
blocks_.reserve(BlocksSize());
blocks_.resize(BlocksSize());
for (size_t idx = 0; idx < BlocksSize(); ++idx) {
blocks_.push_back(BlockDescView(desc_->blocks()->Get(idx)));
blocks_[idx] = BlockDescView(desc_->blocks()->Get(idx));
}
}
......
......@@ -42,9 +42,9 @@ class VarDescView : public VarDescAPI {
CHECK(GetType() == VarDescAPI::Type::LOD_TENSOR);
const auto& dims = desc_->type()->lod_tensor()->tensor()->dims();
std::vector<int64_t> dims_vec;
dims_vec.reserve(dims->size());
for (const auto& dim : *dims) {
dims_vec.push_back(dim);
dims_vec.resize(dims->size());
for (size_t i = 0; i < dims->size(); ++i) {
dims_vec[i] = dims->operator[](i);
}
return dims_vec;
}
......@@ -66,7 +66,7 @@ class VarDescView : public VarDescAPI {
// caused by different building options.
public:
VarDescView() { NotImplemented(); }
VarDescView() = default;
void SetDataType(Type data_type) { NotImplemented(); }
void SetShape(const std::vector<int64_t>& dims) { NotImplemented(); }
......
......@@ -127,9 +127,9 @@ class VectorView<std::string, Flatbuffers> {
operator std::vector<std::string>() const {
VLOG(5) << "Copying elements out of VectorView will damage performance.";
std::vector<std::string> tmp;
tmp.reserve(size());
tmp.resize(size());
for (size_t i = 0; i < size(); ++i) {
tmp.push_back(cvec_->operator[](i)->str());
tmp[i] = cvec_->operator[](i)->str();
}
return tmp;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册