提交 b7f6ed3d 编写于 作者: L Liu Yiqun

Optimize the InferShape of fc, gru and lookup_table.

上级 14e04200
...@@ -47,9 +47,11 @@ value_type DDimLite::count(int start, int end) const { ...@@ -47,9 +47,11 @@ value_type DDimLite::count(int start, int end) const {
DDimLite DDimLite::Slice(int start, int end) const { DDimLite DDimLite::Slice(int start, int end) const {
start = std::max(start, 0); start = std::max(start, 0);
end = std::min(end, static_cast<int>(data_.size())); end = std::min(end, static_cast<int>(data_.size()));
value_type arr[kMaxDimLength]; DDimLite new_dim(end - start);
memcpy(arr, data_.data() + start, (end - start) * sizeof(value_type)); for (int i = start; i < end; ++i) {
return DDimLite(arr, end - start); new_dim[i - start] = data_[i];
}
return new_dim;
} }
std::string DDimLite::repr() const { std::string DDimLite::repr() const {
......
...@@ -73,10 +73,7 @@ class DDimLite { ...@@ -73,10 +73,7 @@ class DDimLite {
DDimLite() = default; DDimLite() = default;
explicit DDimLite(const std::vector<value_type> &x) { ConstructFrom(x); } explicit DDimLite(const std::vector<value_type> &x) { ConstructFrom(x); }
explicit DDimLite(const value_type *arr, size_t size) { explicit DDimLite(size_t size) { data_.resize(size); }
data_.resize(size);
memcpy(data_.data(), arr, data_.size() * sizeof(value_type));
}
void ConstructFrom(const std::vector<value_type> &x) { void ConstructFrom(const std::vector<value_type> &x) {
data_.resize(x.size()); data_.resize(x.size());
......
...@@ -83,12 +83,12 @@ class FCFunctor { ...@@ -83,12 +83,12 @@ class FCFunctor {
// NOTE: here need to mutable_data for temporary Tensor X1 and Y1, // NOTE: here need to mutable_data for temporary Tensor X1 and Y1,
// the overhead is unmeasured. // the overhead is unmeasured.
lite::Tensor X1; Tensor X1;
X1.Resize({M * KK}); X1.Resize(std::vector<int64_t>({M * KK}));
T* X1_data = X1.mutable_data<T>(); T* X1_data = X1.mutable_data<T>();
lite::Tensor Y1; Tensor Y1;
Y1.Resize({M * (N + 4)}); Y1.Resize(std::vector<int64_t>({M * NN}));
Y1_data = Y1.mutable_data<T>(); Y1_data = Y1.mutable_data<T>();
auto parallel_memcpy_x = [&](int64_t begin, int64_t end) { auto parallel_memcpy_x = [&](int64_t begin, int64_t end) {
...@@ -115,7 +115,7 @@ class FCFunctor { ...@@ -115,7 +115,7 @@ class FCFunctor {
if (!B) { if (!B) {
auto parallel_memcpy_y = [&](int64_t begin, int64_t end) { auto parallel_memcpy_y = [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; i++) { for (int64_t i = begin; i < end; i++) {
memcpy(Y + i * N, Y1_data + i * (N + 4), N * sizeof(T)); memcpy(Y + i * N, Y1_data + i * NN, N * sizeof(T));
} }
}; };
lite::x86::RunParallelFor(0, M, parallel_memcpy_y); lite::x86::RunParallelFor(0, M, parallel_memcpy_y);
......
...@@ -49,16 +49,17 @@ bool FcOpLite::CheckShape() const { ...@@ -49,16 +49,17 @@ bool FcOpLite::CheckShape() const {
} }
bool FcOpLite::InferShape() const { bool FcOpLite::InferShape() const {
const auto input_dims = param_.input->dims(); const auto &input_dims = param_.input->dims();
const auto w_dims = param_.w->dims(); const auto &w_dims = param_.w->dims();
int in_num_col_dims = param_.in_num_col_dims;
// Set output dims // Set output dims
std::vector<int64_t> output_dims(param_.in_num_col_dims + 1, 0); DDim output_dims(in_num_col_dims + 1);
for (int i = 0; i < param_.in_num_col_dims; ++i) { for (int i = 0; i < in_num_col_dims; ++i) {
output_dims[i] = input_dims[i]; output_dims[i] = input_dims[i];
} }
output_dims.back() = w_dims[1]; output_dims[in_num_col_dims] = w_dims[1];
param_.output->Resize(lite::DDim(output_dims)); param_.output->Resize(output_dims);
// share LoD // share LoD
// param_.output->set_lod(param_.input->lod()); // param_.output->set_lod(param_.input->lod());
......
...@@ -52,15 +52,17 @@ bool GRUOpLite::CheckShape() const { ...@@ -52,15 +52,17 @@ bool GRUOpLite::CheckShape() const {
} }
bool GRUOpLite::InferShape() const { bool GRUOpLite::InferShape() const {
auto input_dims = param_.input->dims(); auto &input_dims = param_.input->dims();
auto weight_dims = param_.weight->dims(); auto &weight_dims = param_.weight->dims();
int frame_size = weight_dims[0]; int frame_size = weight_dims[0];
auto batch_size = input_dims[0]; auto batch_size = input_dims[0];
param_.batch_gate->Resize(input_dims); param_.batch_gate->Resize(input_dims);
param_.batch_reset_hidden_prev->Resize(lite::DDim({batch_size, frame_size}));
param_.batch_hidden->Resize(lite::DDim({batch_size, frame_size})); auto out_dims = DDim({batch_size, frame_size});
param_.hidden->Resize(lite::DDim({batch_size, frame_size})); param_.batch_reset_hidden_prev->Resize(out_dims);
param_.batch_hidden->Resize(out_dims);
param_.hidden->Resize(out_dims);
*(param_.hidden->mutable_lod()) = param_.input->lod(); *(param_.hidden->mutable_lod()) = param_.input->lod();
return true; return true;
......
...@@ -37,17 +37,14 @@ bool LookupTableOpLite::CheckShape() const { ...@@ -37,17 +37,14 @@ bool LookupTableOpLite::CheckShape() const {
} }
bool LookupTableOpLite::InferShape() const { bool LookupTableOpLite::InferShape() const {
auto table_dims = param_.W->dims(); auto &table_dims = param_.W->dims();
auto ids_dims = param_.Ids->dims(); auto &ids_dims = param_.Ids->dims();
auto out_dims = ids_dims;
int ids_rank = ids_dims.size(); int ids_rank = ids_dims.size();
out_dims[ids_rank - 1] = table_dims[1];
std::vector<int64_t> out_dims; param_.Out->Resize(out_dims);
for (int i = 0; i < ids_rank - 1; ++i) {
out_dims.push_back(ids_dims[i]);
}
out_dims.push_back(table_dims[1]);
param_.Out->Resize(lite::DDim{out_dims});
param_.Out->set_lod(param_.Ids->lod()); param_.Out->set_lod(param_.Ids->lod());
return true; return true;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册