未验证 提交 05757e32 编写于 作者: B BUG1989 提交者: GitHub

support resize, scale op for TIM-VX (#653)

上级 78e2c2eb
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* License); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* Copyright (c) 2021, Open AI Lab
* Author: hhchen@openailab.com
*/
#include "timvx_executor.hpp"
extern "C"
{
#include "operator/op.h"
#include "resize_param.h"
}
bool VXEngine::AddResizeNode(struct node* ir_node)
{
struct graph* ir_graph = ir_node->graph;
struct tensor* output_tensor = get_ir_graph_tensor(ir_graph, ir_node->output_tensors[0]);
struct resize_param* param = (struct resize_param*)ir_node->op.param_mem;
tim::vx::ResizeType resize_type;
if (param->type == 0)
{
resize_type = tim::vx::ResizeType::NEAREST_NEIGHBOR;
}
else if(param->type == 1)
{
resize_type = tim::vx::ResizeType::BILINEAR;
}
else
{
TLOG_ERR("Tengine: VX does not support resize type(%d).\n", (int)resize_type);
}
std::vector<std::shared_ptr<tim::vx::Tensor> > add_in_tensor(ir_node->input_num);
for (int i = 0; i < ir_node->input_num; i++)
{
struct tensor* input_tensor = get_ir_graph_tensor(ir_graph, ir_node->input_tensors[i]);
add_in_tensor[i] = this->vx_tensor_map[input_tensor->index];
}
auto resize = graph->CreateOperation<tim::vx::ops::Resize>(resize_type, 0.0f, false, false, output_tensor->dims[2], output_tensor->dims[3]);
vx_node_map[ir_node->index] = resize;
(*resize)
.BindInputs(add_in_tensor)
.BindOutputs({ this->vx_tensor_map[output_tensor->index] });
return true;
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* License); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* Copyright (c) 2021, Open AI Lab
* Author: hhchen@openailab.com
*/
#include "timvx_executor.hpp"
extern "C"
{
#include "operator/op.h"
}
bool VXEngine::AddScaleNode(struct node* ir_node)
{
struct graph* ir_graph = ir_node->graph;
struct tensor* input_tensor = get_ir_graph_tensor(ir_graph, ir_node->input_tensors[0]);
struct tensor* output_tensor = get_ir_graph_tensor(ir_graph, ir_node->output_tensors[0]);
std::vector<uint32_t> perm;
for (int i = output_tensor->dim_num - 1; i >= 0; i--)
{
perm.push_back(output_tensor->dims[i]);
}
auto reshape = graph->CreateOperation<tim::vx::ops::Reshape>(perm);
vx_node_map[ir_node->index] = reshape;
(*reshape)
.BindInputs({ this->vx_tensor_map[input_tensor->index] })
.BindOutputs({ this->vx_tensor_map[output_tensor->index] });
return true;
}
......@@ -230,6 +230,12 @@ int VXEngine::Build(struct subgraph* subgraph)
case OP_RESHAPE:
this->AddReshapeNode(ir_node);
break;
case OP_RESIZE:
this->AddResizeNode(ir_node);
break;
case OP_SCALE:
this->AddScaleNode(ir_node);
break;
case OP_SIGMOID:
this->AddSigmoidNode(ir_node);
break;
......
......@@ -109,6 +109,8 @@ private:
bool AddReluNode(struct node* ir_node);
bool AddRelu1Node(struct node* ir_node);
bool AddReshapeNode(struct node* ir_node);
bool AddResizeNode(struct node* ir_node);
bool AddScaleNode(struct node* ir_node);
bool AddSigmoidNode(struct node* ir_node);
bool AddSliceNode(struct node* ir_node);
bool AddSoftmaxNode(struct node* ir_node);
......
......@@ -92,14 +92,14 @@ const int timvx_supported_ops[] = {
OP_RELU6,
// OP_REORG,
OP_RESHAPE,
// OP_RESIZE,
OP_RESIZE,
// OP_REVERSE,
// OP_RNN,
// OP_ROIALIGN,
// OP_ROIPOOLING,
// OP_ROUND,
// OP_RPN,
// OP_SCALE,
OP_SCALE,
// OP_SELU,
// OP_SHUFFLECHANNEL,
OP_SIGMOID,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册