提交 96bcd27a 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4394 Fix fc op's bug

Merge pull request !4394 from zhanyuan/dev
......@@ -352,6 +352,7 @@ table FullConnection {
hasBias: bool;
axis: int;
useAxis: bool;
activationType: ActivationType = 0;
}
// Mean(input_tensor, axis, keep_dims)
......
......@@ -24,7 +24,7 @@ int FullConnection::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
MS_ASSERT(this->primitive != nullptr);
auto input0 = inputs_.front();
MS_ASSERT(input0 != nullptr);
auto input1 = inputs_.at(1);
auto input1 = inputs_[1];
MS_ASSERT(input1 != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
......@@ -33,27 +33,45 @@ int FullConnection::InferShape(std::vector<tensor::Tensor *> inputs_, std::vecto
MS_LOG(ERROR) << "Input tensors num error";
return RET_INPUT_TENSOR_ERROR;
}
if (fc_prim->axis() < 1 || fc_prim->axis() > input0->shape().size()) {
auto axis = fc_prim->axis();
auto use_axis = fc_prim->useAxis();
if (use_axis && (axis < 1 || axis >= input0->shape().size())) {
MS_LOG(ERROR) << "FullConnection axis invalid";
return RET_INPUT_TENSOR_ERROR;
}
int new_k = 1;
for (size_t i = fc_prim->axis(); i < input0->shape().size(); ++i) {
new_k *= input0->shape().at(i);
}
if (new_k != input1->shape().at(1)) {
MS_LOG(ERROR) << "Input1 size invalid";
return RET_PARAM_INVALID;
if (use_axis) {
for (int i = axis; i < input0->shape().size(); ++i) {
new_k *= input0->shape()[i];
}
if (new_k != input1->shape()[1]) {
MS_LOG(ERROR) << "Input1 size invalid";
return RET_PARAM_INVALID;
}
} else {
new_k = input1->shape()[1];
}
if (fc_prim->hasBias()) {
if (inputs_.at(2)->shape()[0] != input1->shape()[0]) {
if (inputs_[2]->shape()[0] != input1->shape()[0]) {
MS_LOG(ERROR) << "bias size invalid";
return RET_PARAM_INVALID;
}
}
std::vector<int> out_shape{inputs_[0]->shape()};
out_shape.resize(fc_prim->axis() + 1);
out_shape[fc_prim->axis()] = input1->shape()[0];
if (use_axis) {
out_shape.resize(fc_prim->axis() + 1);
out_shape[fc_prim->axis()] = input1->shape()[0];
} else {
int total = 1;
for (int i = 0; i < input0->shape().size(); ++i) {
total *= input0->shape()[i];
}
out_shape.resize(2);
auto batch_size = total / new_k;
out_shape[0] = batch_size;
out_shape[1] = input1->shape()[0];
}
output->set_shape(out_shape);
output->set_data_type(input0->data_type());
output->SetFormat(input0->GetFormat());
......
......@@ -226,7 +226,14 @@ OpParameter *PopulateFullconnectionParameter(const lite::Primitive *primitive) {
matmul_param->b_transpose_ = true;
matmul_param->a_transpose_ = false;
matmul_param->has_bias_ = param->hasBias();
matmul_param->act_type_ = ActType_No;
if (param->activationType() == schema::ActivationType_RELU) {
matmul_param->act_type_ = ActType_Relu;
} else if (param->activationType() == schema::ActivationType_RELU6) {
matmul_param->act_type_ = ActType_Relu6;
} else {
matmul_param->act_type_ = ActType_No;
}
return reinterpret_cast<OpParameter *>(matmul_param);
}
......
......@@ -48,6 +48,22 @@ int PowerTestInit(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<lite
return out_t->ElementsNum();
}
int PowerTestInit2(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<lite::tensor::Tensor *> *outputs_,
float *a_ptr, std::vector<int> a_shape, std::vector<int> c_shape) {
auto in_t =
new lite::tensor::Tensor(kNumberTypeFloat, a_shape, schema::Format_NHWC, static_cast<schema::NodeType>(1));
in_t->MallocData();
memcpy(in_t->Data(), a_ptr, sizeof(float) * in_t->ElementsNum());
inputs_->push_back(in_t);
auto out_t =
new lite::tensor::Tensor(kNumberTypeFloat, c_shape, schema::Format_NHWC, static_cast<schema::NodeType>(1));
out_t->MallocData();
outputs_->push_back(out_t);
return out_t->ElementsNum();
}
TEST_F(TestPowerFp32, Simple) {
std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_;
......@@ -62,13 +78,12 @@ TEST_F(TestPowerFp32, Simple) {
int total_size = PowerTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape);
auto ctx = new lite::Context;
ctx->thread_num_ = 1;
kernel::PowerCPUKernel *op = new kernel::PowerCPUKernel(reinterpret_cast<OpParameter *>(param), inputs_, outputs_,
ctx, nullptr);
kernel::PowerCPUKernel *op =
new kernel::PowerCPUKernel(reinterpret_cast<OpParameter *>(param), inputs_, outputs_, ctx, nullptr);
op->Init();
op->Run();
float correct[] = {1, 64, 2187, 65536};
float *output = reinterpret_cast<float *>(outputs_[0]->Data());
for (int i = 0; i < 4; ++i) printf("%f ", output[i]);
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
delete op;
for (auto t : inputs_) delete t;
......@@ -79,18 +94,17 @@ TEST_F(TestPowerFp32, Broadcast) {
std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_;
auto param = new PowerParameter();
param->power_ = 2;
param->scale_ = 1;
param->shift_ = 0;
float a[] = {1, 2, 3, 4};
float b[] = {2};
std::vector<int> a_shape = {2, 2};
std::vector<int> b_shape = {1};
std::vector<int> c_shape = {2, 2};
int total_size = PowerTestInit(&inputs_, &outputs_, a, b, a_shape, b_shape, c_shape);
int total_size = PowerTestInit2(&inputs_, &outputs_, a, a_shape, c_shape);
auto ctx = new lite::Context;
ctx->thread_num_ = 2;
kernel::PowerCPUKernel *op = new kernel::PowerCPUKernel(reinterpret_cast<OpParameter *>(param), inputs_, outputs_,
ctx, nullptr);
kernel::PowerCPUKernel *op =
new kernel::PowerCPUKernel(reinterpret_cast<OpParameter *>(param), inputs_, outputs_, ctx, nullptr);
op->Init();
op->Run();
float correct[] = {1, 4, 9, 16};
......
......@@ -38,6 +38,13 @@ STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr<tflite::OperatorT
MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser";
std::unique_ptr<schema::FullConnectionT> attr(new schema::FullConnectionT());
const auto &tflite_attr = tfliteOp->builtin_options.AsFullyConnectedOptions();
if (tflite_attr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name << " attr failed";
return RET_NULL_PTR;
}
attr->activationType = GetActivationFunctionType(tflite_attr->fused_activation_function);
auto weight_index = tfliteOp->inputs[1];
const auto &weight_tensor = tfliteTensors[weight_index];
if (weight_tensor == nullptr) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册