提交 f7ee19a5 编写于 作者: Z zhanyuan

1.Fix bugs of some InferShape. 2.Fix the bug of fc int8

上级 b4925d3e
......@@ -141,8 +141,8 @@ int DeDepthwiseConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
pad_u_ = GetPadUp();
pad_d_ = GetPadDown();
pad_r_ = GetPadRight();
output_h = GetStrideH() * (input_h - 1) * GetKernelH() - pad_u_ - pad_d_;
output_w = GetStrideW() * (input_w - 1) * GetKernelW() - pad_l_ - pad_r_;
output_h = GetStrideH() * (input_h - 1) + GetKernelH() - pad_u_ - pad_d_;
output_w = GetStrideW() * (input_w - 1) + GetKernelW() - pad_l_ - pad_r_;
if ((output_h + GetPadUp() + GetPadDown() - GetKernelH()) % GetStrideH() != 0) {
output_h += (output_h + GetPadLeft() + GetPadRight() - GetKernelH()) % GetStrideH();
}
......
......@@ -28,7 +28,7 @@ void FullConnection::SetHasBias(bool has_bias) { this->primitive->value.AsFullCo
void FullConnection::SetAxis(int axis) { this->primitive->value.AsFullConnection()->axis = axis; }
void FullConnection::SetUseAxis(bool use_axis) { this->primitive->value.AsFullConnection()->useAxis = use_axis; }
void FullConnection::SetActivationType(int activationType) {
this->primitive->value.AsFullConnection()->activationType = (schema::ActivationType) activationType;
this->primitive->value.AsFullConnection()->activationType = (schema::ActivationType)activationType;
}
#else
......@@ -47,43 +47,58 @@ int FullConnection::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
MS_ASSERT(this->primitive != nullptr);
auto input0 = inputs_.front();
MS_ASSERT(input0 != nullptr);
auto input1 = inputs_.at(1);
auto input1 = inputs_[1];
MS_ASSERT(input1 != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
if (!GetInferFlag()) {
return RET_OK;
}
if ((GetHasBias() && inputs_.size() != kMultiNum) || (!GetHasBias() && inputs_.size() != kDoubleNum)) {
MS_LOG(ERROR) << "Input tensors num error";
return 1;
return RET_INPUT_TENSOR_ERROR;
}
if (GetAxis() < 1 || GetAxis() > static_cast<int>(input0->shape().size())) {
if (GetUseAxis() && (GetAxis() < 1 || GetAxis() > static_cast<int>(input0->shape().size()))) {
MS_LOG(ERROR) << "FullConnection axis invalid";
return 1;
return RET_ERROR;
}
int new_k = 1;
for (size_t i = GetAxis(); i < input0->shape().size(); ++i) {
new_k *= input0->shape().at(i);
}
if (new_k != input1->shape().at(1)) {
MS_LOG(ERROR) << "Input1 size invalid";
return 1;
if (GetUseAxis()) {
for (int i = GetAxis(); i < input0->shape().size(); ++i) {
new_k *= input0->shape()[i];
}
if (new_k != input1->shape()[1]) {
MS_LOG(ERROR) << "Input1 size invalid";
return RET_INPUT_TENSOR_ERROR;
}
} else {
new_k = input1->shape()[1];
}
if (GetHasBias()) {
if (inputs_.at(2)->shape()[0] != input1->shape()[0]) {
if (inputs_[2]->shape()[0] != input1->shape()[0]) {
MS_LOG(ERROR) << "bias size invalid";
return 1;
return RET_INPUT_TENSOR_ERROR;
}
}
std::vector<int> out_shape{inputs_[0]->shape()};
out_shape.resize(GetAxis() + 1);
out_shape[GetAxis()] = input1->shape()[0];
if (GetUseAxis()) {
out_shape.resize(GetAxis() + 1);
out_shape[GetAxis()] = input1->shape()[0];
} else {
int total = 1;
for (int i = 0; i < input0->shape().size(); ++i) {
total *= input0->shape()[i];
}
out_shape.resize(2);
auto batch_size = total / new_k;
out_shape[0] = batch_size;
out_shape[1] = input1->shape()[0];
}
output->set_shape(out_shape);
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
return 0;
return RET_OK;
}
} // namespace lite
} // namespace mindspore
......@@ -91,8 +91,8 @@ int FullconnectionInt8CPUKernel::ReSize() {
QuantizeRoundParameter(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.left_shift,
&quant_params_.right_shift);
CalculateActivationRangeQuantized(fc_param_->act_type_ == ActType_Relu, fc_param_->act_type_ == ActType_Relu6,
quant_params_.output.zp_, quant_params_.output.scale_, &quant_params_.out_act_max,
&quant_params_.out_act_min);
quant_params_.output.zp_, quant_params_.output.scale_, &quant_params_.out_act_min,
&quant_params_.out_act_max);
return RET_OK;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册