未验证 提交 4030e756 编写于 作者: H HappyAngel 提交者: GitHub

[arm]fix concat axis < 0 compute error problem (#3802)

* fix concatt axis < 0 errorr,ttest=develop

* fix format. test=develop
上级 30616633
...@@ -25,21 +25,21 @@ namespace mir { ...@@ -25,21 +25,21 @@ namespace mir {
void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
std::vector<std::string> act_types{"relu"}; std::vector<std::string> act_types{"relu"};
bool has_int8 = false; bool has_int8 = false;
bool has_arm_float = false; bool has_arm = false;
bool has_cuda = false; bool has_cuda = false;
for (auto& place : graph->valid_places()) { for (auto& place : graph->valid_places()) {
if (place.precision == PRECISION(kInt8)) { if (place.precision == PRECISION(kInt8)) {
has_int8 = true; has_int8 = true;
} }
if (place.target == TARGET(kARM) && place.precision == PRECISION(kFloat)) { if (place.target == TARGET(kARM)) {
has_arm_float = true; has_arm = true;
} }
if (place.target == TARGET(kCUDA)) { if (place.target == TARGET(kCUDA)) {
has_cuda = true; has_cuda = true;
} }
} }
if (!has_int8 && has_arm_float) { if (has_arm) {
act_types.push_back("relu6"); act_types.push_back("relu6");
act_types.push_back("leaky_relu"); act_types.push_back("leaky_relu");
} }
......
...@@ -52,11 +52,7 @@ void ConcatFunc(const std::vector<lite::Tensor*> inputs, ...@@ -52,11 +52,7 @@ void ConcatFunc(const std::vector<lite::Tensor*> inputs,
output_offset += in_stride[0]; output_offset += in_stride[0];
} }
} else { } else {
std::vector<lite::Tensor*> inputs_concat(inputs.size()); lite::arm::math::concat_func<T>(inputs, axis, out);
for (int j = 0; j < inputs.size(); ++j) {
inputs_concat[j] = inputs[j];
}
lite::arm::math::concat_func<T>(inputs_concat, axis, out);
} }
} }
...@@ -71,6 +67,9 @@ void ConcatCompute::Run() { ...@@ -71,6 +67,9 @@ void ConcatCompute::Run() {
auto* axis_tensor_data = axis_tensor->data<int>(); auto* axis_tensor_data = axis_tensor->data<int>();
axis = axis_tensor_data[0]; axis = axis_tensor_data[0];
} }
if (axis < 0) {
axis += inputs[0]->dims().size();
}
switch (inputs.front()->precision()) { switch (inputs.front()->precision()) {
case PRECISION(kFloat): case PRECISION(kFloat):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册