未验证 提交 81774daf 编写于 作者: M MaxwellDing 提交者: GitHub

[MLU] feat: add kernels, test=develop (#3915)

add mlu kernels  argmax, flatten, slice, transpose
上级 55db1963
......@@ -18,12 +18,16 @@ lite_cc_library(subgraph_bridge_fc_op_mlu SRCS fc_op.cc DEPS ${subgraph_bridge_d
lite_cc_library(subgraph_bridge_scale_op_mlu SRCS scale_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_interp_op_mlu SRCS interpolate_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_concat_op_mlu SRCS concat_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_transpose_op_mlu SRCS transpose_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_dropout_op_mlu SRCS dropout_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_slice_op_mlu SRCS slice_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_split_op_mlu SRCS split_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_cast_op_mlu SRCS cast_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_layout_op_mlu SRCS layout_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_argmax_op_mlu SRCS argmax_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_squeeze_op_mlu SRCS squeeze_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_reshape_op_mlu SRCS reshape_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_flatten_op_mlu SRCS flatten_op.cc DEPS ${subgraph_bridge_deps_mlu})
set(mlu_subgraph_bridges
subgraph_bridge_registry
subgraph_bridge_utility_mlu
......@@ -34,16 +38,20 @@ set(mlu_subgraph_bridges
subgraph_bridge_pool_op_mlu
subgraph_bridge_softmax_op_mlu
subgraph_bridge_fc_op_mlu
subgraph_bridge_transpose_op_mlu
subgraph_bridge_batch_norm_op_mlu
subgraph_bridge_scale_op_mlu
subgraph_bridge_interp_op_mlu
subgraph_bridge_concat_op_mlu
subgraph_bridge_dropout_op_mlu
subgraph_bridge_slice_op_mlu
subgraph_bridge_split_op_mlu
subgraph_bridge_cast_op_mlu
subgraph_bridge_layout_op_mlu
subgraph_bridge_argmax_op_mlu
subgraph_bridge_squeeze_op_mlu
subgraph_bridge_reshape_op_mlu
subgraph_bridge_flatten_op_mlu
CACHE INTERNAL "mlu_subgraph_bridges")
......@@ -58,10 +66,14 @@ lite_cc_test(test_fc_converter_mlu SRCS fc_op_test.cc DEPS scope optimizer targe
lite_cc_test(test_scale_converter_mlu SRCS scale_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_interp_converter_mlu SRCS interpolate_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_concat_converter_mlu SRCS concat_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_transpose_converter_mlu SRCS transpose_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_dropout_converter_mlu SRCS dropout_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_slice_converter_mlu SRCS slice_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_split_converter_mlu SRCS split_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_layout_converter_mlu SRCS layout_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_cast_converter_mlu SRCS cast_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_argmax_converter_mlu SRCS argmax_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_squeeze_converter_mlu SRCS squeeze_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_reshape_converter_mlu SRCS reshape_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_flatten_converter_mlu SRCS flatten_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
message(STATUS "+++++ mlu_subgraph_bridges: ${mlu_subgraph_bridges}")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/mlu/bridges/graph.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
int ArgmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[MLU] Converting " + op_type + "...";
// Get input vars and op attributes
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
auto x_dims = x->dims().Vectorize();
auto out_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize();
int axis = op_info->GetAttr<int64_t>("axis");
if (axis < 0) {
axis = axis + x_dims.size();
}
cnmlDimension_t argmax_mode = static_cast<cnmlDimension_t>(axis);
auto mlu_output_dim = x->dims().Vectorize();
// shape is NCHW, layout is NHWC
mlu_output_dim[axis] = 1;
auto input_tensor = graph->GetNode(x_var_name);
// if use_fp16 and axis is not c, cast input datatype from fp16 to fp32, so
// output datatype is int32
bool cast_to_fp32 =
graph->FPType() == CNML_DATA_FLOAT16 && argmax_mode != CNML_DIM_C;
cnmlBaseOp_t cast_op{nullptr};
std::shared_ptr<MLUTensor> fp32_input_tensor;
if (cast_to_fp32) {
fp32_input_tensor = graph->AddNode(x_var_name + ".fp32",
x_dims,
CNML_TENSOR,
CNML_NCHW,
CNML_DATA_FLOAT32);
cnmlCreateCastOp(&cast_op,
CNML_CAST_FLOAT16_TO_FLOAT32,
input_tensor->mlu_tensor(),
fp32_input_tensor->mlu_tensor());
}
auto output_tensor = graph->AddNode(
out_var_name, mlu_output_dim, CNML_TENSOR, CNML_NCHW, CNML_DATA_INT32);
CHECK(graph->HasNode(x_var_name));
cnmlBaseOp_t argmax_op{nullptr};
// ======================= DEBUG INFO =====================
VLOG(6) << "x_var_name: " << x_var_name;
VLOG(6) << "out_var_name: " << out_var_name;
VLOG(6) << "x dims: " << x->dims();
VLOG(6) << "output dims: " << output->dims();
VLOG(6) << "axis: " << axis;
VLOG(6) << "cast_to_fp32: " << cast_to_fp32;
cnmlPrintTensor(input_tensor->mlu_tensor(), CNML_TENSOR);
cnmlPrintTensor(output_tensor->mlu_tensor(), CNML_TENSOR);
// ======================= DEBUG END =====================
CNML_CALL(cnmlCreateArgmaxOp(&argmax_op,
argmax_mode,
cast_to_fp32 ? fp32_input_tensor->mlu_tensor()
: input_tensor->mlu_tensor(),
output_tensor->mlu_tensor()));
if (cast_to_fp32) {
graph->FuseOp(cast_op);
}
graph->FuseOp(argmax_op);
CNML_CALL(cnmlDestroyBaseOp(&argmax_op));
if (cast_op) {
CNML_CALL(cnmlDestroyBaseOp(&cast_op));
}
return SUCCESS;
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(arg_max,
kMLU,
paddle::lite::subgraph::mlu::ArgmaxConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/argmax_op.h"
#include <gtest/gtest.h>
#include <cmath>
#include <iostream>
#include "lite/core/op_registry.h"
#include "lite/kernels/mlu/bridges/test_helper.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
template <typename dtype, typename out_dtype>
void argmax_ref(const std::shared_ptr<operators::ArgmaxOpLite> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindVar(op_info->Input("X").front())->GetMutable<Tensor>();
auto out =
scope->FindVar(op_info->Output("Out").front())->GetMutable<Tensor>();
int axis = op_info->GetAttr<int64_t>("axis");
auto x_dims = x->dims();
if (axis < 0) {
axis += x_dims.size();
}
auto y_shape = x_dims.Vectorize();
y_shape.erase(y_shape.begin() + axis);
out->Resize(y_shape);
auto out_dims = out->dims();
auto* x_data = x->mutable_data<dtype>();
auto* out_data = out->mutable_data<out_dtype>();
const int size = x_dims[axis];
const int in_channel = x_dims.count(axis, x_dims.size());
const int out_channel = out_dims.count(axis, out_dims.size());
const int in_stride = x_dims.count(axis + 1, x_dims.size());
const int out_stride = x_dims.count(0, axis);
// int index = 0;
for (int n = 0; n < out_stride; n++) {
for (int k = 0; k < in_stride; k++) {
const float* in_ptr = x_data + n * in_channel + k;
std::vector<std::pair<float, int>> vec;
vec.resize(size);
for (int i = 0; i < size; i++) {
vec[i] = std::make_pair(in_ptr[i * in_stride], i);
}
// sort
std::partial_sort(vec.begin(),
vec.begin() + 1,
vec.end(),
std::greater<std::pair<float, int>>());
out_dtype* out_ptr = out_data + n * out_channel + k;
*out_ptr = vec[0].second;
}
}
}
void test_argmax(const std::vector<int64_t>& input_shape, int axis) {
// prepare input&output variables
Scope scope;
std::string x_var_name = "x";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize(input_shape);
// initialize input&output data
FillTensor<float, float>(x, -9, 9);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("arg_max");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetAttr("axis", static_cast<int64_t>(axis));
// create and convert op to MLU model, then run it on MLU
auto op = CreateOp<operators::ArgmaxOpLite>(opdesc, &scope);
argmax_ref<float, int>(op);
out_ref->CopyDataFrom(*out);
Tensor input_x;
input_x.Resize(DDim(input_shape));
// change input layout from NCHW to NHWC
transpose<float>(x->mutable_data<float>(),
input_x.mutable_data<float>(),
{static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[1]),
static_cast<int>(input_shape[2]),
static_cast<int>(input_shape[3])},
{0, 2, 3, 1});
x->CopyDataFrom(input_x);
LaunchOp(op, {x_var_name}, {out_var_name});
auto* out_data = out->mutable_data<int>();
auto* out_ref_data = out_ref->mutable_data<int>();
std::vector<int64_t> out_shape = input_shape;
out_shape[axis] = 1;
Tensor output_trans;
output_trans.Resize(out_shape);
// Change output layout from NHWC to NCHW
transpose<int>(out_data,
output_trans.mutable_data<int>(),
{static_cast<int>(out_shape[0]),
static_cast<int>(out_shape[2]),
static_cast<int>(out_shape[3]),
static_cast<int>(out_shape[1])},
{0, 3, 1, 2});
out_data = output_trans.mutable_data<int>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2);
}
}
TEST(MLUBridges, arg_max) {
test_argmax({1, 2, 3, 4}, 1);
test_argmax({1, 2, 3, 4}, 2);
test_argmax({1, 2, 3, 4}, 3);
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
USE_SUBGRAPH_BRIDGE(arg_max, kMLU);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/mlu/bridges/graph.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
int FlattenConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[MLU] Converting " + op_type + "...";
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
auto x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize();
// ================== Trans1: NHWC => NCHW ===========================
auto input_tensor = graph->GetNode(x_var_name);
auto trans_1_axis = std::move(GetAxisNHWC2NCHW<int>(x->dims().size()));
auto trans1_out = graph->AddNode(x_var_name + ".trans.i",
x->dims().Vectorize(),
CNML_TENSOR,
CNML_NCHW,
graph->FPType(),
CNML_NCHW);
cnmlBaseOp_t trans1_op{nullptr};
cnmlNdTransposeOpParam_t trans1_param{nullptr};
CNML_CALL(cnmlCreateNdTransposeOpParam(
&trans1_param, trans_1_axis.data(), trans_1_axis.size()));
CNML_CALL(cnmlCreateNdTransposeProOp(&trans1_op,
input_tensor->mlu_tensor(),
trans1_out->mlu_tensor(),
trans1_param));
// ======================== Trans1 End ==================================
// ======================= Flatten op ===================================
cnmlBaseOp_t flatten_op;
auto trans2_input = graph->AddNode(out_var_name + ".trans.o",
output_dims,
CNML_TENSOR,
CNML_NCHW,
graph->FPType(),
CNML_NCHW);
int cnml_trans2_input_shape[4];
CNML_CALL(
cnmlGetTensorShape(trans2_input->mlu_tensor(), cnml_trans2_input_shape));
cnmlReshapeOpParam_t reshape_param{nullptr};
CNML_CALL(cnmlCreateNdReshapeOpParam(
&reshape_param, cnml_trans2_input_shape, output->dims().size()));
// Use cnmlCreatexxxOpForward to create op.
CNML_CALL(cnmlCreateReshapeOp(&flatten_op,
reshape_param,
trans1_out->mlu_tensor(),
trans2_input->mlu_tensor()));
// ======================= Flatten End ===================================
// ================== Trans2: NCHW => NHWC ===============================
auto trans_2_axis = std::move(GetAxisNCHW2NHWC<int>(output->dims().size()));
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
cnmlBaseOp_t trans2_op{nullptr};
cnmlNdTransposeOpParam_t trans2_param{nullptr};
CNML_CALL(cnmlCreateNdTransposeOpParam(
&trans2_param, trans_2_axis.data(), trans_2_axis.size()));
CNML_CALL(cnmlCreateNdTransposeProOp(&trans2_op,
trans2_input->mlu_tensor(),
output_tensor->mlu_tensor(),
trans2_param));
// ======================== Trans2 End ==================================
// ============== DEBUG LOG ===============
VLOG(6) << "x_var_name: " << x_var_name;
VLOG(6) << "out_var_name: " << out_var_name;
VLOG(6) << "input dim: " << x->dims();
VLOG(6) << "output dim: " << output->dims();
// cnmlPrintTensor(input_tensor->mlu_tensor(), CNML_TENSOR);
// cnmlPrintTensor(trans1_out->mlu_tensor(), CNML_TENSOR);
// cnmlPrintTensor(trans2_input->mlu_tensor(), CNML_TENSOR);
// cnmlPrintTensor(output_tensor->mlu_tensor(), CNML_TENSOR);
// ============== DEBUG END ===============
graph->FuseOp(trans1_op);
graph->FuseOp(flatten_op);
graph->FuseOp(trans2_op);
CNML_CALL(cnmlDestroyBaseOp(&trans1_op));
CNML_CALL(cnmlDestroyBaseOp(&flatten_op));
CNML_CALL(cnmlDestroyBaseOp(&trans2_op));
return SUCCESS;
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(flatten,
kMLU,
paddle::lite::subgraph::mlu::FlattenConverter);
REGISTER_SUBGRAPH_BRIDGE(flatten2,
kMLU,
paddle::lite::subgraph::mlu::FlattenConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/flatten_op.h"
#include <gtest/gtest.h>
#include <random>
#include "lite/core/op_registry.h"
#include "lite/kernels/mlu/bridges/test_helper.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
void test_flatten(std::vector<int64_t> input_shape, int axis) {
// prepare input&output variables
Scope scope;
std::string x_var_name("x");
std::string out_var_name("out");
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
x->Resize(input_shape);
Tensor x_cpu;
// initialize input&output data
FillTensor<float, int>(x);
x_cpu.CopyDataFrom(*x);
Tensor input_trans;
input_trans.Resize(input_shape);
transpose(x->mutable_data<float>(),
input_trans.mutable_data<float>(),
{static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[1]),
static_cast<int>(input_shape[2]),
static_cast<int>(input_shape[3])},
{0, 2, 3, 1});
x->CopyDataFrom(input_trans);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("flatten2");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetAttr<int>("axis", axis);
auto op = CreateOp<operators::FlattenOp>(opdesc, &scope);
LaunchOp(op, {x_var_name}, {out_var_name});
// compare results
auto* out_data = out->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], x_cpu.mutable_data<float>()[i], 1e-5);
}
}
TEST(MLUBridges, flatten) { test_flatten({1, 2, 4, 4}, 2); }
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
USE_SUBGRAPH_BRIDGE(flatten, kMLU);
USE_SUBGRAPH_BRIDGE(flatten2, kMLU);
......@@ -25,15 +25,21 @@ USE_SUBGRAPH_BRIDGE(batch_norm, kMLU);
USE_SUBGRAPH_BRIDGE(fc, kMLU);
USE_SUBGRAPH_BRIDGE(nearest_interp, kMLU);
USE_SUBGRAPH_BRIDGE(leaky_relu, kMLU);
USE_SUBGRAPH_BRIDGE(transpose, kMLU);
USE_SUBGRAPH_BRIDGE(transpose2, kMLU);
USE_SUBGRAPH_BRIDGE(concat, kMLU);
USE_SUBGRAPH_BRIDGE(scale, kMLU);
USE_SUBGRAPH_BRIDGE(sigmoid, kMLU);
USE_SUBGRAPH_BRIDGE(elementwise_mul, kMLU);
USE_SUBGRAPH_BRIDGE(dropout, kMLU);
USE_SUBGRAPH_BRIDGE(arg_max, kMLU);
USE_SUBGRAPH_BRIDGE(split, kMLU);
USE_SUBGRAPH_BRIDGE(cast, kMLU);
USE_SUBGRAPH_BRIDGE(layout, kMLU);
USE_SUBGRAPH_BRIDGE(slice, kMLU);
USE_SUBGRAPH_BRIDGE(squeeze, kMLU);
USE_SUBGRAPH_BRIDGE(squeeze2, kMLU);
USE_SUBGRAPH_BRIDGE(flatten, kMLU);
USE_SUBGRAPH_BRIDGE(flatten2, kMLU);
USE_SUBGRAPH_BRIDGE(reshape, kMLU);
USE_SUBGRAPH_BRIDGE(reshape2, kMLU);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/mlu/bridges/graph.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
int SliceConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto scope = op->scope();
auto op_info = op->op_info();
auto op_type = op_info->Type();
VLOG(3) << "[MLU] Converting " + op_type + "...";
// input
auto input_var_name = op_info->Input("Input").front();
auto input = scope->FindVar(input_var_name)->GetMutable<lite::Tensor>();
auto input_shape = input->dims().Vectorize();
// output
auto output_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
// attr
auto axes = op_info->GetAttr<std::vector<int32_t>>("axes");
auto starts = op_info->GetAttr<std::vector<int32_t>>("starts");
auto ends = op_info->GetAttr<std::vector<int32_t>>("ends");
CHECK(graph->HasNode(input_var_name));
auto input_tensor = graph->GetNode(input_var_name);
auto output_tensor = graph->AddNode(output_var_name,
output->dims().Vectorize(),
CNML_TENSOR,
CNML_NCHW,
graph->FPType());
std::vector<int32_t> begin_index(input_shape.size(), 0);
std::vector<int32_t> end_index(input_shape.size());
std::vector<int32_t> strides(input_shape.size(), 1);
auto nhwc2nchw_axis = std::move(GetAxisNHWC2NCHW<int>(input_shape.size()));
for (size_t i = 0; i < input_shape.size(); ++i) {
end_index[nhwc2nchw_axis[i]] = input_shape[i];
}
for (size_t i = 0; i < axes.size(); i++) {
int dim_value = input_shape[axes[i]];
int end = ends[i] < 0 ? std::max(ends[i] + dim_value, 0) : ends[i];
begin_index[nhwc2nchw_axis[axes[i]]] =
starts[i] < 0 ? std::max(starts[i] + dim_value, 0) : starts[i];
end_index[nhwc2nchw_axis[axes[i]]] = std::min(end, dim_value);
}
cnmlNdStridedSliceOpParam_t param;
cnmlBaseOp_t slice_op;
CNML_CALL(cnmlCreateNdStridedSliceOpParam(&param,
input_shape.size(),
begin_index.data(),
end_index.data(),
strides.data()));
CNML_CALL(cnmlCreateNdStridedSliceOp(&slice_op,
param,
input_tensor->mlu_tensor(),
output_tensor->mlu_tensor()));
CNML_CALL(cnmlDestroyNdStridedSliceOpParam(&param));
graph->FuseOp(slice_op);
CNML_CALL(cnmlDestroyBaseOp(&slice_op));
return SUCCESS;
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(slice,
kMLU,
paddle::lite::subgraph::mlu::SliceConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/slice_op.h"
#include <gtest/gtest.h>
#include <utility>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/mlu/bridges/test_helper.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
static void slice_ref(const float* input,
std::vector<int64_t> in_dims,
std::vector<int> axes,
std::vector<int> starts,
std::vector<int> ends,
float* out) {
auto out_dims = in_dims;
std::vector<int> real_starts(in_dims.size(), 0);
std::vector<int> real_ends(in_dims.size(), 0);
std::vector<int> real_step(in_dims.size(), 0);
for (size_t i = 0; i < in_dims.size(); i++) {
real_ends[i] = in_dims[i];
}
for (size_t i = 0; i < axes.size(); i++) {
int dim_value = in_dims[axes[i]];
if (dim_value > 0) {
int start = starts[i] < 0 ? (starts[i] + dim_value) : starts[i];
int end = ends[i] < 0 ? (ends[i] + dim_value) : ends[i];
start = std::max(start, 0);
end = std::max(end, 0);
end = std::min(end, dim_value);
out_dims[axes[i]] = end - start;
real_starts[axes[i]] = start;
real_ends[axes[i]] = end;
}
}
const int LEN = in_dims.size();
int dst_step[LEN];
for (size_t i = 0; i < in_dims.size(); ++i) {
dst_step[i] = 1;
}
int src_step[LEN];
for (size_t i = 0; i < in_dims.size(); ++i) {
src_step[i] = 1;
}
int out_num = out_dims[in_dims.size() - 1];
for (int i = in_dims.size() - 2; i >= 0; i--) {
dst_step[i] = out_dims[i + 1] * dst_step[i + 1];
src_step[i] = in_dims[i + 1] * src_step[i + 1];
out_num *= out_dims[i];
}
for (int dst_id = 0; dst_id < out_num; dst_id++) {
int src_id = 0;
int index_id = dst_id;
for (size_t j = 0; j < out_dims.size(); j++) {
int cur_id = index_id / dst_step[j];
index_id = index_id % dst_step[j];
src_id += (cur_id + real_starts[j]) * src_step[j];
}
out[dst_id] = input[src_id];
}
}
static void test_case(std::vector<int64_t> x_shape,
std::vector<int64_t> out_shape,
std::vector<int> starts,
std::vector<int> ends,
std::vector<int> axes) {
Scope scope;
std::string x_var_name = "x";
std::string out_var_name = "out";
auto* x = scope.NewTensor(x_var_name);
auto* out = scope.NewTensor(out_var_name);
x->Resize(lite::DDim(x_shape));
out->Resize(lite::DDim(out_shape));
auto x_data = x->mutable_data<float>();
FillTensor<float, float>(x, 0.f, 2.f);
cpp::OpDesc opdesc;
opdesc.SetType("slice");
opdesc.SetInput("Input", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetAttr("axes", axes);
opdesc.SetAttr("starts", starts);
opdesc.SetAttr("ends", ends);
std::vector<float> out_ref(out->data_size(), 0);
slice_ref(x_data, x_shape, axes, starts, ends, out_ref.data());
auto type_cast = [](int64_t in) { return static_cast<int>(in); };
std::vector<int> i_dims;
std::transform(
x_shape.cbegin(), x_shape.cend(), std::back_inserter(i_dims), type_cast);
auto nchw2nhwc_axis = std::move(GetAxisNCHW2NHWC<int>(x_shape.size()));
Tensor input_x;
input_x.Resize(x->dims());
transpose<float>(x->mutable_data<float>(),
input_x.mutable_data<float>(),
i_dims,
nchw2nhwc_axis);
x->CopyDataFrom(input_x);
auto op = CreateOp<operators::SliceOp>(opdesc, &scope);
LaunchOp(op, {x_var_name}, {out_var_name});
Tensor output_trans;
auto os = out->dims().Vectorize();
output_trans.Resize(os);
std::vector<int> o_dims(os.size());
for (size_t i = 0; i < os.size(); ++i) {
o_dims[i] = os[nchw2nhwc_axis[i]];
}
transpose<float>(out->mutable_data<float>(),
output_trans.mutable_data<float>(),
o_dims,
GetAxisNHWC2NCHW<int>(x_shape.size()));
auto out_data = output_trans.mutable_data<float>();
for (DDim::value_type i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_ref[i], out_data[i], 1e-4);
}
}
TEST(MLUBridges, slice) {
/* test_case({3}, {3}, {-3}, {3}, {0}); */
test_case({3, 4}, {3, 4}, {-3, 0}, {3, 100}, {0, 1});
test_case({3, 4, 5}, {3, 4, 2}, {-3, 0, 2}, {3, 100, -1}, {0, 1, 2});
test_case({3, 4, 5, 6}, {3, 4, 2, 6}, {-3, 0, 2}, {3, 100, -1}, {0, 1, 2});
/* test_case({3, 4, 5, 6, 3}, {3, 4, 2, 6, 3}, {-3, 0, 2}, {3, 100, -1}, {0,
* 1, 2}); */
/* test_case({3, 4, 5, 6, 5, 2}, {3, 4, 2, 6, 5, 2}, {-3, 0, 2}, {3, 100, 1},
* {0, 1, 2}); */
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
USE_SUBGRAPH_BRIDGE(slice, kMLU);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/mlu/bridges/graph.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
std::vector<int> axis_to_nhwc(const std::vector<int>& axis) {
std::vector<int> new_axis(axis.size());
auto nhwc2nchw_axis = std::move(GetAxisNHWC2NCHW<int>(axis.size()));
auto nchw2nhwc_axis = std::move(GetAxisNCHW2NHWC<int>(axis.size()));
for (size_t i = 0; i < new_axis.size(); ++i) {
new_axis[i] = nhwc2nchw_axis[axis[nchw2nhwc_axis[i]]];
}
return new_axis;
}
int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[MLU] Converting " + op_type + "...";
// Get input vars and op attributes
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
auto x_dims = x->dims().Vectorize();
auto out_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize();
auto axis = op_info->GetAttr<std::vector<int>>("axis");
std::vector<int> axis_nhwc = axis_to_nhwc(axis);
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
CHECK(graph->HasNode(x_var_name));
auto input_tensor = graph->GetNode(x_var_name);
cnmlBaseOp_t transpose_op{nullptr};
cnmlNdTransposeOpParam_t transpose_param{nullptr};
CNML_CALL(cnmlCreateNdTransposeOpParam(
&transpose_param, axis_nhwc.data(), axis_nhwc.size()));
// Use cnmlCreatexxxOpForward to create op.
CNML_CALL(cnmlCreateNdTransposeProOp(&transpose_op,
input_tensor->mlu_tensor(),
output_tensor->mlu_tensor(),
transpose_param));
graph->FuseOp(transpose_op);
CNML_CALL(cnmlDestroyBaseOp(&transpose_op));
return SUCCESS;
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(transpose,
kMLU,
paddle::lite::subgraph::mlu::TransposeConverter);
REGISTER_SUBGRAPH_BRIDGE(transpose2,
kMLU,
paddle::lite::subgraph::mlu::TransposeConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/transpose_op.h"
#include <gtest/gtest.h>
#include "lite/core/op_registry.h"
#include "lite/kernels/mlu/bridges/test_helper.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
int data_index(std::vector<int> pos, DDimLite dims) {
int d1 = dims[1];
int d2 = dims[2];
int d3 = dims[3];
return pos[3] + pos[2] * d3 + pos[1] * d3 * d2 + pos[0] * d3 * d2 * d1;
}
std::vector<int> pos_trans(std::vector<int> in_pos, std::vector<int> axis) {
std::vector<int> out_pos(in_pos.size());
for (size_t i = 0; i < axis.size(); i++) {
out_pos[axis[i]] = in_pos[i];
}
return out_pos;
}
template <typename dtype>
void transpose_ref(const std::shared_ptr<operators::TransposeOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto input =
scope->FindVar(op_info->Input("X").front())->GetMutable<Tensor>();
auto output =
scope->FindVar(op_info->Output("Out").front())->GetMutable<Tensor>();
auto x_dims = input->dims();
auto y_dims = output->dims();
auto axis = op_info->GetAttr<std::vector<int>>("axis");
// auto input_data = input->data<dtype>();
auto* input_data = input->mutable_data<dtype>();
auto* output_data = output->mutable_data<dtype>();
int input_n = x_dims[0];
int input_c = x_dims[1];
int input_h = x_dims[2];
int input_w = x_dims[3];
for (int n = 0; n < input_n; ++n) {
for (int c = 0; c < input_c; ++c) {
for (int h = 0; h < input_h; ++h) {
for (int w = 0; w < input_w; ++w) {
std::vector<int> in_pos{n, c, h, w};
std::vector<int> out_pos = pos_trans(in_pos, axis);
int in_index = data_index(in_pos, x_dims);
int out_index = data_index(out_pos, y_dims);
output_data[out_index] = input_data[in_index];
}
}
}
}
}
void test_transpose(const std::vector<int64_t>& input_shape,
std::vector<int> axis) {
// prepare input&output variables
Scope scope;
std::string x_var_name = "x";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize(input_shape);
// initialize input&output data
FillTensor<float>(x);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("transpose");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetAttr("axis", axis);
// create and convert op to MLU model, then run it on MLU
auto op = CreateOp<operators::TransposeOp>(opdesc, &scope);
// transpose_ref must run befor LaunchOp
// otherwise get Cannot access memory
// execute reference implementation and save to output tensor
transpose_ref<float>(op);
out_ref->CopyDataFrom(*out);
Tensor input_x;
input_x.Resize(DDim(input_shape));
transpose(x->mutable_data<float>(),
input_x.mutable_data<float>(),
{static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[1]),
static_cast<int>(input_shape[2]),
static_cast<int>(input_shape[3])},
{0, 2, 3, 1});
x->CopyDataFrom(input_x);
LaunchOp(op, {x_var_name}, {out_var_name});
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
Tensor output_trans;
output_trans.Resize(out->dims());
auto os = out->dims();
transpose(out_data,
output_trans.mutable_data<float>(),
{static_cast<int>(os[0]),
static_cast<int>(os[2]),
static_cast<int>(os[3]),
static_cast<int>(os[1])},
{0, 3, 1, 2});
out_data = output_trans.mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2);
}
}
// TODO(pmshst): fix the transpose test
TEST(MLUBridges, transpose) {
std::vector<int64_t> input_shape = {2, 3, 4, 5};
test_transpose(input_shape, std::vector<int>{0, 1, 3, 2});
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
USE_SUBGRAPH_BRIDGE(transpose, kMLU);
USE_SUBGRAPH_BRIDGE(transpose2, kMLU);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册