提交 118c1441 编写于 作者: L Liu Yiqun

Save the workspace.

上级 05df5200
...@@ -137,6 +137,7 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) { ...@@ -137,6 +137,7 @@ void RuntimeProgram::UpdateVarsOfProgram(cpp::ProgramDesc* desc) {
void RuntimeProgram::Run() { void RuntimeProgram::Run() {
for (auto& inst : instructions_) { for (auto& inst : instructions_) {
// LOG(INFO) << "Run op: " << inst.op()->op_info()->Type();
if (inst.is_feed_fetch_op()) continue; if (inst.is_feed_fetch_op()) continue;
inst.Run(); inst.Run();
#ifdef LITE_WITH_PROFILE #ifdef LITE_WITH_PROFILE
......
...@@ -42,26 +42,47 @@ template <typename ValueType, size_t DimLength> ...@@ -42,26 +42,47 @@ template <typename ValueType, size_t DimLength>
class DimVector { class DimVector {
public: public:
DimVector() { DimVector() {
memset(arr_, 0, DimLength * sizeof(ValueType)); // data_ = new ValueType[DimLength];
// data_ = static_cast<ValueType *>(malloc(DimLength *
// sizeof(ValueType)));
data_.resize(DimLength);
// memset(data_, 0, DimLength * sizeof(ValueType));
size_ = 0; size_ = 0;
} }
~DimVector() {
// if (data_) {
// delete[] data_;
// free(data_);
// }
}
size_t size() const { return size_; } size_t size() const { return size_; }
void resize(size_t new_size) { void resize(size_t new_size) {
CHECK_LE(new_size, DimLength) CHECK_LE(new_size, DimLength)
<< "Expected the number of dimentations <= " << DimLength << "Expected the number of dimentations <= " << DimLength
<< ", received " << new_size << "."; << ", received " << new_size << ".";
// if (new_size != size_) {
// delete[] data_;
// data_ = nullptr;
// }
size_ = new_size; size_ = new_size;
} }
ValueType *data() { return arr_; } ValueType *mutable_data() {
const ValueType *data() const { return arr_; } // if (!data_ && size_ > 0U) {
// data_ = new ValueType[size_];
// }
return data_.data();
}
const ValueType *data() const { return data_.data(); }
ValueType operator[](int offset) const { return arr_[offset]; } ValueType operator[](int offset) const { return data_[offset]; }
ValueType &operator[](int offset) { return arr_[offset]; } ValueType &operator[](int offset) { return data_[offset]; }
private: private:
ValueType arr_[DimLength]; // ValueType data_[DimLength];
// ValueType* data_{nullptr};
std::vector<ValueType> data_;
size_t size_{0}; size_t size_{0};
}; };
...@@ -78,7 +99,7 @@ class DDimLite { ...@@ -78,7 +99,7 @@ class DDimLite {
void ConstructFrom(const std::vector<value_type> &x) { void ConstructFrom(const std::vector<value_type> &x) {
data_.resize(x.size()); data_.resize(x.size());
memcpy(data_.data(), x.data(), x.size() * sizeof(value_type)); memcpy(data_.mutable_data(), x.data(), x.size() * sizeof(value_type));
} }
value_type operator[](int offset) const { return data_[offset]; } value_type operator[](int offset) const { return data_[offset]; }
...@@ -127,7 +148,9 @@ class DDimLite { ...@@ -127,7 +148,9 @@ class DDimLite {
DDimLite &operator=(const DDimLite &a) { DDimLite &operator=(const DDimLite &a) {
this->data_.resize(a.size()); this->data_.resize(a.size());
memcpy(this->data_.data(), a.data_.data(), a.size() * sizeof(value_type)); memcpy(this->data_.mutable_data(),
a.data_.data(),
a.size() * sizeof(value_type));
return *this; return *this;
} }
...@@ -176,10 +199,19 @@ class TensorLite { ...@@ -176,10 +199,19 @@ class TensorLite {
offset_); offset_);
} }
void Resize(const DDimLite &ddim) { dims_ = ddim; } void Resize(const DDimLite &ddim) {
void Resize(const std::vector<int64_t> &x) { dims_ = DDimLite(x); } dims_ = ddim;
// LOG(INFO) << "Set dims: " << dims_ << " for tensor " << this;
}
void Resize(const std::vector<int64_t> &x) {
dims_ = DDimLite(x);
// LOG(INFO) << "Set dims: " << dims_ << " for tensor " << this;
}
const DDimLite &dims() const { return dims_; } const DDimLite &dims() const {
// LOG(INFO) << "Get dims: " << dims_ << " for tensor " << this;
return dims_;
}
int64_t numel() const { return dims_.production(); } int64_t numel() const { return dims_.production(); }
const LoD &lod() const { return lod_; } const LoD &lod() const { return lod_; }
...@@ -216,6 +248,9 @@ class TensorLite { ...@@ -216,6 +248,9 @@ class TensorLite {
} }
memory_size_ = dims_.production() * sizeof(T); memory_size_ = dims_.production() * sizeof(T);
buffer_->ResetLazy(target_, memory_size_); buffer_->ResetLazy(target_, memory_size_);
// char *ptr = static_cast<char *>(buffer_->data()) + offset_;
// LOG(INFO) << "mutable_data for tensor " << this << ": " << ptr << ",
// memory_size: " << memory_size_;
return reinterpret_cast<R *>(static_cast<char *>(buffer_->data()) + return reinterpret_cast<R *>(static_cast<char *>(buffer_->data()) +
offset_); offset_);
} }
......
...@@ -34,9 +34,11 @@ class LookupTableCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -34,9 +34,11 @@ class LookupTableCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
auto *output_t = param.Out; auto *output_t = param.Out;
int64_t padding_idx = param.padding_idx; int64_t padding_idx = param.padding_idx;
auto *ids = ids_t->data<int64_t>(); auto *ids = ids_t->data<int64_t>();
// LOG(INFO) << "ids->dims: " << ids_t->dims();
int64_t ids_numel = ids_t->dims().production(); int64_t ids_numel = ids_t->dims().production();
auto *table_t = param.W; auto *table_t = param.W;
// LOG(INFO) << "W->dims: " << table_t->dims();
int64_t row_number = table_t->dims()[0]; int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1]; int64_t row_width = table_t->dims()[1];
......
...@@ -36,9 +36,9 @@ class SequenceReshapeCompute ...@@ -36,9 +36,9 @@ class SequenceReshapeCompute
auto* out = param.output; auto* out = param.output;
int out_width = param.new_dim; int out_width = param.new_dim;
auto in_dims = in->dims(); const auto& in_dims = in->dims();
// LOG(INFO) << "in_dims: " << in_dims;
int64_t in_width = in_dims[1]; int64_t in_width = in_dims[1];
// LOG(INFO)<<"sequence_reshape in tensor:"<<*in;
auto& in_lod = in->lod(); auto& in_lod = in->lod();
CHECK_EQ(in_lod.size(), 1UL); CHECK_EQ(in_lod.size(), 1UL);
......
...@@ -38,12 +38,15 @@ bool LookupTableOpLite::CheckShape() const { ...@@ -38,12 +38,15 @@ bool LookupTableOpLite::CheckShape() const {
bool LookupTableOpLite::InferShape() const { bool LookupTableOpLite::InferShape() const {
const auto &table_dims = param_.W->dims(); const auto &table_dims = param_.W->dims();
// LOG(INFO) << "table_dims: " << table_dims;
const auto &ids_dims = param_.Ids->dims(); const auto &ids_dims = param_.Ids->dims();
// LOG(INFO) << "ids_dims: " << ids_dims;
auto out_dims = ids_dims; auto out_dims = ids_dims;
int ids_rank = ids_dims.size(); int ids_rank = ids_dims.size();
out_dims[ids_rank - 1] = table_dims[1]; out_dims[ids_rank - 1] = table_dims[1];
// LOG(INFO) << "out_dims: " << out_dims;
param_.Out->Resize(out_dims); param_.Out->Resize(out_dims);
param_.Out->set_lod(param_.Ids->lod()); param_.Out->set_lod(param_.Ids->lod());
return true; return true;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册