提交 5963b4ba 编写于 作者: Z zhupengyang 提交者: GitHub

[NPU] add argamx op bridge and unit test (#2580)

test=develop
上级 ce89a79e
......@@ -23,6 +23,7 @@ lite_cc_library(npu_bridge_square_op SRCS square_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_sqrt_op SRCS sqrt_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_reduce_mean_op SRCS reduce_mean_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_unsqueeze_op SRCS unsqueeze_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_argmax_op SRCS argmax_op.cc DEPS ${npu_bridge_deps})
set(npu_bridges
npu_bridge_registry
......@@ -47,6 +48,7 @@ set(npu_bridges
npu_bridge_sqrt_op
npu_bridge_reduce_mean_op
npu_bridge_unsqueeze_op
npu_bridge_argmax_op
CACHE INTERNAL "npu_bridges")
set(npu_bridge_test_deps ${npu_bridges} ${npu_kernels} ${ops})
......@@ -72,5 +74,6 @@ lite_cc_test(test_npu_bridge_square_op SRCS square_op_test.cc test_helper.cc DEP
lite_cc_test(test_npu_bridge_sqrt_op SRCS sqrt_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_reduce_mean_op SRCS reduce_mean_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_unsqueeze_op SRCS unsqueeze_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_argmax_op SRCS argmax_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
message(STATUS "+++++ npu_bridges: ${npu_bridges}")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace npu {
namespace bridges {
node_map_type ArgmaxConverter(const std::shared_ptr<lite::OpLite> argmax_op,
const node_map_type& inputs_map) {
auto scope = argmax_op->scope();
auto op_info = argmax_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "[NPU] Converting " + op_type + "...";
int axis = op_info->GetAttr<int64_t>("axis");
std::shared_ptr<ge::op::ArgMax> argmax_node =
std::make_shared<ge::op::ArgMax>(unique_op_type);
auto x_var_name = op_info->Input("X").front();
CHECK(inputs_map.count(x_var_name));
argmax_node->set_input_x1(*inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(argmax_node);
Tensor x2_t;
x2_t.Resize(std::vector<int64_t>{1});
auto x2_t_data = x2_t.mutable_data<int>();
x2_t_data[0] = axis;
auto x2 = std::make_shared<ge::op::Const>(unique_op_type + "/axis");
x2->set_attr_value(lite::npu::CvtTensor(&x2_t));
argmax_node->set_input_x2(*x2);
lite::npu::OpList::Global().add(x2);
// argmax_node->set_attr_axis(axis);
// argmax only support output_type==int32
// argmax_node->set_attr_output_type(3);
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = argmax_node;
return outputs_map;
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_NPU_BRIDGE(arg_max,
paddle::lite::kernels::npu::bridges::ArgmaxConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/argmax_op.h"
#include <gtest/gtest.h>
#include <cmath>
#include "lite/core/op_registry.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/test_helper.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace npu {
namespace bridges {
template <typename dtype>
void argmax_ref(const std::shared_ptr<operators::ArgmaxOpLite> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindTensor("x");
auto out = scope->FindMutableTensor("out_ref");
int axis = op_info->GetAttr<int64_t>("axis");
auto x_dims = x->dims();
if (axis < 0) {
axis += x_dims.size();
}
auto y_shape = x_dims.Vectorize();
y_shape.erase(y_shape.begin() + axis);
out->Resize(y_shape);
auto out_dims = out->dims();
auto x_data = x->data<dtype>();
auto out_data = out->mutable_data<dtype>();
const int size = x_dims[axis];
const int in_channel = x_dims.count(axis, x_dims.size());
const int out_channel = out_dims.count(axis, out_dims.size());
const int in_stride = x_dims.count(axis + 1, x_dims.size());
const int out_stride = x_dims.count(0, axis);
for (int n = 0; n < out_stride; n++) {
for (int k = 0; k < in_stride; k++) {
const float* in_ptr = x_data + n * in_channel + k;
std::vector<std::pair<float, int>> vec;
vec.resize(size);
for (int i = 0; i < size; i++) {
vec[i] = std::make_pair(in_ptr[i * in_stride], i);
}
// sort
std::partial_sort(vec.begin(),
vec.begin() + 1,
vec.end(),
std::greater<std::pair<float, int>>());
// out
dtype* out_ptr = out_data + n * out_channel + k;
*out_ptr = vec[0].second;
}
}
}
void test_argmax(const std::vector<int64_t>& input_shape, int axis) {
// prepare input&output variables
Scope scope;
std::string x_var_name = "x";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
auto* x = scope.NewTensor(x_var_name);
auto* out = scope.NewTensor(out_var_name);
auto* out_ref = scope.NewTensor(out_ref_var_name);
x->Resize(input_shape);
// initialize input&output data
FillTensor<float>(x);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("arg_max");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetAttr("axis", static_cast<int64_t>(axis));
// create and convert op to NPU model, then run it on NPU
auto op = CreateOp<operators::ArgmaxOpLite>(opdesc, &scope);
LauchOp(op, {x_var_name}, {out_var_name});
// execute reference implementation and save to output tensor
argmax_ref<float>(op);
// compare results
auto* out_data = out->mutable_data<int>();
auto* out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2);
}
}
TEST(NPUBridges, argmax) {
test_argmax({1, 2, 3, 4}, 1);
test_argmax({1, 2, 3, 4}, 2);
test_argmax({1, 2, 3, 4}, 3);
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_OP(arg_max);
USE_NPU_BRIDGE(arg_max);
......@@ -25,6 +25,7 @@ USE_NPU_BRIDGE(leaky_relu);
USE_NPU_BRIDGE(softsign);
USE_NPU_BRIDGE(hard_sigmoid);
USE_NPU_BRIDGE(arg_max);
USE_NPU_BRIDGE(batch_norm);
USE_NPU_BRIDGE(concat);
USE_NPU_BRIDGE(conv2d);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册