提交 87e30f5c 编写于 作者: Y Yiqun Liu 提交者: GitHub

Optimize the InferShape of several operators. (#2839)

* Optimize the InferShape of several operators.
test=develop

* Remove the new function, resize and CheckPositive in DDim.
test=develop

* Fix a bug in fc_op's InferShape.
test=develop
上级 e3933e1b
...@@ -25,21 +25,17 @@ using value_type = int64_t; ...@@ -25,21 +25,17 @@ using value_type = int64_t;
value_type DDimLite::production() const { value_type DDimLite::production() const {
value_type res = 1; value_type res = 1;
for (size_t i = 0; i < this->size(); i++) { for (size_t i = 0; i < data_.size(); i++) {
res *= (*this)[i]; res *= data_[i];
} }
return res; return res;
} }
value_type DDimLite::count(int start, int end) const { value_type DDimLite::count(int start, int end) const {
if (start < 0) { start = std::max(start, 0);
start = 0; end = std::min(end, static_cast<int>(data_.size()));
}
if (end > size()) {
end = size();
}
if (end < start) { if (end < start) {
end = start; return 0;
} }
value_type sum = 1; value_type sum = 1;
for (auto i = start; i < end; ++i) { for (auto i = start; i < end; ++i) {
...@@ -49,11 +45,13 @@ value_type DDimLite::count(int start, int end) const { ...@@ -49,11 +45,13 @@ value_type DDimLite::count(int start, int end) const {
} }
DDimLite DDimLite::Slice(int start, int end) const { DDimLite DDimLite::Slice(int start, int end) const {
std::vector<value_type> vec; start = std::max(start, 0);
end = std::min(end, static_cast<int>(data_.size()));
std::vector<value_type> new_dim(end - start);
for (int i = start; i < end; i++) { for (int i = start; i < end; i++) {
vec.push_back((*this)[i]); new_dim[i - start] = data_[i];
} }
return DDimLite(vec); return DDim(new_dim);
} }
std::string DDimLite::repr() const { std::string DDimLite::repr() const {
......
...@@ -85,7 +85,11 @@ class DDimLite { ...@@ -85,7 +85,11 @@ class DDimLite {
} }
friend bool operator!=(const DDimLite &a, const DDimLite &b) { friend bool operator!=(const DDimLite &a, const DDimLite &b) {
return !(a == b); if (a.size() != b.size()) return true;
for (size_t i = 0; i < a.size(); i++) {
if (a[i] != b[i]) return true;
}
return false;
} }
private: private:
...@@ -118,7 +122,7 @@ class TensorLite { ...@@ -118,7 +122,7 @@ class TensorLite {
} }
void Resize(const DDimLite &ddim) { dims_ = ddim; } void Resize(const DDimLite &ddim) { dims_ = ddim; }
void Resize(const std::vector<int64_t> &x) { dims_ = DDimLite(x); } void Resize(const std::vector<int64_t> &x) { dims_.ConstructFrom(x); }
const DDimLite &dims() const { return dims_; } const DDimLite &dims() const { return dims_; }
int64_t numel() const { return dims_.production(); } int64_t numel() const { return dims_.production(); }
......
...@@ -31,20 +31,6 @@ namespace lite { ...@@ -31,20 +31,6 @@ namespace lite {
namespace kernels { namespace kernels {
namespace x86 { namespace x86 {
inline void FCOutputSize(const lite::DDim& in_dims,
const lite::DDim& w_dims,
std::vector<int64_t>& out_dims, // NOLINT
int in_num_col_dims,
bool padding_weights) {
auto w_dims1 = padding_weights ? w_dims[1] - 4 : w_dims[1];
out_dims.reserve(static_cast<size_t>(in_num_col_dims + 1));
for (int i = 0; i < in_num_col_dims; ++i) {
out_dims.push_back(in_dims[i]);
}
out_dims.push_back(w_dims1);
}
template <lite::TargetType Target, typename T> template <lite::TargetType Target, typename T>
class FCFunctor { class FCFunctor {
public: public:
...@@ -84,11 +70,11 @@ class FCFunctor { ...@@ -84,11 +70,11 @@ class FCFunctor {
// NOTE: here need to mutable_data for temporary Tensor X1 and Y1, // NOTE: here need to mutable_data for temporary Tensor X1 and Y1,
// the overhead is unmeasured. // the overhead is unmeasured.
lite::Tensor X1; lite::Tensor X1;
X1.Resize({M * KK}); X1.Resize(std::vector<int64_t>{M * KK});
T* X1_data = X1.mutable_data<T>(); T* X1_data = X1.mutable_data<T>();
lite::Tensor Y1; lite::Tensor Y1;
Y1.Resize({M * (N + 4)}); Y1.Resize(std::vector<int64_t>{M * NN});
Y1_data = Y1.mutable_data<T>(); Y1_data = Y1.mutable_data<T>();
auto parallel_memcpy_x = [&](int64_t begin, int64_t end) { auto parallel_memcpy_x = [&](int64_t begin, int64_t end) {
...@@ -115,7 +101,7 @@ class FCFunctor { ...@@ -115,7 +101,7 @@ class FCFunctor {
if (!B) { if (!B) {
auto parallel_memcpy_y = [&](int64_t begin, int64_t end) { auto parallel_memcpy_y = [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; i++) { for (int64_t i = begin; i < end; i++) {
memcpy(Y + i * N, Y1_data + i * (N + 4), N * sizeof(T)); memcpy(Y + i * N, Y1_data + i * NN, N * sizeof(T));
} }
}; };
lite::x86::RunParallelFor(0, M, parallel_memcpy_y); lite::x86::RunParallelFor(0, M, parallel_memcpy_y);
...@@ -145,22 +131,14 @@ class FcCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -145,22 +131,14 @@ class FcCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
auto* w = param.w; auto* w = param.w;
auto* bias = param.bias; auto* bias = param.bias;
auto* output = param.output; auto* output = param.output;
int in_num_col_dims = param.in_num_col_dims;
bool with_relu = (param.activation_type == "relu") ? true : false; bool with_relu = (param.activation_type == "relu") ? true : false;
auto w_dims = w->dims();
bool padding_weights = param.padding_weights; bool padding_weights = param.padding_weights;
const auto& w_dims = w->dims();
std::vector<int64_t> output_dims;
FCOutputSize(
input->dims(), w_dims, output_dims, in_num_col_dims, padding_weights);
output->Resize(output_dims);
output->set_lod(input->lod());
auto out_dims = output->dims();
auto w_dims0 = padding_weights ? w_dims[0] - 4 : w_dims[0]; auto w_dims0 = padding_weights ? w_dims[0] - 4 : w_dims[0];
auto w_dims1 = padding_weights ? w_dims[1] - 4 : w_dims[1]; auto w_dims1 = padding_weights ? w_dims[1] - 4 : w_dims[1];
int M = out_dims.production() / w_dims1;
int M = output->dims().production() / w_dims1;
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
const T* w_data = w->data<T>(); const T* w_data = w->data<T>();
......
...@@ -28,8 +28,9 @@ namespace x86 { ...@@ -28,8 +28,9 @@ namespace x86 {
template <typename T> template <typename T>
void Compute(const lite::Tensor* in, lite::Tensor* out) { void Compute(const lite::Tensor* in, lite::Tensor* out) {
// In CopyDataFrom, the target tensor's dims will be set to the source
// tensor's dims.
auto out_dims = out->dims(); auto out_dims = out->dims();
auto in_dims = in->dims();
out->CopyDataFrom(*in); out->CopyDataFrom(*in);
out->Resize(out_dims); out->Resize(out_dims);
} }
......
...@@ -27,11 +27,8 @@ bool ConcatOpLite::CheckShape() const { ...@@ -27,11 +27,8 @@ bool ConcatOpLite::CheckShape() const {
} }
bool ConcatOpLite::InferShape() const { bool ConcatOpLite::InferShape() const {
std::vector<lite::DDim> input_dims; const std::vector<Tensor *> &inputs = param_.x;
for (auto p : param_.x) { const size_t n = inputs.size();
input_dims.push_back(p->dims());
}
const size_t n = input_dims.size();
CHECK_GT_OR_FALSE(n, 0); CHECK_GT_OR_FALSE(n, 0);
int axis = 0; int axis = 0;
...@@ -42,17 +39,18 @@ bool ConcatOpLite::InferShape() const { ...@@ -42,17 +39,18 @@ bool ConcatOpLite::InferShape() const {
axis = axis_tensor_val[0]; axis = axis_tensor_val[0];
} }
if (axis < 0) { if (axis < 0) {
axis += input_dims[0].size(); axis += inputs[0]->dims().size();
} }
auto &out_dims = input_dims[0]; auto out_dims = inputs[0]->dims();
size_t in_zero_dims_size = out_dims.size(); size_t in_zero_dims_size = out_dims.size();
for (size_t i = 1; i < n; i++) { for (size_t i = 1; i < n; i++) {
const auto &input_dims_i = inputs[i]->dims();
for (size_t j = 0; j < in_zero_dims_size; j++) { for (size_t j = 0; j < in_zero_dims_size; j++) {
if (j == static_cast<size_t>(axis)) { if (j == static_cast<size_t>(axis)) {
out_dims[axis] += input_dims[i][j]; out_dims[axis] += input_dims_i[j];
} else { } else {
CHECK_EQ_OR_FALSE(out_dims[j], input_dims[i][j]); CHECK_EQ_OR_FALSE(out_dims[j], input_dims_i[j]);
} }
} }
} }
...@@ -60,7 +58,7 @@ bool ConcatOpLite::InferShape() const { ...@@ -60,7 +58,7 @@ bool ConcatOpLite::InferShape() const {
out_dims[axis] = -1; out_dims[axis] = -1;
} }
// Set output dims // Set output dims
param_.output->Resize(lite::DDim(out_dims)); param_.output->Resize(out_dims);
auto out_lod = param_.output->mutable_lod(); auto out_lod = param_.output->mutable_lod();
*out_lod = param_.x[0]->lod(); *out_lod = param_.x[0]->lod();
return true; return true;
......
...@@ -49,23 +49,25 @@ bool FcOpLite::CheckShape() const { ...@@ -49,23 +49,25 @@ bool FcOpLite::CheckShape() const {
} }
bool FcOpLite::InferShape() const { bool FcOpLite::InferShape() const {
const auto input_dims = param_.input->dims(); const auto& input_dims = param_.input->dims();
const auto w_dims = param_.w->dims(); const auto& w_dims = param_.w->dims();
int in_num_col_dims = param_.in_num_col_dims;
int64_t w_dims_1 = param_.padding_weights ? w_dims[1] - 4 : w_dims[1];
// Set output dims // Set output dims
std::vector<int64_t> output_dims(param_.in_num_col_dims + 1, 0); std::vector<DDim::value_type> output_dims(in_num_col_dims + 1);
for (int i = 0; i < param_.in_num_col_dims; ++i) { for (int i = 0; i < in_num_col_dims; ++i) {
output_dims[i] = input_dims[i]; output_dims[i] = input_dims[i];
} }
output_dims.back() = w_dims[1]; output_dims[in_num_col_dims] = w_dims_1;
param_.output->Resize(lite::DDim(output_dims)); param_.output->Resize(output_dims);
// share LoD // share LoD
param_.output->set_lod(param_.input->lod()); param_.output->set_lod(param_.input->lod());
return true; return true;
} }
bool FcOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { bool FcOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
auto input = op_desc.Input("Input").front(); auto input = op_desc.Input("Input").front();
auto W = op_desc.Input("W").front(); auto W = op_desc.Input("W").front();
auto out = op_desc.Output("Out").front(); auto out = op_desc.Output("Out").front();
......
...@@ -28,8 +28,8 @@ bool GRUOpLite::CheckShape() const { ...@@ -28,8 +28,8 @@ bool GRUOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.batch_hidden) CHECK_OR_FALSE(param_.batch_hidden)
CHECK_OR_FALSE(param_.hidden) CHECK_OR_FALSE(param_.hidden)
auto input_dims = param_.input->dims(); const auto& input_dims = param_.input->dims();
auto weight_dims = param_.weight->dims(); const auto& weight_dims = param_.weight->dims();
int input_size = input_dims[1]; int input_size = input_dims[1];
int frame_size = weight_dims[0]; int frame_size = weight_dims[0];
CHECK_EQ_OR_FALSE(input_size, frame_size * 3) CHECK_EQ_OR_FALSE(input_size, frame_size * 3)
...@@ -52,21 +52,23 @@ bool GRUOpLite::CheckShape() const { ...@@ -52,21 +52,23 @@ bool GRUOpLite::CheckShape() const {
} }
bool GRUOpLite::InferShape() const { bool GRUOpLite::InferShape() const {
auto input_dims = param_.input->dims(); const auto& input_dims = param_.input->dims();
auto weight_dims = param_.weight->dims(); const auto& weight_dims = param_.weight->dims();
int frame_size = weight_dims[0]; int frame_size = weight_dims[0];
auto batch_size = input_dims[0]; auto batch_size = input_dims[0];
param_.batch_gate->Resize(input_dims); param_.batch_gate->Resize(input_dims);
param_.batch_reset_hidden_prev->Resize(lite::DDim({batch_size, frame_size}));
param_.batch_hidden->Resize(lite::DDim({batch_size, frame_size})); DDim out_dims({batch_size, frame_size});
param_.hidden->Resize(lite::DDim({batch_size, frame_size})); param_.batch_reset_hidden_prev->Resize(out_dims);
param_.batch_hidden->Resize(out_dims);
param_.hidden->Resize(out_dims);
*(param_.hidden->mutable_lod()) = param_.input->lod(); *(param_.hidden->mutable_lod()) = param_.input->lod();
return true; return true;
} }
bool GRUOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { bool GRUOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
auto input = op_desc.Input("Input").front(); auto input = op_desc.Input("Input").front();
auto weight = op_desc.Input("Weight").front(); auto weight = op_desc.Input("Weight").front();
auto batch_gate = op_desc.Output("BatchGate").front(); auto batch_gate = op_desc.Output("BatchGate").front();
......
...@@ -25,8 +25,8 @@ bool LookupTableOpLite::CheckShape() const { ...@@ -25,8 +25,8 @@ bool LookupTableOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.Ids) CHECK_OR_FALSE(param_.Ids)
CHECK_OR_FALSE(param_.Out) CHECK_OR_FALSE(param_.Out)
auto table_dims = param_.W->dims(); const auto& table_dims = param_.W->dims();
auto ids_dims = param_.Ids->dims(); const auto& ids_dims = param_.Ids->dims();
int ids_rank = ids_dims.size(); int ids_rank = ids_dims.size();
...@@ -37,25 +37,20 @@ bool LookupTableOpLite::CheckShape() const { ...@@ -37,25 +37,20 @@ bool LookupTableOpLite::CheckShape() const {
} }
bool LookupTableOpLite::InferShape() const { bool LookupTableOpLite::InferShape() const {
auto table_dims = param_.W->dims(); const auto& table_dims = param_.W->dims();
auto ids_dims = param_.Ids->dims(); const auto& ids_dims = param_.Ids->dims();
auto out_dims = ids_dims;
int ids_rank = ids_dims.size(); int ids_rank = ids_dims.size();
out_dims[ids_rank - 1] = table_dims[1];
auto output_dims = ids_dims.Slice(0, ids_rank - 1); param_.Out->Resize(out_dims);
std::vector<int64_t> out_dims;
for (int i = 0; i < ids_rank - 1; ++i) {
out_dims.push_back(ids_dims[i]);
}
out_dims.push_back(table_dims[1]);
param_.Out->Resize(lite::DDim{out_dims});
param_.Out->set_lod(param_.Ids->lod()); param_.Out->set_lod(param_.Ids->lod());
return true; return true;
} }
bool LookupTableOpLite::AttachImpl(const cpp::OpDesc &op_desc, bool LookupTableOpLite::AttachImpl(const cpp::OpDesc& op_desc,
lite::Scope *scope) { lite::Scope* scope) {
auto input = op_desc.Input("W").front(); auto input = op_desc.Input("W").front();
auto ids = op_desc.Input("Ids").front(); auto ids = op_desc.Input("Ids").front();
auto out = op_desc.Output("Out").front(); auto out = op_desc.Output("Out").front();
......
...@@ -29,39 +29,43 @@ bool ReduceOp::CheckShape() const { ...@@ -29,39 +29,43 @@ bool ReduceOp::CheckShape() const {
} }
bool ReduceOp::InferShape() const { bool ReduceOp::InferShape() const {
auto x_dims = param_.x->dims(); const auto &x_dims = param_.x->dims();
auto x_rank = x_dims.size(); auto x_rank = x_dims.size();
auto dims = param_.dim; auto dims = param_.dim;
for (size_t i = 0; i < dims.size(); ++i) { for (size_t i = 0; i < dims.size(); ++i) {
if (dims[i] < 0) dims[i] = x_rank + dims[i]; if (dims[i] < 0) {
dims[i] = x_rank + dims[i];
}
CHECK_LT(dims[i], x_rank) CHECK_LT(dims[i], x_rank)
<< "The dim should be in the range [-rank(input), rank(input)."; << "The dim should be in the range [-rank(input), rank(input).";
} }
sort(dims.begin(), dims.end());
bool reduce_all = param_.reduce_all; bool reduce_all = param_.reduce_all;
bool keep_dim = param_.keep_dim; bool keep_dim = param_.keep_dim;
if (reduce_all) { if (reduce_all) {
if (keep_dim) if (keep_dim)
param_.output->Resize(lite::DDim(std::vector<int64_t>(x_rank, 1))); param_.output->Resize(std::vector<int64_t>(x_rank, 1));
else else
param_.output->Resize(lite::DDim(std::vector<int64_t>{1})); param_.output->Resize(std::vector<int64_t>{1});
} else { } else {
auto dims_vector = x_dims.Vectorize(); size_t out_rank = keep_dim ? x_rank : x_rank - dims.size();
std::vector<DDim::value_type> out_dims(out_rank);
if (keep_dim) { if (keep_dim) {
for (size_t i = 0; i < dims.size(); ++i) { for (size_t i = 0; i < dims.size(); ++i) {
dims_vector[dims[i]] = 1; out_dims[dims[i]] = 1;
} }
} else { } else {
const int kDelFlag = -2; sort(dims.begin(), dims.end());
for (size_t i = 0; i < dims.size(); ++i) { int dim_index = 0;
dims_vector[dims[i]] = kDelFlag; int out_index = 0;
for (size_t i = 0; i < x_rank; ++i) {
if (dims[dim_index] == static_cast<DDim::value_type>(i)) {
dim_index++;
} else {
out_dims[out_index++] = x_dims[i];
}
} }
dims_vector.erase(
remove(dims_vector.begin(), dims_vector.end(), kDelFlag),
dims_vector.end());
} }
auto out_dims = lite::DDim(dims_vector);
param_.output->Resize(out_dims); param_.output->Resize(out_dims);
if (dims[0] != 0) { if (dims[0] != 0) {
param_.output->set_lod(param_.x->lod()); param_.output->set_lod(param_.x->lod());
......
...@@ -27,14 +27,15 @@ bool ReshapeOp::CheckShape() const { ...@@ -27,14 +27,15 @@ bool ReshapeOp::CheckShape() const {
} }
bool ReshapeOp::InferShape() const { bool ReshapeOp::InferShape() const {
auto shape_tensor_vct = param_.shape_tensor_vct; const auto &shape_tensor_vct = param_.shape_tensor_vct;
auto *shape_tensor = param_.shape_tensor; auto *shape_tensor = param_.shape_tensor;
auto shape_vct = param_.shape_vct; const auto &shape_vct = param_.shape_vct;
std::vector<int> final_shape;
std::vector<int> final_shape;
if (shape_tensor_vct.size() > 0) { if (shape_tensor_vct.size() > 0) {
for (int i = 0; i < shape_tensor_vct.size(); i++) { final_shape.resize(shape_tensor_vct.size());
final_shape.push_back(shape_tensor_vct[i]->data<int>()[0]); for (size_t i = 0; i < shape_tensor_vct.size(); i++) {
final_shape[i] = shape_tensor_vct[i]->data<int>()[0];
} }
} else if (shape_tensor != nullptr) { } else if (shape_tensor != nullptr) {
auto *shape_tensor_data = shape_tensor->data<int>(); auto *shape_tensor_data = shape_tensor->data<int>();
...@@ -46,7 +47,7 @@ bool ReshapeOp::InferShape() const { ...@@ -46,7 +47,7 @@ bool ReshapeOp::InferShape() const {
LOG(FATAL) << "input shape error"; LOG(FATAL) << "input shape error";
} }
auto x_dims = param_.x->dims(); const auto &x_dims = param_.x->dims();
auto output_dims = ValidateShape(final_shape, x_dims); auto output_dims = ValidateShape(final_shape, x_dims);
param_.output->Resize(output_dims); param_.output->Resize(output_dims);
auto out_lod = param_.output->mutable_lod(); auto out_lod = param_.output->mutable_lod();
...@@ -98,8 +99,9 @@ bool Reshape2Op::CheckShape() const { ...@@ -98,8 +99,9 @@ bool Reshape2Op::CheckShape() const {
bool Reshape2Op::InferShape() const { bool Reshape2Op::InferShape() const {
ReshapeOp::InferShape(); ReshapeOp::InferShape();
auto x_dims = param_.x->dims(); const auto &x_dims = param_.x->dims();
std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 0); std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (size_t i = 0; i < x_dims.size(); i++) { for (size_t i = 0; i < x_dims.size(); i++) {
xshape_dims[i + 1] = x_dims[i]; xshape_dims[i + 1] = x_dims[i];
} }
...@@ -116,20 +118,26 @@ bool Reshape2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -116,20 +118,26 @@ bool Reshape2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
return true; return true;
} }
DDim ValidateShape(const std::vector<int> &shape, const DDim &input_dims) { static bool CheckPositive(const DDim &dims) {
const lite::DDim::value_type input_size = input_dims.production(); for (size_t i = 0; i < dims.size(); ++i) {
auto input_shape = input_dims.Vectorize(); if (dims[i] <= 0) {
bool all_positive = std::all_of( return false;
input_shape.cbegin(), input_shape.cend(), [](lite::DDim::value_type i) { }
return i > 0; }
}); return true;
}
std::vector<DDim::value_type> ValidateShape(const std::vector<int> &shape,
const DDim &input_dims) {
const DDim::value_type input_size = input_dims.production();
// only one dimension can be set to -1, whose size will be automatically // only one dimension can be set to -1, whose size will be automatically
// infered. // infered.
const int unk_dim_val = -1; const int unk_dim_val = -1;
const int copy_dim_val = 0; const int copy_dim_val = 0;
std::vector<lite::DDim::value_type> output_shape(shape.size(), 0); std::vector<DDim::value_type> output_dims(shape.size());
lite::DDim::value_type capacity = 1; DDim::value_type capacity = 1;
int unk_dim_idx = -1; int unk_dim_idx = -1;
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == unk_dim_val) { if (shape[i] == unk_dim_val) {
...@@ -137,7 +145,7 @@ DDim ValidateShape(const std::vector<int> &shape, const DDim &input_dims) { ...@@ -137,7 +145,7 @@ DDim ValidateShape(const std::vector<int> &shape, const DDim &input_dims) {
<< "Only one input dimension of Attr(shape) can be unknown."; << "Only one input dimension of Attr(shape) can be unknown.";
unk_dim_idx = i; unk_dim_idx = i;
} else if (shape[i] == copy_dim_val) { } else if (shape[i] == copy_dim_val) {
CHECK_LT(static_cast<int>(i), input_shape.size()) CHECK_LT(static_cast<int>(i), input_dims.size())
<< "The index of dimension to copy from input shape must be less " << "The index of dimension to copy from input shape must be less "
"than the size of input shape."; "than the size of input shape.";
} else { } else {
...@@ -145,28 +153,28 @@ DDim ValidateShape(const std::vector<int> &shape, const DDim &input_dims) { ...@@ -145,28 +153,28 @@ DDim ValidateShape(const std::vector<int> &shape, const DDim &input_dims) {
"be negtive except one unknown dimension."; "be negtive except one unknown dimension.";
} }
capacity *= (shape[i] ? static_cast<lite::DDim::value_type>(shape[i]) DDim::value_type output_dim_i =
: input_shape[i]); shape[i] ? static_cast<DDim::value_type>(shape[i]) : input_dims[i];
output_shape[i] = (shape[i] ? static_cast<lite::DDim::value_type>(shape[i]) output_dims[i] = output_dim_i;
: input_shape[i]); capacity *= output_dim_i;
} }
if (unk_dim_idx != -1) { if (unk_dim_idx != -1) {
if (all_positive) { if (CheckPositive(input_dims)) {
// input_size < 0 and is un-determinate in compile time, skip the check, // input_size < 0 and is un-determinate in compile time, skip the check,
// for example, input_dims = [-1, 8, 1, 1], shape = [-1, 3, 8], // for example, input_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
// capacity = -24, input_size = -8, output_shape[0] = 0 // capacity = -24, input_size = -8, output_shape[0] = 0
// the following check will fail. // the following check will fail.
output_shape[unk_dim_idx] = -input_size / capacity; output_dims[unk_dim_idx] = -input_size / capacity;
CHECK_EQ(output_shape[unk_dim_idx] * capacity, -input_size) CHECK_EQ(output_dims[unk_dim_idx] * capacity, -input_size)
<< "Invalid shape is given."; << "Invalid shape is given.";
} else { } else {
output_shape[unk_dim_idx] = -1; output_dims[unk_dim_idx] = -1;
} }
} else { } else {
CHECK_EQ(capacity, input_size) << "Invalid shape is given."; CHECK_EQ(capacity, input_size) << "Invalid shape is given.";
} }
return lite::DDim(output_shape); return output_dims;
} }
} // namespace operators } // namespace operators
......
...@@ -56,7 +56,8 @@ class Reshape2Op : public ReshapeOp { ...@@ -56,7 +56,8 @@ class Reshape2Op : public ReshapeOp {
std::string DebugString() const override { return "reshape2"; } std::string DebugString() const override { return "reshape2"; }
}; };
DDim ValidateShape(const std::vector<int> &shape, const DDim &input_dims); std::vector<DDim::value_type> ValidateShape(const std::vector<int> &shape,
const DDim &input_dims);
} // namespace operators } // namespace operators
} // namespace lite } // namespace lite
......
...@@ -47,7 +47,6 @@ void Relu(float* out, int num, int channel) { ...@@ -47,7 +47,6 @@ void Relu(float* out, int num, int channel) {
DDim ComputeOutDim(const DDim& dim_in, const DDim& wdim, int in_num_col_dim) { DDim ComputeOutDim(const DDim& dim_in, const DDim& wdim, int in_num_col_dim) {
std::vector<int64_t> out_dim; std::vector<int64_t> out_dim;
out_dim.resize(in_num_col_dim + 1); out_dim.resize(in_num_col_dim + 1);
auto in_mat_dims = dim_in.Flatten2D(in_num_col_dim);
for (int i = 0; i < in_num_col_dim; ++i) { for (int i = 0; i < in_num_col_dim; ++i) {
out_dim[i] = dim_in[i]; out_dim[i] = dim_in[i];
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册