提交 5f98d800 编写于 作者: W wangguibao

AsyncExecutor bugfix: Tensor change to LoDTensor

上级 f6a877bc
......@@ -33,11 +33,7 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
CheckInit();
for (size_t i = 0; i < use_slots_.size(); ++i) {
if (name == use_slots_[i]) {
if (use_slots_is_dense_[i]) {
feed_vec_[i] = MixTensor(var->GetMutable<Tensor>());
} else {
feed_vec_[i] = MixTensor(var->GetMutable<LoDTensor>());
}
feed_vec_[i] = var->GetMutable<LoDTensor>();
}
}
}
......@@ -350,34 +346,21 @@ void MultiSlotDataFeed::PutToFeedVec(
int total_instance = static_cast<int>(offset.back());
if (type[0] == 'f') { // float
const auto& feasign = ins_vec[i].GetFloatData();
if (feed_vec_[i].IsDense()) {
int size_in_each_batch = total_instance / batch_size_;
float* tensor_ptr = feed_vec_[i].GetTensor()->mutable_data<float>(
{batch_size_, size_in_each_batch}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
} else {
float* tensor_ptr = feed_vec_[i].GetLoDTensor()->mutable_data<float>(
float* tensor_ptr = feed_vec_[i]->mutable_data<float>(
{total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
LoD data_lod{offset};
feed_vec_[i].GetLoDTensor()->set_lod(data_lod);
}
} else if (type[0] == 'u') { // uint64
// no uint64_t type in paddlepaddle
const auto& feasign = ins_vec[i].GetUint64Data();
if (feed_vec_[i].IsDense()) {
int size_in_each_batch = total_instance / batch_size_;
int64_t* tensor_ptr = feed_vec_[i].GetTensor()->mutable_data<int64_t>(
{batch_size_, size_in_each_batch}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t));
} else {
int64_t* tensor_ptr =
feed_vec_[i].GetLoDTensor()->mutable_data<int64_t>(
int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
{total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t));
LoD data_lod{offset};
feed_vec_[i].GetLoDTensor()->set_lod(data_lod);
}
LoD data_lod{offset};
feed_vec_[i]->set_lod(data_lod);
if (use_slots_is_dense_[i]) {
int dim = total_instance / batch_size_;
feed_vec_[i]->Resize({batch_size_, dim});
}
}
}
......
......@@ -30,35 +30,6 @@ limitations under the License. */
namespace paddle {
namespace framework {
// Pack Tensor type and LoDTensor type into MixTensor type, in order
// to record either Tensor or LoDTensor information at the same time.
class MixTensor {
public:
MixTensor() {}
explicit MixTensor(LoDTensor* lodtensor) {
is_dense_ = false;
lodtensor_ = lodtensor;
}
explicit MixTensor(Tensor* tensor) {
is_dense_ = true;
tensor_ = tensor;
}
bool IsDense() { return is_dense_; }
LoDTensor* GetLoDTensor() {
PADDLE_ENFORCE(!is_dense_, "Let a dense var return a LoDTensor ptr.");
return lodtensor_;
}
Tensor* GetTensor() {
PADDLE_ENFORCE(is_dense_, "Let a sparse var return a Tensor ptr.");
return tensor_;
}
private:
bool is_dense_;
LoDTensor* lodtensor_;
Tensor* tensor_;
};
// DataFeed is the base virtual class for all ohther DataFeeds.
// It is used to read files and parse the data for subsequent trainer.
// Example:
......@@ -133,7 +104,7 @@ class DataFeed {
use_slots_index_; // -1: not used; >=0: the index of use_slots_
// The data read by DataFeed will be stored here
std::vector<MixTensor> feed_vec_;
std::vector<LoDTensor*> feed_vec_;
// the batch size defined by user
int default_batch_size_;
......
......@@ -152,21 +152,15 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
const auto& multi_slot_desc = data_feed_desc.multi_slot_desc();
std::map<std::string, const paddle::framework::LoDTensor*>
lodtensor_targets;
std::map<std::string, const paddle::framework::Tensor*> tensor_targets;
for (int i = 0; i < multi_slot_desc.slots_size(); ++i) {
const auto& slot = multi_slot_desc.slots(i);
if (slot.is_used()) {
const auto& name = slot.name();
readers[idx]->AddFeedVar(scope->Var(name), name);
if (slot.is_dense()) {
tensor_targets[name] =
&scope->FindVar(name)->Get<paddle::framework::Tensor>();
} else {
lodtensor_targets[name] =
&scope->FindVar(name)->Get<paddle::framework::LoDTensor>();
}
}
}
readers[idx]->Start();
while (readers[idx]->Next()) {
int index = 0;
......@@ -175,8 +169,9 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
if (!slot.is_used()) {
continue;
}
const paddle::framework::LoDTensor* tens =
lodtensor_targets[slot.name()];
if (slot.is_dense()) { // dense branch
const paddle::framework::Tensor* tens = tensor_targets[slot.name()];
if (slot.type() == "uint64") {
const int64_t* data = tens->data<int64_t>();
int batch_size = tens->dims()[0];
......@@ -202,8 +197,6 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
PADDLE_THROW("Error type in proto file.");
}
} else { // sparse branch
const paddle::framework::LoDTensor* tens =
lodtensor_targets[slot.name()];
if (slot.type() == "uint64") {
const int64_t* data = tens->data<int64_t>();
for (size_t i = 0; i < tens->NumElements(); ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册