未验证 提交 be7cc8f8 编写于 作者: M MaxwellDing 提交者: GitHub

[MLU] feat: add extra kernels, test=develop (#3919)

add op lrn  norm  gather
上级 37cb221e
......@@ -55,6 +55,18 @@ set(mlu_subgraph_bridges
CACHE INTERNAL "mlu_subgraph_bridges")
if (LITE_BUILD_EXTRA)
lite_cc_library(subgraph_bridge_lrn_op_mlu SRCS lrn_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_gather_op_mlu SRCS gather_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_norm_op_mlu SRCS norm_op.cc DEPS ${subgraph_bridge_deps_mlu})
set(mlu_subgraph_bridges
"${mlu_subgraph_bridges}"
subgraph_bridge_lrn_op_mlu
subgraph_bridge_gather_op_mlu
subgraph_bridge_norm_op_mlu
CACHE INTERNAL "mlu_subgraph_bridges")
endif()
lite_cc_library(subgraph_test_helper_mlu SRCS test_helper.cc DEPS ${mlu_subgraph_bridges})
lite_cc_test(test_conv_converter_mlu SRCS conv_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_act_converter_mlu SRCS act_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
......@@ -76,4 +88,11 @@ lite_cc_test(test_argmax_converter_mlu SRCS argmax_op_test.cc DEPS scope optimiz
lite_cc_test(test_squeeze_converter_mlu SRCS squeeze_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_reshape_converter_mlu SRCS reshape_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_flatten_converter_mlu SRCS flatten_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
if (LITE_BUILD_EXTRA)
lite_cc_test(test_norm_converter_mlu SRCS norm_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_lrn_converter_mlu SRCS lrn_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_gather_converter_mlu SRCS gather_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
endif()
message(STATUS "+++++ mlu_subgraph_bridges: ${mlu_subgraph_bridges}")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/mlu/bridges/graph.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
int GatherConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[MLU] Converting " + op_type + "...";
auto x_var_name = op_info->Input("X").front();
auto index_var_name = op_info->Input("Index").front();
auto out_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize();
CHECK(graph->HasNode(x_var_name));
auto x_tensor = graph->GetNode(x_var_name);
auto index_tensor = graph->GetNode(index_var_name);
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
cnmlBaseOp_t gather_op;
CNML_CALL(cnmlCreateGatherV2Op(&gather_op,
x_tensor->mlu_tensor(),
index_tensor->mlu_tensor(),
output_tensor->mlu_tensor(),
CNML_DIM_N));
graph->FuseOp(gather_op);
CNML_CALL(cnmlDestroyBaseOp(&gather_op));
return SUCCESS;
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(gather,
kMLU,
paddle::lite::subgraph::mlu::GatherConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/gather_op.h"
#include <gtest/gtest.h>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/mlu/bridges/test_helper.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
template <typename dtype>
void gather_ref(const std::shared_ptr<operators::GatherOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindVar(op_info->Input("X").front())->GetMutable<Tensor>();
auto index =
scope->FindVar(op_info->Input("Index").front())->GetMutable<Tensor>();
auto out =
scope->FindVar(op_info->Output("Out").front())->GetMutable<Tensor>();
auto x_dims = x->dims();
auto index_dims = index->dims();
CHECK(index_dims.size() == 1 ||
(index_dims.size() == 2 && index_dims[1] == 1));
int batch_size = index_dims[0];
DDim out_dims = x_dims;
out_dims[0] = batch_size;
out->Resize(out_dims);
auto x_data = x->data<float>();
auto index_data = index->data<int>();
auto out_data = out->mutable_data<float>();
auto slice_num = x_dims[0];
auto slice_size = x_dims.Slice(1, x_dims.size()).production();
for (int i = 0; i < batch_size; i++) {
auto index = index_data[i];
CHECK_LT(index, slice_num) << "index <= slice_num";
CHECK_GE(index, 0) << "index > 0";
memcpy(out_data + i * slice_size,
x_data + index * slice_size,
slice_size * sizeof(float));
}
}
void test_gather() {
// prepare input&output variables
std::string x_var_name = "x";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
std::string index_var_name = "index";
Scope scope;
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
auto* index = scope.Var(index_var_name)->GetMutable<Tensor>();
x->Resize({5, 4, 3, 2});
index->Resize({2});
// initialize input&output data
FillTensor<float>(x);
FillTensor<int>(index, 1, 3);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("gather");
opdesc.SetInput("X", {x_var_name});
opdesc.SetInput("Index", {index_var_name});
opdesc.SetOutput("Out", {out_var_name});
auto op = CreateOp<operators::GatherOp>(opdesc, &scope);
gather_ref<float>(op);
out_ref->CopyDataFrom(*out);
Tensor input;
input.Resize({5, 4, 3, 2});
transpose<float>(x->mutable_data<float>(),
input.mutable_data<float>(),
{static_cast<int>(5),
static_cast<int>(4),
static_cast<int>(3),
static_cast<int>(2)},
{0, 2, 3, 1});
x->CopyDataFrom(input);
LaunchOp(op, {x_var_name, index_var_name}, {out_var_name});
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
Tensor output;
output.Resize(out->dims());
transpose<float>(out_data,
output.mutable_data<float>(),
{static_cast<int>(out->dims()[0]),
static_cast<int>(out->dims()[2]),
static_cast<int>(out->dims()[3]),
static_cast<int>(out->dims()[1])},
{0, 3, 1, 2});
out_data = output.mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
VLOG(5) << i;
EXPECT_NEAR(out_data[i], out_ref_data[i], 5e-4);
}
}
TEST(MLUBridges, gather) { test_gather(); }
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
USE_SUBGRAPH_BRIDGE(gather, kMLU);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/mlu/bridges/graph.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
int LrnConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[MLU] Converting " + op_type + "...";
// Create lrn node and get params from op
auto fp_type = graph->FPType();
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize();
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, fp_type);
CHECK(graph->HasNode(x_var_name));
auto input_tensor = graph->GetNode(x_var_name);
auto alpha = op_info->GetAttr<float>("alpha");
auto beta = op_info->GetAttr<float>("beta");
auto k = op_info->GetAttr<float>("k");
if (op_info->HasAttr("norm_region")) {
CHECK(op_info->GetAttr<std::string>("norm_region") == "AcrossChannels")
<< "Unsuport WithinChannel";
}
auto local_size = op_info->GetAttr<int>("n");
CHECK(op_info->HasAttr("input_scale"));
auto input_scale = op_info->GetAttr<float>("input_scale");
VLOG(5) << "lrn input scale: " << input_scale;
cnmlLrnOpParam_t param;
cnmlBaseOp_t lrn_op;
CNML_CALL(
cnmlCreateLrnOpParam(&param, CNML_LRN_V3, local_size, alpha, beta, k));
CNML_CALL(cnmlCreateLrnOp(
&lrn_op, param, input_tensor->mlu_tensor(), output_tensor->mlu_tensor()));
CNML_CALL(cnmlDestroyLrnOpParam(&param));
graph->SetComputingDataType(
lrn_op, input_tensor->mlu_tensor(), 1 / input_scale);
CNML_CALL(cnmlSetOperationComputingDataType(
lrn_op, output_tensor->mlu_tensor(), fp_type, nullptr));
graph->FuseOp(lrn_op);
CNML_CALL(cnmlDestroyBaseOp(&lrn_op));
return SUCCESS;
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(lrn, kMLU, paddle::lite::subgraph::mlu::LrnConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/lrn_op.h"
#include <gtest/gtest.h>
#include <algorithm>
#include <cmath>
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/mlu/bridges/test_helper.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
/**
* @brief get sum of x^2 between channels [size elements]
*
* @tparam float
* @param input
* @param channel_id: the c-th channel within n-th graph.
* @param offset_within_channel: the pixel's offset within a channel.
* @param offset_num: the first address of n-th graph.
* @param c
* @param h
* @param w
* @param size
* @return float
*/
float lrn_square(const float* input,
int channel_id,
int offset_within_channel,
int offset_num,
int c,
int h,
int w,
int size) {
int pre_pad = (size - 1) / 2;
float res = 0;
const float* src = input + offset_num;
// handle left channels with padding situation.
if (channel_id - pre_pad < 0) {
for (int i = 0; i <= channel_id; ++i) {
res += src[i * h * w + offset_within_channel] *
src[i * h * w + offset_within_channel];
}
}
// handle left channels.
if (channel_id - pre_pad >= 0) {
for (int i = channel_id - pre_pad; i <= channel_id; ++i) {
res += src[i * h * w + offset_within_channel] *
src[i * h * w + offset_within_channel];
}
}
// handle right channels.
if (channel_id + pre_pad < c) {
for (int i = channel_id + 1; i <= channel_id + pre_pad; ++i) {
res += src[i * h * w + offset_within_channel] *
src[i * h * w + offset_within_channel];
}
}
// handle right channels with padding situation.
if (channel_id + pre_pad >= c && channel_id + 1 < c) {
for (int i = channel_id + 1; i < c; ++i) {
res += src[i * h * w + offset_within_channel] *
src[i * h * w + offset_within_channel];
}
}
return res;
}
void lrn_compute_ref(std::shared_ptr<operators::LrnOpLite> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x =
scope->FindVar(op_info->Input("X").front())->GetMutable<lite::Tensor>();
auto out = scope->FindVar(op_info->Output("Out").front())
->GetMutable<lite::Tensor>();
const float* x_data = x->data<const float>();
float* out_data = out->mutable_data<float>();
auto x_dims = x->dims();
auto alpha = op_info->GetAttr<float>("alpha");
auto beta = op_info->GetAttr<float>("beta");
auto k = op_info->GetAttr<float>("k");
auto norm_region = op_info->GetAttr<std::string>("norm_region");
auto local_size = op_info->GetAttr<int>("n");
int N = x_dims[0];
int C = x_dims[1];
int H = x_dims[2];
int W = x_dims[3];
int offset_num = 0;
int offset_within_channel = 0;
int dst_id;
float square;
for (int n = 0; n < N; ++n) {
offset_num = n * C * H * W;
for (int c = 0; c < C; ++c) {
for (int h = 0; h < H; ++h) {
for (int w = 0; w < W; ++w) {
offset_within_channel = h * W + w;
dst_id = offset_num + c * H * W + offset_within_channel;
square = lrn_square(x_data,
c,
offset_within_channel,
offset_num,
C,
H,
W,
local_size);
out_data[dst_id] = x_data[dst_id] * pow(k + alpha * square, -beta);
}
}
}
}
}
void test_lrn(float alpha,
float beta,
float k,
int local_size,
int n,
int c,
int h,
int w,
const std::string& norm_region) {
Scope scope;
std::string x_var_name("X_test");
std::string out_var_name("Out_test");
std::string out_ref_var_name("Out_ref");
auto* x = scope.NewTensor(x_var_name);
auto* out = scope.NewTensor(out_var_name);
auto* out_ref = scope.NewTensor(out_ref_var_name);
std::vector<int64_t> x_dim{n, c, h, w};
x->Resize(x_dim);
out->Resize(x_dim);
out_ref->Resize(x_dim);
auto* x_data = x->mutable_data<float>();
FillTensor<float, float>(x, 0.f, 1.f);
float *dmax, *dmin;
std::tie(dmin, dmax) =
std::minmax_element(x_data, x_data + x->data_size() - 1);
cpp::OpDesc opdesc;
opdesc.SetType("lrn");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetAttr("alpha", alpha);
opdesc.SetAttr("beta", beta);
opdesc.SetAttr("k", k);
opdesc.SetAttr("n", local_size);
opdesc.SetAttr("norm_region", norm_region);
opdesc.SetAttr<float>("input_scale", (*dmax - *dmin) / 255.f);
auto op = CreateOp<operators::LrnOpLite>(opdesc, &scope);
// baseline
lrn_compute_ref(op);
out_ref->CopyDataFrom(*out);
Tensor input_x;
input_x.Resize(x->dims());
transpose(x->mutable_data<float>(),
input_x.mutable_data<float>(),
{static_cast<int>(x_dim[0]),
static_cast<int>(x_dim[1]),
static_cast<int>(x_dim[2]),
static_cast<int>(x_dim[3])},
{0, 2, 3, 1});
x->CopyDataFrom(input_x);
LaunchOp(op, {x_var_name}, {out_var_name});
Tensor output_trans;
auto os = out->dims();
output_trans.Resize(os);
transpose(out->mutable_data<float>(),
output_trans.mutable_data<float>(),
{static_cast<int>(os[0]),
static_cast<int>(os[2]),
static_cast<int>(os[3]),
static_cast<int>(os[1])},
{0, 3, 1, 2});
auto output_data = output_trans.mutable_data<float>();
auto* output_ref_data = out_ref->mutable_data<float>();
for (size_t i = 0; i < out->data_size(); i++) {
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-4);
}
}
TEST(MLUBridges, lrn) {
int local_size = 5;
float alpha = 0.0001f;
float beta = 0.75;
float k = 2.0f;
std::string norm_region = "AcrossChannels";
for (int w : {2, 4, 8}) {
for (int h : {2, 4, 8}) {
for (int c : {1, 2, 3, 4}) {
for (int n : {1, 2, 3, 4}) {
test_lrn(alpha, beta, k, local_size, n, c, h, w, norm_region);
}
}
}
}
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
USE_SUBGRAPH_BRIDGE(lrn, kMLU)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/mlu/bridges/graph.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
int NormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[MLU] Converting " + op_type + "...";
// Get input vars and op attributes
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
auto x_dims = x->dims().Vectorize();
auto out_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize();
int axis = op_info->GetAttr<int>("axis");
int epsilon = op_info->GetAttr<float>("epsilon");
if (axis < 0) {
axis = axis + x_dims.size();
}
std::vector<int> nchw2nhwc = {0, 3, 1, 2};
int nhwc_axis = nchw2nhwc[axis];
CHECK(graph->HasNode(x_var_name));
auto input_tensor = graph->GetNode(x_var_name);
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType());
// ======== DEBUG ===============
VLOG(6) << "x name=" << x_var_name;
VLOG(6) << "out name=" << out_var_name;
VLOG(6) << "x dims=" << x->dims();
VLOG(6) << "out dims=" << output->dims();
VLOG(6) << "axis =" << axis;
VLOG(6) << "nwhc axis=" << nhwc_axis;
VLOG(6) << "epsilon =" << epsilon;
// cnmlPrintTensor(input_tensor->mlu_tensor(), CNML_TENSOR);
// cnmlPrintTensor(output_tensor->mlu_tensor(), CNML_TENSOR);
// ======== DEBUG END ============
cnmlBaseOp_t norm_op{nullptr};
cnmlNormalizeOpParam_t param;
int mode = -1;
switch (axis) {
case 0:
mode = 3; // N
break;
case 1:
mode = 0; // C
break;
case 2:
mode = 4; // H
break;
case 3:
mode = 5; // W
break;
default:
CHECK(0);
break;
}
cnmlCreateNormalizeOpParamV2(&param,
0, // p
0, // use_scale
mode,
1, // weight
epsilon);
CNML_CALL(cnmlCreateNormalizeOp(&norm_op,
param,
input_tensor->mlu_tensor(),
output_tensor->mlu_tensor(),
nullptr,
false /*is_fix8_mode*/));
graph->FuseOp(norm_op);
CNML_CALL(cnmlDestroyBaseOp(&norm_op));
return SUCCESS;
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(norm,
kMLU,
paddle::lite::subgraph::mlu::NormConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/norm_op.h"
#include <gtest/gtest.h>
#include <cmath>
#include <iostream>
#include "lite/core/op_registry.h"
#include "lite/kernels/mlu/bridges/test_helper.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
// void ToFile(std::string file_name, Tensor* tensor) {
// int count = tensor->dims().production();
// auto data = tensor->mutable_data<float>();
// std::ostringstream outs;
// for (size_t i = 0; i < count; i++) {
// outs << data[i] << std::endl;
// }
// std::ofstream of;
// of.open(file_name, std::ios::out);
// of << outs.str();
// of.close();
// }
void norm_ref(const std::shared_ptr<operators::NormOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindVar(op_info->Input("X").front())->GetMutable<Tensor>();
auto out =
scope->FindVar(op_info->Output("Out").front())->GetMutable<Tensor>();
int axis = op_info->GetAttr<int>("axis");
int epsilon = op_info->GetAttr<float>("epsilon");
auto x_dims = x->dims();
if (axis < 0) {
axis += x_dims.size();
}
out->Resize(x_dims.Vectorize());
auto* out_data = out->mutable_data<float>();
const auto* x_data = x->data<float>();
int pre_n = x_dims.count(0, axis);
int n = x_dims[axis];
int post_n = x_dims.count(axis + 1, x_dims.size());
for (int i = 0; i < pre_n; i++) {
for (int k = 0; k < post_n; k++) {
float sum = epsilon;
const float* in_tmp = x_data + i * n * post_n + k;
for (int j = 0; j < n; j++) {
sum += in_tmp[j * post_n] * in_tmp[j * post_n];
}
sum = std::sqrt(sum);
float* out_tmp = out_data + i * n * post_n + k;
for (int j = 0; j < n; j++) {
out_tmp[j * post_n] = in_tmp[j * post_n] / sum;
}
}
}
}
void test_norm(const std::vector<int64_t>& input_shape, int axis) {
// prepare input&output variables
Scope scope;
std::string x_var_name = "x";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize(input_shape);
// initialize input&output data
FillTensor<float, float>(x, -9, 9);
// initialize op desc
cpp::OpDesc opdesc;
float epsilon = 1e-9f;
opdesc.SetType("norm");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetAttr("axis", static_cast<int>(axis));
opdesc.SetAttr("epsilon", static_cast<float>(epsilon));
// create and convert op to MLU model, then run it on MLU
auto op = CreateOp<operators::NormOp>(opdesc, &scope);
norm_ref(op);
out_ref->CopyDataFrom(*out);
Tensor input_x;
input_x.Resize(DDim(input_shape));
// change input layout from NCHW to NHWC
transpose<float>(x->mutable_data<float>(),
input_x.mutable_data<float>(),
{static_cast<int>(input_shape[0]),
static_cast<int>(input_shape[1]),
static_cast<int>(input_shape[2]),
static_cast<int>(input_shape[3])},
{0, 2, 3, 1});
x->CopyDataFrom(input_x);
LaunchOp(op, {x_var_name}, {out_var_name});
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
std::vector<int64_t> out_shape = input_shape;
Tensor output_trans;
output_trans.Resize(out_shape);
// Change output layout from NHWC to NCHW
transpose<float>(out_data,
output_trans.mutable_data<float>(),
{static_cast<int>(out_shape[0]),
static_cast<int>(out_shape[2]),
static_cast<int>(out_shape[3]),
static_cast<int>(out_shape[1])},
{0, 3, 1, 2});
out_data = output_trans.mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2);
}
}
TEST(MLUBridges, norm) {
test_norm({1, 2, 3, 4}, 1);
test_norm({1, 2, 3, 4}, 2);
test_norm({1, 2, 3, 4}, 3);
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
USE_SUBGRAPH_BRIDGE(norm, kMLU);
......@@ -43,3 +43,8 @@ USE_SUBGRAPH_BRIDGE(flatten, kMLU);
USE_SUBGRAPH_BRIDGE(flatten2, kMLU);
USE_SUBGRAPH_BRIDGE(reshape, kMLU);
USE_SUBGRAPH_BRIDGE(reshape2, kMLU);
#ifdef LITE_BUILD_EXTRA
USE_SUBGRAPH_BRIDGE(gather, kMLU);
USE_SUBGRAPH_BRIDGE(lrn, kMLU)
USE_SUBGRAPH_BRIDGE(norm, kMLU)
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册