From db21f878952a7583274a6e45d304114542b33056 Mon Sep 17 00:00:00 2001 From: dingminghui Date: Tue, 28 Apr 2020 11:37:15 +0800 Subject: [PATCH] feat(kernel): add mlu lrn convertor --- lite/core/mir/subgraph/CMakeLists.txt | 3 + lite/kernels/mlu/bridges/CMakeLists.txt | 9 + lite/kernels/mlu/bridges/lrn_op.cc | 79 ++++++ lite/kernels/mlu/bridges/lrn_op_test.cc | 224 ++++++++++++++++++ lite/kernels/mlu/bridges/paddle_use_bridges.h | 3 + 5 files changed, 318 insertions(+) create mode 100644 lite/kernels/mlu/bridges/lrn_op.cc create mode 100644 lite/kernels/mlu/bridges/lrn_op_test.cc diff --git a/lite/core/mir/subgraph/CMakeLists.txt b/lite/core/mir/subgraph/CMakeLists.txt index f8aa09676c..ef3cfd07e7 100644 --- a/lite/core/mir/subgraph/CMakeLists.txt +++ b/lite/core/mir/subgraph/CMakeLists.txt @@ -4,6 +4,9 @@ lite_cc_library(subgraph_detector lite_cc_library(subgraph_pass SRCS subgraph_pass.cc DEPS mir_pass types context ${mir_fusers} subgraph_detector) +if (LITE_BUILD_EXTRA) + target_compile_definitions(subgraph_pass PUBLIC "-DLITE_BUILD_EXTRA") +endif() if (WITH_TESTING AND NOT LITE_WITH_CUDA) lite_cc_test(test_subgraph_detector SRCS subgraph_detector_test.cc diff --git a/lite/kernels/mlu/bridges/CMakeLists.txt b/lite/kernels/mlu/bridges/CMakeLists.txt index ff8f03621b..4721b2d217 100644 --- a/lite/kernels/mlu/bridges/CMakeLists.txt +++ b/lite/kernels/mlu/bridges/CMakeLists.txt @@ -38,6 +38,12 @@ set(mlu_subgraph_bridges subgraph_bridge_dropout_op_mlu CACHE INTERNAL "mlu_subgraph_bridges") + +if (LITE_BUILD_EXTRA) + lite_cc_library(subgraph_bridge_lrn_op_mlu SRCS lrn_op.cc DEPS ${subgraph_bridge_deps_mlu}) + list(APPEND mlu_subgraph_bridges subgraph_bridge_lrn_op_mlu) +endif() + lite_cc_library(subgraph_test_helper_mlu SRCS test_helper.cc DEPS ${mlu_subgraph_bridges}) lite_cc_test(test_conv_converter_mlu SRCS conv_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) lite_cc_test(test_act_converter_mlu SRCS act_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) @@ -51,4 +57,7 @@ lite_cc_test(test_interp_converter_mlu SRCS interpolate_op_test.cc DEPS scope op lite_cc_test(test_concat_converter_mlu SRCS concat_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) lite_cc_test(test_transpose_converter_mlu SRCS transpose_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) lite_cc_test(test_dropout_converter_mlu SRCS dropout_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) +if (LITE_BUILD_EXTRA) + lite_cc_test(test_lrn_converter_mlu SRCS lrn_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) +endif() message(STATUS "+++++ mlu_subgraph_bridges: ${mlu_subgraph_bridges}") diff --git a/lite/kernels/mlu/bridges/lrn_op.cc b/lite/kernels/mlu/bridges/lrn_op.cc new file mode 100644 index 0000000000..aa098eaee4 --- /dev/null +++ b/lite/kernels/mlu/bridges/lrn_op.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/mlu/bridges/graph.h" +#include "lite/kernels/mlu/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace mlu { + +int LrnConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto scope = op->scope(); + VLOG(3) << "[MLU] Converting " + op_type + "..."; + + // Create lrn node and get params from op + auto fp_type = graph->FPType(); + auto x_var_name = op_info->Input("X").front(); + auto out_var_name = op_info->Output("Out").front(); + auto output = scope->FindVar(out_var_name)->GetMutable(); + auto output_dims = output->dims().Vectorize(); + auto output_tensor = graph->AddNode( + out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, fp_type); + CHECK(graph->HasNode(x_var_name)); + auto input_tensor = graph->GetNode(x_var_name); + + auto alpha = op_info->GetAttr("alpha"); + auto beta = op_info->GetAttr("beta"); + auto k = op_info->GetAttr("k"); + if (op_info->HasAttr("norm_region")) { + CHECK(op_info->GetAttr("norm_region") == "AcrossChannels") + << "Unsuport WithinChannel"; + } + auto local_size = op_info->GetAttr("n"); + CHECK(op_info->HasAttr("input_scale")); + auto input_scale = op_info->GetAttr("input_scale"); + std::cout << "input scale: " << input_scale << std::endl; + + cnmlLrnOpParam_t param; + cnmlBaseOp_t lrn_op; + CNML_CALL( + cnmlCreateLrnOpParam(¶m, CNML_LRN_V3, local_size, alpha, beta, k)); + CNML_CALL(cnmlCreateLrnOp( + &lrn_op, param, input_tensor->mlu_tensor(), output_tensor->mlu_tensor())); + CNML_CALL(cnmlDestroyLrnOpParam(¶m)); + + graph->SetComputingDataType( + lrn_op, input_tensor->mlu_tensor(), 1 / input_scale); + CNML_CALL(cnmlSetOperationComputingDataType( + lrn_op, output_tensor->mlu_tensor(), fp_type, nullptr)); + + graph->FuseOp(lrn_op); + CNML_CALL(cnmlDestroyBaseOp(&lrn_op)); + return SUCCESS; +} + +} // namespace mlu +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(lrn, kMLU, paddle::lite::subgraph::mlu::LrnConverter); diff --git a/lite/kernels/mlu/bridges/lrn_op_test.cc b/lite/kernels/mlu/bridges/lrn_op_test.cc new file mode 100644 index 0000000000..10058cea79 --- /dev/null +++ b/lite/kernels/mlu/bridges/lrn_op_test.cc @@ -0,0 +1,224 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/lrn_op.h" +#include +#include +#include +#include +#include + +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/mlu/bridges/test_helper.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace mlu { + +/** + * @brief get sum of x^2 between channels [size elements] + * + * @tparam float + * @param input + * @param channel_id: the c-th channel within n-th graph. + * @param offset_within_channel: the pixel's offset within a channel. + * @param offset_num: the first address of n-th graph. + * @param c + * @param h + * @param w + * @param size + * @return float + */ +float lrn_square(const float* input, + int channel_id, + int offset_within_channel, + int offset_num, + int c, + int h, + int w, + int size) { + int pre_pad = (size - 1) / 2; + float res = 0; + const float* src = input + offset_num; + + // handle left channels with padding situation. + if (channel_id - pre_pad < 0) { + for (int i = 0; i <= channel_id; ++i) { + res += src[i * h * w + offset_within_channel] * + src[i * h * w + offset_within_channel]; + } + } + + // handle left channels. + if (channel_id - pre_pad >= 0) { + for (int i = channel_id - pre_pad; i <= channel_id; ++i) { + res += src[i * h * w + offset_within_channel] * + src[i * h * w + offset_within_channel]; + } + } + + // handle right channels. + if (channel_id + pre_pad < c) { + for (int i = channel_id + 1; i <= channel_id + pre_pad; ++i) { + res += src[i * h * w + offset_within_channel] * + src[i * h * w + offset_within_channel]; + } + } + + // handle right channels with padding situation. + if (channel_id + pre_pad >= c && channel_id + 1 < c) { + for (int i = channel_id + 1; i < c; ++i) { + res += src[i * h * w + offset_within_channel] * + src[i * h * w + offset_within_channel]; + } + } + + return res; +} + +void lrn_compute_ref(std::shared_ptr op) { + Scope* scope = op->scope(); + const OpInfo* op_info = op->op_info(); + auto x = + scope->FindVar(op_info->Input("X").front())->GetMutable(); + auto out = scope->FindVar(op_info->Output("Out").front()) + ->GetMutable(); + + const float* x_data = x->data(); + float* out_data = out->mutable_data(); + auto x_dims = x->dims(); + + auto alpha = op_info->GetAttr("alpha"); + auto beta = op_info->GetAttr("beta"); + auto k = op_info->GetAttr("k"); + auto norm_region = op_info->GetAttr("norm_region"); + auto local_size = op_info->GetAttr("n"); + + int N = x_dims[0]; + int C = x_dims[1]; + int H = x_dims[2]; + int W = x_dims[3]; + + int offset_num = 0; + int offset_within_channel = 0; + int dst_id; + + float square; + + for (int n = 0; n < N; ++n) { + offset_num = n * C * H * W; + + for (int c = 0; c < C; ++c) { + for (int h = 0; h < H; ++h) { + for (int w = 0; w < W; ++w) { + offset_within_channel = h * W + w; + dst_id = offset_num + c * H * W + offset_within_channel; + square = lrn_square(x_data, + c, + offset_within_channel, + offset_num, + C, + H, + W, + local_size); + out_data[dst_id] = x_data[dst_id] * pow(k + alpha * square, -beta); + } + } + } + } +} + +void test_lrn(float alpha, + float beta, + float k, + int local_size, + int n, + int c, + int h, + int w, + const std::string& norm_region) { + Scope scope; + std::string x_var_name("X_test"); + std::string out_var_name("Out_test"); + std::string out_ref_var_name("Out_ref"); + auto* x = scope.NewTensor(x_var_name); + auto* out = scope.NewTensor(out_var_name); + auto* out_ref = scope.NewTensor(out_ref_var_name); + + std::vector x_dim{n, c, h, w}; + x->Resize(x_dim); + out->Resize(x_dim); + out_ref->Resize(x_dim); + auto* x_data = x->mutable_data(); + FillTensor(x, 0.f, 1.f); + /* for (size_t i = 0; i < x->data_size(); i++) { */ + /* x_data[i] = i; */ + /* } */ + float *dmax, *dmin; + std::tie(dmin, dmax) = + std::minmax_element(x_data, x_data + x->data_size() - 1); + printf("max: %f, min: %f\n", *dmax, *dmin); + + cpp::OpDesc opdesc; + opdesc.SetType("lrn"); + opdesc.SetInput("X", {x_var_name}); + opdesc.SetOutput("Out", {out_var_name}); + opdesc.SetAttr("alpha", alpha); + opdesc.SetAttr("beta", beta); + opdesc.SetAttr("k", k); + opdesc.SetAttr("n", local_size); + opdesc.SetAttr("norm_region", norm_region); + opdesc.SetAttr("input_scale", (*dmax - *dmin) / 255.f); + + auto op = CreateOp(opdesc, &scope); + + // baseline + lrn_compute_ref(op); + out_ref->CopyDataFrom(*out); + + LaunchOp(op, {x_var_name}, {out_var_name}); + + auto* output_data = out->mutable_data(); + auto* output_ref_data = out_ref->mutable_data(); + for (size_t i = 0; i < out->data_size(); i++) { + EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-4); + } +} + +TEST(MLUBridges, lrn) { + int local_size = 5; + float alpha = 0.0001f; + float beta = 0.75; + float k = 2.0f; + std::string norm_region = "AcrossChannels"; + for (int w : {2, 4, 8}) { + for (int h : {2, 4, 8}) { + for (int c : {1, 2, 3, 4}) { + for (int n : {1, 2, 3, 4}) { + test_lrn(alpha, beta, k, local_size, n, c, h, w, norm_region); + } + } + } + } +} + +} // namespace mlu +} // namespace subgraph +} // namespace lite +} // namespace paddle + +USE_SUBGRAPH_BRIDGE(lrn, kMLU) diff --git a/lite/kernels/mlu/bridges/paddle_use_bridges.h b/lite/kernels/mlu/bridges/paddle_use_bridges.h index 559b348cf4..8109b0f395 100644 --- a/lite/kernels/mlu/bridges/paddle_use_bridges.h +++ b/lite/kernels/mlu/bridges/paddle_use_bridges.h @@ -31,3 +31,6 @@ USE_SUBGRAPH_BRIDGE(scale, kMLU); USE_SUBGRAPH_BRIDGE(sigmoid, kMLU); USE_SUBGRAPH_BRIDGE(elementwise_mul, kMLU); USE_SUBGRAPH_BRIDGE(dropout, kMLU); +#ifdef LITE_BUILD_EXTRA +USE_SUBGRAPH_BRIDGE(lrn, kMLU) +#endif -- GitLab