未验证 提交 bdff2240 编写于 作者: 石晓伟 提交者: GitHub

speed up vector assignment, test=develop (#4098)

* speed up vector assignment, test=develop

* update op_desc with c-string, test=develop
上级 04b2c9fa
...@@ -83,9 +83,9 @@ class VectorView { ...@@ -83,9 +83,9 @@ class VectorView {
operator std::vector<T>() const { operator std::vector<T>() const {
VLOG(5) << "Copying elements out of VectorView will damage performance."; VLOG(5) << "Copying elements out of VectorView will damage performance.";
std::vector<T> tmp; std::vector<T> tmp;
tmp.reserve(size()); tmp.resize(size());
for (size_t i = 0; i < size(); ++i) { for (size_t i = 0; i < size(); ++i) {
tmp.push_back(cvec_->operator[](i)); tmp[i] = cvec_->operator[](i);
} }
return tmp; return tmp;
} }
......
...@@ -30,13 +30,13 @@ class BlockDescView : public BlockDescAPI { ...@@ -30,13 +30,13 @@ class BlockDescView : public BlockDescAPI {
public: public:
explicit BlockDescView(proto::BlockDesc const* desc) : desc_(desc) { explicit BlockDescView(proto::BlockDesc const* desc) : desc_(desc) {
CHECK(desc_); CHECK(desc_);
vars_.reserve(VarsSize()); vars_.resize(VarsSize());
ops_.reserve(OpsSize()); ops_.resize(OpsSize());
for (size_t idx = 0; idx < VarsSize(); ++idx) { for (size_t idx = 0; idx < VarsSize(); ++idx) {
vars_.push_back(VarDescView(desc_->vars()->Get(idx))); vars_[idx] = VarDescView(desc_->vars()->Get(idx));
} }
for (size_t idx = 0; idx < OpsSize(); ++idx) { for (size_t idx = 0; idx < OpsSize(); ++idx) {
ops_.push_back(OpDescView(desc_->ops()->Get(idx))); ops_[idx] = OpDescView(desc_->ops()->Get(idx));
} }
} }
...@@ -76,7 +76,7 @@ class BlockDescView : public BlockDescAPI { ...@@ -76,7 +76,7 @@ class BlockDescView : public BlockDescAPI {
return desc_->forward_block_idx(); return desc_->forward_block_idx();
} }
BlockDescView() { NotImplemented(); } BlockDescView() = default;
private: private:
proto::BlockDesc const* desc_; // not_own proto::BlockDesc const* desc_; // not_own
......
...@@ -19,8 +19,8 @@ namespace lite { ...@@ -19,8 +19,8 @@ namespace lite {
namespace fbs { namespace fbs {
template <> template <>
std::string OpDescView::GetAttr<std::string>(const std::string& name) const { std::string OpDescView::GetAttr<std::string>(const char* name) const {
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); const auto& it = desc_->attrs()->LookupByKey(name);
if (!it->s()) { if (!it->s()) {
return std::string(); return std::string();
} }
...@@ -28,56 +28,48 @@ std::string OpDescView::GetAttr<std::string>(const std::string& name) const { ...@@ -28,56 +28,48 @@ std::string OpDescView::GetAttr<std::string>(const std::string& name) const {
} }
template <> template <>
std::string OpDescView::GetAttr<std::string>(size_t idx) const { std::string OpDescView::GetAttr<std::string>(const std::string& name) const {
const auto& it = desc_->attrs()->Get(idx); return GetAttr<std::string>(name.c_str());
if (!it->s()) {
return std::string();
}
return it->s()->str();
} }
template <> template <>
lite::VectorView<std::string, Flatbuffers> lite::VectorView<std::string, Flatbuffers>
OpDescView::GetAttr<std::vector<std::string>>(const std::string& name) const { OpDescView::GetAttr<std::vector<std::string>>(const char* name) const {
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); const auto& it = desc_->attrs()->LookupByKey(name);
CHECK(it) << "Attr " << name << "does not exist."; CHECK(it) << "Attr " << name << "does not exist.";
return VectorView<std::string>(it->strings()); return VectorView<std::string>(it->strings());
} }
template <> template <>
VectorView<std::string, Flatbuffers> lite::VectorView<std::string, Flatbuffers>
OpDescView::GetAttr<std::vector<std::string>>(size_t idx) const { OpDescView::GetAttr<std::vector<std::string>>(const std::string& name) const {
const auto& it = desc_->attrs()->Get(idx); return GetAttr<std::vector<std::string>>(name.c_str());
CHECK(it) << "Attr " << idx << "does not exist.";
return VectorView<std::string>(it->strings());
} }
#define GET_ATTR_IMPL(T, fb_f__) \ #define GET_ATTR_IMPL(T, fb_f__) \
template <> \ template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \ typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
const std::string& name) const { \ const char* name) const { \
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \ const auto& it = desc_->attrs()->LookupByKey(name); \
return it->fb_f__(); \ return it->fb_f__(); \
} \ } \
template <> \ template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \ typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
size_t idx) const { \ const std::string& name) const { \
const auto& it = desc_->attrs()->Get(idx); \ return GetAttr<T>(name.c_str()); \
return it->fb_f__(); \
} }
#define GET_ATTRS_IMPL(T, fb_f__) \ #define GET_ATTRS_IMPL(T, fb_f__) \
template <> \ template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \ typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
const std::string& name) const { \ const char* name) const { \
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \ const auto& it = desc_->attrs()->LookupByKey(name); \
return typename lite::OpDataTypeTrait<T, Flatbuffers>::RT(it->fb_f__()); \ return typename lite::OpDataTypeTrait<T, Flatbuffers>::RT(it->fb_f__()); \
} \ } \
template <> \ template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \ typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
size_t idx) const { \ const std::string& name) const { \
const auto& it = desc_->attrs()->Get(idx); \ return GetAttr<T>(name.c_str()); \
return typename lite::OpDataTypeTrait<T, Flatbuffers>::RT(it->fb_f__()); \
} }
GET_ATTR_IMPL(int32_t, i); GET_ATTR_IMPL(int32_t, i);
......
...@@ -36,57 +36,68 @@ class OpDescView : public OpDescAPI { ...@@ -36,57 +36,68 @@ class OpDescView : public OpDescAPI {
std::string Type() const override { return desc_->type()->str(); } std::string Type() const override { return desc_->type()->str(); }
// Get the arguments of parameter called `param` std::vector<std::string> Input(const char* param) const {
std::vector<std::string> Input(const std::string& param) const override { const auto& var = desc_->inputs()->LookupByKey(param);
const auto& var = desc_->inputs()->LookupByKey(param.c_str());
std::vector<std::string> args_vec; std::vector<std::string> args_vec;
if (var->arguments()) { if (var && var->arguments()) {
args_vec.reserve(var->arguments()->size()); args_vec.resize(var->arguments()->size());
for (const auto& in : *var->arguments()) { for (size_t i = 0; i < var->arguments()->size(); ++i) {
args_vec.push_back(in->str()); args_vec[i] = (*var->arguments())[i]->str();
} }
} }
return args_vec; return args_vec;
} }
std::vector<std::string> Input(const std::string& param) const override {
return Input(param.c_str());
}
std::vector<std::string> InputArgumentNames() const override { std::vector<std::string> InputArgumentNames() const override {
const auto& vars = desc_->inputs(); const auto& vars = desc_->inputs();
std::vector<std::string> input_names_vec; std::vector<std::string> input_names_vec;
if (vars) { if (vars) {
input_names_vec.reserve(vars->size()); input_names_vec.resize(vars->size());
for (const auto& in : *vars) { for (size_t i = 0; i < vars->size(); ++i) {
input_names_vec.push_back(in->parameter()->str()); input_names_vec[i] = (*vars)[i]->parameter()->str();
} }
} }
return input_names_vec; return input_names_vec;
} }
std::vector<std::string> Output(const std::string& param) const override { std::vector<std::string> Output(const char* param) const {
const auto& var = desc_->outputs()->LookupByKey(param.c_str()); const auto& var = desc_->outputs()->LookupByKey(param);
std::vector<std::string> args_vec; std::vector<std::string> args_vec;
if (var && var->arguments()) { if (var && var->arguments()) {
args_vec.reserve(var->arguments()->size()); args_vec.resize(var->arguments()->size());
for (const auto& out : *var->arguments()) { for (size_t i = 0; i < var->arguments()->size(); ++i) {
args_vec.push_back(out->str()); args_vec[i] = (*var->arguments())[i]->str();
} }
} }
return args_vec; return args_vec;
} }
std::vector<std::string> Output(const std::string& param) const override {
return Output(param.c_str());
}
std::vector<std::string> OutputArgumentNames() const override { std::vector<std::string> OutputArgumentNames() const override {
const auto& vars = desc_->outputs(); const auto& vars = desc_->outputs();
std::vector<std::string> output_names_vec; std::vector<std::string> output_names_vec;
if (vars) { if (vars) {
output_names_vec.reserve(vars->size()); output_names_vec.resize(vars->size());
for (const auto& out : *vars) { for (size_t i = 0; i < vars->size(); ++i) {
output_names_vec.push_back(out->parameter()->str()); output_names_vec[i] = (*vars)[i]->parameter()->str();
} }
} }
return output_names_vec; return output_names_vec;
} }
bool HasAttr(const char* name) const {
return desc_->attrs()->LookupByKey(name) != nullptr;
}
bool HasAttr(const std::string& name) const override { bool HasAttr(const std::string& name) const override {
return desc_->attrs()->LookupByKey(name.c_str()) != nullptr; return HasAttr(name.c_str());
} }
size_t AttrsSize() const { return desc_->attrs()->size(); } size_t AttrsSize() const { return desc_->attrs()->size(); }
...@@ -95,25 +106,23 @@ class OpDescView : public OpDescAPI { ...@@ -95,25 +106,23 @@ class OpDescView : public OpDescAPI {
return desc_->attrs()->Get(idx)->name()->str(); return desc_->attrs()->Get(idx)->name()->str();
} }
OpDescAPI::AttrType GetAttrType(const std::string& name) const override { OpDescAPI::AttrType GetAttrType(const char* name) const {
const auto& attr = desc_->attrs()->LookupByKey(name.c_str()); const auto& attr = desc_->attrs()->LookupByKey(name);
CHECK(attr) << "Can not find attr: " << name; CHECK(attr) << "Can not find attr: " << name;
return ConvertAttrType(attr->type()); return ConvertAttrType(attr->type());
} }
OpDescAPI::AttrType GetAttrType(size_t idx) const { OpDescAPI::AttrType GetAttrType(const std::string& name) const override {
const auto& attr = desc_->attrs()->Get(idx); return GetAttrType(name.c_str());
CHECK(attr);
return ConvertAttrType(attr->type());
} }
std::vector<std::string> AttrNames() const override { std::vector<std::string> AttrNames() const override {
const auto& attrs = desc_->attrs(); const auto& attrs = desc_->attrs();
std::vector<std::string> attr_names_vec; std::vector<std::string> attr_names_vec;
if (attrs) { if (attrs) {
attr_names_vec.reserve(attrs->size()); attr_names_vec.resize(attrs->size());
for (const auto& attr : *attrs) { for (size_t i = 0; i < attrs->size(); ++i) {
attr_names_vec.push_back(attr->name()->str()); attr_names_vec[i] = (*attrs)[i]->name()->str();
} }
} }
return attr_names_vec; return attr_names_vec;
...@@ -121,10 +130,11 @@ class OpDescView : public OpDescAPI { ...@@ -121,10 +130,11 @@ class OpDescView : public OpDescAPI {
template <typename T> template <typename T>
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT GetAttr( typename lite::OpDataTypeTrait<T, Flatbuffers>::RT GetAttr(
const std::string& name) const; const char* name) const;
template <typename T> template <typename T>
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT GetAttr(size_t idx) const; typename lite::OpDataTypeTrait<T, Flatbuffers>::RT GetAttr(
const std::string& name) const;
private: private:
proto::OpDesc const* desc_; proto::OpDesc const* desc_;
...@@ -138,7 +148,7 @@ class OpDescView : public OpDescAPI { ...@@ -138,7 +148,7 @@ class OpDescView : public OpDescAPI {
// caused by different building options. // caused by different building options.
public: public:
OpDescView() { NotImplemented(); } OpDescView() = default;
bool HasInput(const std::string& param) const { bool HasInput(const std::string& param) const {
return desc_->inputs()->LookupByKey(param.c_str()) != nullptr; return desc_->inputs()->LookupByKey(param.c_str()) != nullptr;
} }
......
...@@ -42,9 +42,9 @@ class ParamDescView : public ParamDescReadAPI { ...@@ -42,9 +42,9 @@ class ParamDescView : public ParamDescReadAPI {
std::vector<int64_t> Dim() const override { std::vector<int64_t> Dim() const override {
const auto& dims = tensor_desc_->dim(); const auto& dims = tensor_desc_->dim();
std::vector<int64_t> dims_vec; std::vector<int64_t> dims_vec;
dims_vec.reserve(dims->size()); dims_vec.resize(dims->size());
for (const auto& dim : *dims) { for (size_t i = 0; i < dims->size(); ++i) {
dims_vec.push_back(dim); dims_vec[i] = dims->operator[](i);
} }
return dims_vec; return dims_vec;
} }
...@@ -57,7 +57,7 @@ class ParamDescView : public ParamDescReadAPI { ...@@ -57,7 +57,7 @@ class ParamDescView : public ParamDescReadAPI {
size_t byte_size() const override { return tensor_desc_->data()->size(); } size_t byte_size() const override { return tensor_desc_->data()->size(); }
ParamDescView() = delete; ParamDescView() = default;
private: private:
proto::ParamDesc const* desc_; proto::ParamDesc const* desc_;
...@@ -87,9 +87,9 @@ class CombinedParamsDescView : public CombinedParamsDescReadAPI { ...@@ -87,9 +87,9 @@ class CombinedParamsDescView : public CombinedParamsDescReadAPI {
void InitParams() { void InitParams() {
desc_ = proto::GetCombinedParamsDesc(buf_.data()); desc_ = proto::GetCombinedParamsDesc(buf_.data());
size_t params_size = desc_->params()->size(); size_t params_size = desc_->params()->size();
params_.reserve(params_size); params_.resize(params_size);
for (size_t idx = 0; idx < params_size; ++idx) { for (size_t idx = 0; idx < params_size; ++idx) {
params_.push_back(ParamDescView(desc_->params()->Get(idx))); params_[idx] = ParamDescView(desc_->params()->Get(idx));
} }
} }
......
...@@ -48,9 +48,9 @@ class ProgramDescView : public ProgramDescAPI { ...@@ -48,9 +48,9 @@ class ProgramDescView : public ProgramDescAPI {
void InitProgramDesc() { void InitProgramDesc() {
desc_ = proto::GetProgramDesc(buf_.data()); desc_ = proto::GetProgramDesc(buf_.data());
blocks_.reserve(BlocksSize()); blocks_.resize(BlocksSize());
for (size_t idx = 0; idx < BlocksSize(); ++idx) { for (size_t idx = 0; idx < BlocksSize(); ++idx) {
blocks_.push_back(BlockDescView(desc_->blocks()->Get(idx))); blocks_[idx] = BlockDescView(desc_->blocks()->Get(idx));
} }
} }
......
...@@ -42,9 +42,9 @@ class VarDescView : public VarDescAPI { ...@@ -42,9 +42,9 @@ class VarDescView : public VarDescAPI {
CHECK(GetType() == VarDescAPI::Type::LOD_TENSOR); CHECK(GetType() == VarDescAPI::Type::LOD_TENSOR);
const auto& dims = desc_->type()->lod_tensor()->tensor()->dims(); const auto& dims = desc_->type()->lod_tensor()->tensor()->dims();
std::vector<int64_t> dims_vec; std::vector<int64_t> dims_vec;
dims_vec.reserve(dims->size()); dims_vec.resize(dims->size());
for (const auto& dim : *dims) { for (size_t i = 0; i < dims->size(); ++i) {
dims_vec.push_back(dim); dims_vec[i] = dims->operator[](i);
} }
return dims_vec; return dims_vec;
} }
...@@ -66,7 +66,7 @@ class VarDescView : public VarDescAPI { ...@@ -66,7 +66,7 @@ class VarDescView : public VarDescAPI {
// caused by different building options. // caused by different building options.
public: public:
VarDescView() { NotImplemented(); } VarDescView() = default;
void SetDataType(Type data_type) { NotImplemented(); } void SetDataType(Type data_type) { NotImplemented(); }
void SetShape(const std::vector<int64_t>& dims) { NotImplemented(); } void SetShape(const std::vector<int64_t>& dims) { NotImplemented(); }
......
...@@ -127,9 +127,9 @@ class VectorView<std::string, Flatbuffers> { ...@@ -127,9 +127,9 @@ class VectorView<std::string, Flatbuffers> {
operator std::vector<std::string>() const { operator std::vector<std::string>() const {
VLOG(5) << "Copying elements out of VectorView will damage performance."; VLOG(5) << "Copying elements out of VectorView will damage performance.";
std::vector<std::string> tmp; std::vector<std::string> tmp;
tmp.reserve(size()); tmp.resize(size());
for (size_t i = 0; i < size(); ++i) { for (size_t i = 0; i < size(); ++i) {
tmp.push_back(cvec_->operator[](i)->str()); tmp[i] = cvec_->operator[](i)->str();
} }
return tmp; return tmp;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册