提交 b7f6ed3d 编写于 作者: L Liu Yiqun

Optimize the InferShape of fc, gru and lookup_table.

上级 14e04200
......@@ -47,9 +47,11 @@ value_type DDimLite::count(int start, int end) const {
DDimLite DDimLite::Slice(int start, int end) const {
start = std::max(start, 0);
end = std::min(end, static_cast<int>(data_.size()));
value_type arr[kMaxDimLength];
memcpy(arr, data_.data() + start, (end - start) * sizeof(value_type));
return DDimLite(arr, end - start);
DDimLite new_dim(end - start);
for (int i = start; i < end; ++i) {
new_dim[i - start] = data_[i];
}
return new_dim;
}
std::string DDimLite::repr() const {
......
......@@ -73,10 +73,7 @@ class DDimLite {
DDimLite() = default;
explicit DDimLite(const std::vector<value_type> &x) { ConstructFrom(x); }
explicit DDimLite(const value_type *arr, size_t size) {
data_.resize(size);
memcpy(data_.data(), arr, data_.size() * sizeof(value_type));
}
explicit DDimLite(size_t size) { data_.resize(size); }
void ConstructFrom(const std::vector<value_type> &x) {
data_.resize(x.size());
......
......@@ -83,12 +83,12 @@ class FCFunctor {
// NOTE: here need to mutable_data for temporary Tensor X1 and Y1,
// the overhead is unmeasured.
lite::Tensor X1;
X1.Resize({M * KK});
Tensor X1;
X1.Resize(std::vector<int64_t>({M * KK}));
T* X1_data = X1.mutable_data<T>();
lite::Tensor Y1;
Y1.Resize({M * (N + 4)});
Tensor Y1;
Y1.Resize(std::vector<int64_t>({M * NN}));
Y1_data = Y1.mutable_data<T>();
auto parallel_memcpy_x = [&](int64_t begin, int64_t end) {
......@@ -115,7 +115,7 @@ class FCFunctor {
if (!B) {
auto parallel_memcpy_y = [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; i++) {
memcpy(Y + i * N, Y1_data + i * (N + 4), N * sizeof(T));
memcpy(Y + i * N, Y1_data + i * NN, N * sizeof(T));
}
};
lite::x86::RunParallelFor(0, M, parallel_memcpy_y);
......
......@@ -49,16 +49,17 @@ bool FcOpLite::CheckShape() const {
}
bool FcOpLite::InferShape() const {
const auto input_dims = param_.input->dims();
const auto w_dims = param_.w->dims();
const auto &input_dims = param_.input->dims();
const auto &w_dims = param_.w->dims();
int in_num_col_dims = param_.in_num_col_dims;
// Set output dims
std::vector<int64_t> output_dims(param_.in_num_col_dims + 1, 0);
for (int i = 0; i < param_.in_num_col_dims; ++i) {
DDim output_dims(in_num_col_dims + 1);
for (int i = 0; i < in_num_col_dims; ++i) {
output_dims[i] = input_dims[i];
}
output_dims.back() = w_dims[1];
param_.output->Resize(lite::DDim(output_dims));
output_dims[in_num_col_dims] = w_dims[1];
param_.output->Resize(output_dims);
// share LoD
// param_.output->set_lod(param_.input->lod());
......
......@@ -52,15 +52,17 @@ bool GRUOpLite::CheckShape() const {
}
bool GRUOpLite::InferShape() const {
auto input_dims = param_.input->dims();
auto weight_dims = param_.weight->dims();
auto &input_dims = param_.input->dims();
auto &weight_dims = param_.weight->dims();
int frame_size = weight_dims[0];
auto batch_size = input_dims[0];
param_.batch_gate->Resize(input_dims);
param_.batch_reset_hidden_prev->Resize(lite::DDim({batch_size, frame_size}));
param_.batch_hidden->Resize(lite::DDim({batch_size, frame_size}));
param_.hidden->Resize(lite::DDim({batch_size, frame_size}));
auto out_dims = DDim({batch_size, frame_size});
param_.batch_reset_hidden_prev->Resize(out_dims);
param_.batch_hidden->Resize(out_dims);
param_.hidden->Resize(out_dims);
*(param_.hidden->mutable_lod()) = param_.input->lod();
return true;
......
......@@ -37,17 +37,14 @@ bool LookupTableOpLite::CheckShape() const {
}
bool LookupTableOpLite::InferShape() const {
auto table_dims = param_.W->dims();
auto ids_dims = param_.Ids->dims();
auto &table_dims = param_.W->dims();
auto &ids_dims = param_.Ids->dims();
auto out_dims = ids_dims;
int ids_rank = ids_dims.size();
out_dims[ids_rank - 1] = table_dims[1];
std::vector<int64_t> out_dims;
for (int i = 0; i < ids_rank - 1; ++i) {
out_dims.push_back(ids_dims[i]);
}
out_dims.push_back(table_dims[1]);
param_.Out->Resize(lite::DDim{out_dims});
param_.Out->Resize(out_dims);
param_.Out->set_lod(param_.Ids->lod());
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册