Created by: gfwm2013
使fetch能够支持LoDTensorArray
26 27 28 class FetchVar : public boost::variant<LoDTensor, LoDTensorArray> { 29 private: 30 using FetchVarBase = boost::variant<LoDTensor, LoDTensorArray>; 31 32 public: 33 FetchVar() = default; 34 FetchVar(const LoDTensor &lod_tensor) : FetchVarBase(lod_tensor) {} // NOLINT 35 FetchVar(const LoDTensorArray &lod_tensor_array) // NOLINT 36 : FetchVarBase(lod_tensor_array) {} // NOLINT 37 }; 38 39 // using FetchVar = boost::variant<LoDTensor, LoDTensorArray>; 40 using FetchVarList = std::vector<FetchVar>; 41 42 struct DataIsLoDTensor : public boost::static_visitor<bool> { 901 911 use_program_cache=use_program_cache) 902 912 903 913 program._compile(scope, self.place) 914 915 if _has_lod_tensor_array(program._program, fetch_list): 916 raise ValueError("Currently, The type of item in fetch_list should be" \ 82 fetch_list->resize(col + 1); 83 } 84 if (fetch_var->IsType<framework::LoDTensor>()) { 85 auto &src_item = fetch_var->Get<framework::LoDTensor>(); 86 auto *dst_item = 87 &(boost::get<framework::LoDTensor>(fetch_list->at(col))); 88 TransDataLayout(src_item, fetch_var_name, dst_item); 89 } else if (fetch_var->IsType<framework::LoDTensorArray>()) { 90 auto &src_item = fetch_var->Get<framework::LoDTensorArray>(); 91 auto &item = fetch_list->at(col); 92 std::vector<framework::LoDTensor> temp; 93 temp.resize(src_item.size()); 94 item = temp; 95 framework::LoDTensorArray *dst_item = 96 &(boost::get<framework::LoDTensorArray>(item)); 97 for (size_t i = 0; i < src_item.size(); i++) { 67 fetch_var_name == framework::GradVarName("Filter") 68 ? framework::DataLayout::kNCHW 69 : paddle::platform::get_cur_paddle_data_layout(), 70 src_item, &out, platform::CPUPlace()); 71 TensorCopySync(out, platform::CPUPlace(), &dst_item); 72 } else { 73 TensorCopySync(src_item, platform::CPUPlace(), &dst_item); 71 if (out_var->IsType<framework::FeedFetchList>()) { 72 auto *fetch_list = out_var->GetMutable<framework::FeedFetchList>(); 73 if (col >= fetch_list->size()) { 74 fetch_list->resize(col + 1); 74 75 } 75 #else 76 TensorCopySync(src_item, platform::CPUPlace(), &dst_item); 77 #endif 76 auto *dst_item = &(fetch_list->at(col)); Created by: zhhsplendid
Using auto when the type is obvious, otherwise you should specify the type in C++
For example,
auto &src_item = fetch_var->Get<framework::FeedFetchType>()
is very clear thatsrc_item
is a reference offramework::FeedFetchType
, so that's good.But not
auto *dst_item = &(fetch_list->at(col));
78 79 } else { 79 // Not copy, if the src tensor is empty. 80 dst_item.clear(); 81 dst_item.Resize({0}); 80 auto *fetch_list = out_var->GetMutable<framework::FetchVarList>(); 81 if (col >= fetch_list->size()) { 82 fetch_list->resize(col + 1); 83 } 84 if (fetch_var->IsType<framework::LoDTensor>()) { 85 auto &src_item = fetch_var->Get<framework::LoDTensor>(); 86 auto *dst_item = 87 &(boost::get<framework::LoDTensor>(fetch_list->at(col))); 88 TransDataLayout(src_item, fetch_var_name, dst_item); 89 } else if (fetch_var->IsType<framework::LoDTensorArray>()) { 90 auto &src_item = fetch_var->Get<framework::LoDTensorArray>(); 91 auto &item = fetch_list->at(col); 81 dst_item.Resize({0}); 80 auto *fetch_list = out_var->GetMutable<framework::FetchVarList>(); 81 if (col >= fetch_list->size()) { 82 fetch_list->resize(col + 1); 83 } 84 if (fetch_var->IsType<framework::LoDTensor>()) { 85 auto &src_item = fetch_var->Get<framework::LoDTensor>(); 86 auto *dst_item = 87 &(boost::get<framework::LoDTensor>(fetch_list->at(col))); 88 TransDataLayout(src_item, fetch_var_name, dst_item); 89 } else if (fetch_var->IsType<framework::LoDTensorArray>()) { 90 auto &src_item = fetch_var->Get<framework::LoDTensorArray>(); 91 auto &item = fetch_list->at(col); 92 std::vector<framework::LoDTensor> temp; 93 temp.resize(src_item.size()); 94 item = temp;