diff --git a/lite/operators/fc_op.cc b/lite/operators/fc_op.cc index 28a220da2de0920643d46f1ed9c610dfa613cf95..e0591c5eae3bf351aa6dae5ff981e3b9c81249e0 100644 --- a/lite/operators/fc_op.cc +++ b/lite/operators/fc_op.cc @@ -50,7 +50,7 @@ bool FcOpLite::CheckShape() const { bool FcOpLite::InferShapeImpl() const { const auto& input_dims = param_.input->dims(); - const auto& w_dims = param_.w->dims(); + const auto& w_dims = param_.w_dims; int in_num_col_dims = param_.in_num_col_dims; int64_t w_dims_1 = param_.padding_weights ? w_dims[1] - 4 : w_dims[1]; @@ -77,6 +77,7 @@ bool FcOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { param_.input = scope->FindVar(input)->GetMutable(); param_.w = scope->FindVar(W)->GetMutable(); + param_.w_dims = param_.w->dims(); std::vector input_arg_names = op_desc.InputArgumentNames(); if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") != input_arg_names.end()) { diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index dffa15188431460fbf8a41f76f98258f213f6493..b3bbd648ed612cc9d835e6550261311bf02cb8fa 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -103,6 +103,8 @@ struct FcParam : ParamBase { lite::Tensor* bias{nullptr}; lite::Tensor* output{nullptr}; lite::DDim in_mat_dims; + // original dims of input weight + lite::DDim w_dims; int in_num_col_dims{1}; std::string activation_type{""}; bool padding_weights{false};