未验证 提交 b95d214f 编写于 作者: L Leo 提交者: GitHub

add argmax op (#73)

1. add argmax  op

2. add transpose template function
上级 116c4a2a
......@@ -20,6 +20,7 @@ lite_cc_library(subgraph_bridge_interp_op_mlu SRCS interpolate_op.cc DEPS ${subg
lite_cc_library(subgraph_bridge_concat_op_mlu SRCS concat_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_transpose_op_mlu SRCS transpose_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_dropout_op_mlu SRCS dropout_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_argmax_op_mlu SRCS argmax_op.cc DEPS ${subgraph_bridge_deps_mlu})
set(mlu_subgraph_bridges
subgraph_bridge_registry
subgraph_bridge_utility_mlu
......@@ -36,6 +37,7 @@ set(mlu_subgraph_bridges
subgraph_bridge_interp_op_mlu
subgraph_bridge_concat_op_mlu
subgraph_bridge_dropout_op_mlu
subgraph_bridge_argmax_op_mlu
CACHE INTERNAL "mlu_subgraph_bridges")
......@@ -57,6 +59,7 @@ lite_cc_test(test_interp_converter_mlu SRCS interpolate_op_test.cc DEPS scope op
lite_cc_test(test_concat_converter_mlu SRCS concat_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_transpose_converter_mlu SRCS transpose_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_dropout_converter_mlu SRCS dropout_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_argmax_converter_mlu SRCS argmax_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
if (LITE_BUILD_EXTRA)
lite_cc_test(test_lrn_converter_mlu SRCS lrn_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
endif()
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/mlu/bridges/graph.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
int ArgmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[MLU] Converting " + op_type + "...";
// Get input vars and op attributes
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
auto x_dims = x->dims().Vectorize();
auto out_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize();
int axis = op_info->GetAttr<int64_t>("axis");
cnmlDimension_t argmax_mode = static_cast<cnmlDimension_t>(axis);
auto mlu_output_dim = x->dims().Vectorize();
// shape is NCHW, layout is NHWC
mlu_output_dim[axis] = 1;
auto output_tensor = graph->AddNode(
out_var_name, mlu_output_dim, CNML_TENSOR, CNML_NCHW, graph->FPType());
CHECK(graph->HasNode(x_var_name));
auto input_tensor = graph->GetNode(x_var_name);
cnmlBaseOp_t argmax_op{nullptr};
CNML_CALL(cnmlCreateArgmaxOp(&argmax_op,
argmax_mode,
input_tensor->mlu_tensor(),
output_tensor->mlu_tensor()));
graph->FuseOp(argmax_op);
CNML_CALL(cnmlDestroyBaseOp(&argmax_op));
return SUCCESS;
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(argmax,
kMLU,
paddle::lite::subgraph::mlu::ArgmaxConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/argmax_op.h"
#include <gtest/gtest.h>
#include <cmath>
#include <iostream>
#include "lite/core/op_registry.h"
#include "lite/kernels/mlu/bridges/test_helper.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
template <typename dtype, typename out_dtype>
void argmax_ref(const std::shared_ptr<operators::ArgmaxOpLite> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindVar(op_info->Input("X").front())->GetMutable<Tensor>();
auto out =
scope->FindVar(op_info->Output("Out").front())->GetMutable<Tensor>();
int axis = op_info->GetAttr<int64_t>("axis");
auto x_dims = x->dims();
if (axis < 0) {
axis += x_dims.size();
}
auto y_shape = x_dims.Vectorize();
y_shape.erase(y_shape.begin() + axis);
out->Resize(y_shape);
auto out_dims = out->dims();
auto* x_data = x->mutable_data<dtype>();
auto* out_data = out->mutable_data<out_dtype>();
const int size = x_dims[axis];
const int in_channel = x_dims.count(axis, x_dims.size());
const int out_channel = out_dims.count(axis, out_dims.size());
const int in_stride = x_dims.count(axis + 1, x_dims.size());
const int out_stride = x_dims.count(0, axis);
// int index = 0;
for (int n = 0; n < out_stride; n++) {
for (int k = 0; k < in_stride; k++) {
const float* in_ptr = x_data + n * in_channel + k;
std::vector<std::pair<float, int>> vec;
vec.resize(size);
for (int i = 0; i < size; i++) {
vec[i] = std::make_pair(in_ptr[i * in_stride], i);
}
// sort
std::partial_sort(vec.begin(),
vec.begin() + 1,
vec.end(),
std::greater<std::pair<float, int>>());
out_dtype* out_ptr = out_data + n * out_channel + k;
*out_ptr = vec[0].second;
}
}
}
void test_argmax(const std::vector<int64_t>& input_shape, int axis) {
// prepare input&output variables
Scope scope;
std::string x_var_name = "x";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize(input_shape);
// initialize input&output data
FillTensor<float, float>(x, -9, 9);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("argmax");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetAttr("axis", static_cast<int64_t>(axis));
// create and convert op to MLU model, then run it on MLU
auto op = CreateOp<operators::ArgmaxOpLite>(opdesc, &scope);
argmax_ref<float, int>(op);
out_ref->CopyDataFrom(*out);
Tensor input_x;
input_x.Resize(DDim(input_shape));
// change input layout from NCHW to NHWC
transpose<float*>(x->mutable_data<float>(),
input_x.mutable_data<float>(),
{static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[1]),
static_cast<int>(input_shape[2]),
static_cast<int>(input_shape[3])},
{0, 2, 3, 1});
x->CopyDataFrom(input_x);
LaunchOp(op, {x_var_name}, {out_var_name});
auto* out_data = out->mutable_data<int>();
auto* out_ref_data = out_ref->mutable_data<int>();
std::vector<int64_t> out_shape = input_shape;
out_shape[axis] = 1;
Tensor output_trans;
output_trans.Resize(out_shape);
// Change output layout from NHWC to NCHW
transpose<int*>(out_data,
output_trans.mutable_data<int>(),
{static_cast<int>(out_shape[0]),
static_cast<int>(out_shape[2]),
static_cast<int>(out_shape[3]),
static_cast<int>(out_shape[1])},
{0, 3, 1, 2});
out_data = output_trans.mutable_data<int>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2);
}
}
TEST(MLUBridges, argmax) {
test_argmax({1, 2, 3, 4}, 1);
test_argmax({1, 2, 3, 4}, 2);
test_argmax({1, 2, 3, 4}, 3);
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
USE_SUBGRAPH_BRIDGE(argmax, kMLU);
......@@ -31,6 +31,7 @@ USE_SUBGRAPH_BRIDGE(scale, kMLU);
USE_SUBGRAPH_BRIDGE(sigmoid, kMLU);
USE_SUBGRAPH_BRIDGE(elementwise_mul, kMLU);
USE_SUBGRAPH_BRIDGE(dropout, kMLU);
USE_SUBGRAPH_BRIDGE(argmax, kMLU);
#ifdef LITE_BUILD_EXTRA
USE_SUBGRAPH_BRIDGE(lrn, kMLU)
#endif
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "lite/kernels/mlu/bridges/utility.h"
#include <utility>
namespace paddle {
......
......@@ -16,9 +16,11 @@
#include <cnml.h>
#include <cnrt.h>
#include <memory>
#include <string>
#include <vector>
#include "lite/backends/mlu/mlu_utils.h"
#include "lite/core/op_lite.h"
#include "lite/core/tensor.h"
......@@ -32,6 +34,38 @@ namespace mlu {
void transpose2d(float* input_data,
float* output_data,
std::vector<int> input_shape);
template <typename dtype>
void transpose(dtype input_data,
dtype output_data,
std::vector<int> input_shape,
std::vector<int> axis);
template <typename dtype>
void transpose(dtype input_data,
dtype output_data,
std::vector<int> input_shape,
std::vector<int> axis) {
int old_index = -1;
int new_index = -1;
int dim[4] = {0};
std::vector<int> shape = input_shape;
for (dim[0] = 0; dim[0] < input_shape[0]; dim[0]++) {
for (dim[1] = 0; dim[1] < input_shape[1]; dim[1]++) {
for (dim[2] = 0; dim[2] < input_shape[2]; dim[2]++) {
for (dim[3] = 0; dim[3] < input_shape[3]; dim[3]++) {
old_index = dim[0] * shape[1] * shape[2] * shape[3] +
dim[1] * shape[2] * shape[3] + dim[2] * shape[3] + dim[3];
new_index =
dim[axis[0]] * shape[axis[1]] * shape[axis[2]] * shape[axis[3]] +
dim[axis[1]] * shape[axis[2]] * shape[axis[3]] +
dim[axis[2]] * shape[axis[3]] + dim[axis[3]];
output_data[new_index] = input_data[old_index];
}
}
}
}
}
void transpose(float* input_data,
float* output_data,
std::vector<int> input_shape,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册