未验证 提交 4030e756 编写于 作者: H HappyAngel 提交者: GitHub

[arm]fix concat axis < 0 compute error problem (#3802)

* fix concatt axis < 0 errorr,ttest=develop

* fix format. test=develop
上级 30616633
......@@ -25,21 +25,21 @@ namespace mir {
void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
std::vector<std::string> act_types{"relu"};
bool has_int8 = false;
bool has_arm_float = false;
bool has_arm = false;
bool has_cuda = false;
for (auto& place : graph->valid_places()) {
if (place.precision == PRECISION(kInt8)) {
has_int8 = true;
}
if (place.target == TARGET(kARM) && place.precision == PRECISION(kFloat)) {
has_arm_float = true;
if (place.target == TARGET(kARM)) {
has_arm = true;
}
if (place.target == TARGET(kCUDA)) {
has_cuda = true;
}
}
if (!has_int8 && has_arm_float) {
if (has_arm) {
act_types.push_back("relu6");
act_types.push_back("leaky_relu");
}
......
......@@ -52,11 +52,7 @@ void ConcatFunc(const std::vector<lite::Tensor*> inputs,
output_offset += in_stride[0];
}
} else {
std::vector<lite::Tensor*> inputs_concat(inputs.size());
for (int j = 0; j < inputs.size(); ++j) {
inputs_concat[j] = inputs[j];
}
lite::arm::math::concat_func<T>(inputs_concat, axis, out);
lite::arm::math::concat_func<T>(inputs, axis, out);
}
}
......@@ -71,6 +67,9 @@ void ConcatCompute::Run() {
auto* axis_tensor_data = axis_tensor->data<int>();
axis = axis_tensor_data[0];
}
if (axis < 0) {
axis += inputs[0]->dims().size();
}
switch (inputs.front()->precision()) {
case PRECISION(kFloat):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册