diff --git a/lite/kernels/mlu/bridges/CMakeLists.txt b/lite/kernels/mlu/bridges/CMakeLists.txt index 38f66ee95a31bcebed1ed2239581d902c909d911..91323925e1ef49462c180fd96392d638e273fd69 100644 --- a/lite/kernels/mlu/bridges/CMakeLists.txt +++ b/lite/kernels/mlu/bridges/CMakeLists.txt @@ -55,6 +55,18 @@ set(mlu_subgraph_bridges CACHE INTERNAL "mlu_subgraph_bridges") +if (LITE_BUILD_EXTRA) + lite_cc_library(subgraph_bridge_lrn_op_mlu SRCS lrn_op.cc DEPS ${subgraph_bridge_deps_mlu}) + lite_cc_library(subgraph_bridge_gather_op_mlu SRCS gather_op.cc DEPS ${subgraph_bridge_deps_mlu}) + lite_cc_library(subgraph_bridge_norm_op_mlu SRCS norm_op.cc DEPS ${subgraph_bridge_deps_mlu}) + set(mlu_subgraph_bridges + "${mlu_subgraph_bridges}" + subgraph_bridge_lrn_op_mlu + subgraph_bridge_gather_op_mlu + subgraph_bridge_norm_op_mlu + CACHE INTERNAL "mlu_subgraph_bridges") +endif() + lite_cc_library(subgraph_test_helper_mlu SRCS test_helper.cc DEPS ${mlu_subgraph_bridges}) lite_cc_test(test_conv_converter_mlu SRCS conv_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) lite_cc_test(test_act_converter_mlu SRCS act_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) @@ -76,4 +88,11 @@ lite_cc_test(test_argmax_converter_mlu SRCS argmax_op_test.cc DEPS scope optimiz lite_cc_test(test_squeeze_converter_mlu SRCS squeeze_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) lite_cc_test(test_reshape_converter_mlu SRCS reshape_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) lite_cc_test(test_flatten_converter_mlu SRCS flatten_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) + +if (LITE_BUILD_EXTRA) + lite_cc_test(test_norm_converter_mlu SRCS norm_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) + lite_cc_test(test_lrn_converter_mlu SRCS lrn_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) + lite_cc_test(test_gather_converter_mlu SRCS gather_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu) +endif() + message(STATUS "+++++ mlu_subgraph_bridges: ${mlu_subgraph_bridges}") diff --git a/lite/kernels/mlu/bridges/gather_op.cc b/lite/kernels/mlu/bridges/gather_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b68f1af76456eede14ec550c623d6a8355f5d5e8 --- /dev/null +++ b/lite/kernels/mlu/bridges/gather_op.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/mlu/bridges/graph.h" +#include "lite/kernels/mlu/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace mlu { + +int GatherConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto scope = op->scope(); + VLOG(3) << "[MLU] Converting " + op_type + "..."; + + auto x_var_name = op_info->Input("X").front(); + auto index_var_name = op_info->Input("Index").front(); + auto out_var_name = op_info->Output("Out").front(); + auto output = scope->FindVar(out_var_name)->GetMutable(); + auto output_dims = output->dims().Vectorize(); + + CHECK(graph->HasNode(x_var_name)); + auto x_tensor = graph->GetNode(x_var_name); + auto index_tensor = graph->GetNode(index_var_name); + + auto output_tensor = graph->AddNode( + out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType()); + + cnmlBaseOp_t gather_op; + CNML_CALL(cnmlCreateGatherV2Op(&gather_op, + x_tensor->mlu_tensor(), + index_tensor->mlu_tensor(), + output_tensor->mlu_tensor(), + CNML_DIM_N)); + graph->FuseOp(gather_op); + CNML_CALL(cnmlDestroyBaseOp(&gather_op)); + return SUCCESS; +} + +} // namespace mlu +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(gather, + kMLU, + paddle::lite::subgraph::mlu::GatherConverter); diff --git a/lite/kernels/mlu/bridges/gather_op_test.cc b/lite/kernels/mlu/bridges/gather_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..413de7c9d7fda750b387c2daa21ef1e40e7982c7 --- /dev/null +++ b/lite/kernels/mlu/bridges/gather_op_test.cc @@ -0,0 +1,133 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/gather_op.h" +#include +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/mlu/bridges/test_helper.h" +#include "lite/kernels/mlu/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace mlu { + +template +void gather_ref(const std::shared_ptr op) { + Scope* scope = op->scope(); + const OpInfo* op_info = op->op_info(); + auto x = scope->FindVar(op_info->Input("X").front())->GetMutable(); + auto index = + scope->FindVar(op_info->Input("Index").front())->GetMutable(); + auto out = + scope->FindVar(op_info->Output("Out").front())->GetMutable(); + + auto x_dims = x->dims(); + auto index_dims = index->dims(); + CHECK(index_dims.size() == 1 || + (index_dims.size() == 2 && index_dims[1] == 1)); + + int batch_size = index_dims[0]; + DDim out_dims = x_dims; + out_dims[0] = batch_size; + out->Resize(out_dims); + + auto x_data = x->data(); + auto index_data = index->data(); + auto out_data = out->mutable_data(); + + auto slice_num = x_dims[0]; + auto slice_size = x_dims.Slice(1, x_dims.size()).production(); + for (int i = 0; i < batch_size; i++) { + auto index = index_data[i]; + CHECK_LT(index, slice_num) << "index <= slice_num"; + CHECK_GE(index, 0) << "index > 0"; + memcpy(out_data + i * slice_size, + x_data + index * slice_size, + slice_size * sizeof(float)); + } +} + +void test_gather() { + // prepare input&output variables + std::string x_var_name = "x"; + std::string out_var_name = "out"; + std::string out_ref_var_name = "out_ref"; + std::string index_var_name = "index"; + + Scope scope; + auto* x = scope.Var(x_var_name)->GetMutable(); + auto* out = scope.Var(out_var_name)->GetMutable(); + auto* out_ref = scope.Var(out_ref_var_name)->GetMutable(); + auto* index = scope.Var(index_var_name)->GetMutable(); + + x->Resize({5, 4, 3, 2}); + index->Resize({2}); + // initialize input&output data + FillTensor(x); + FillTensor(index, 1, 3); + + // initialize op desc + cpp::OpDesc opdesc; + opdesc.SetType("gather"); + opdesc.SetInput("X", {x_var_name}); + opdesc.SetInput("Index", {index_var_name}); + opdesc.SetOutput("Out", {out_var_name}); + + auto op = CreateOp(opdesc, &scope); + gather_ref(op); + out_ref->CopyDataFrom(*out); + + Tensor input; + input.Resize({5, 4, 3, 2}); + transpose(x->mutable_data(), + input.mutable_data(), + {static_cast(5), + static_cast(4), + static_cast(3), + static_cast(2)}, + {0, 2, 3, 1}); + x->CopyDataFrom(input); + LaunchOp(op, {x_var_name, index_var_name}, {out_var_name}); + + // compare results + auto* out_data = out->mutable_data(); + auto* out_ref_data = out_ref->mutable_data(); + + Tensor output; + output.Resize(out->dims()); + transpose(out_data, + output.mutable_data(), + {static_cast(out->dims()[0]), + static_cast(out->dims()[2]), + static_cast(out->dims()[3]), + static_cast(out->dims()[1])}, + {0, 3, 1, 2}); + out_data = output.mutable_data(); + for (int i = 0; i < out->dims().production(); i++) { + VLOG(5) << i; + EXPECT_NEAR(out_data[i], out_ref_data[i], 5e-4); + } +} + +TEST(MLUBridges, gather) { test_gather(); } + +} // namespace mlu +} // namespace subgraph +} // namespace lite +} // namespace paddle + +USE_SUBGRAPH_BRIDGE(gather, kMLU); diff --git a/lite/kernels/mlu/bridges/lrn_op.cc b/lite/kernels/mlu/bridges/lrn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..657f0dd6781590e1a9ca90bf25e4efcf789863dd --- /dev/null +++ b/lite/kernels/mlu/bridges/lrn_op.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/mlu/bridges/graph.h" +#include "lite/kernels/mlu/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace mlu { + +int LrnConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto scope = op->scope(); + VLOG(3) << "[MLU] Converting " + op_type + "..."; + + // Create lrn node and get params from op + auto fp_type = graph->FPType(); + auto x_var_name = op_info->Input("X").front(); + auto out_var_name = op_info->Output("Out").front(); + auto output = scope->FindVar(out_var_name)->GetMutable(); + auto output_dims = output->dims().Vectorize(); + auto output_tensor = graph->AddNode( + out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, fp_type); + CHECK(graph->HasNode(x_var_name)); + auto input_tensor = graph->GetNode(x_var_name); + + auto alpha = op_info->GetAttr("alpha"); + auto beta = op_info->GetAttr("beta"); + auto k = op_info->GetAttr("k"); + if (op_info->HasAttr("norm_region")) { + CHECK(op_info->GetAttr("norm_region") == "AcrossChannels") + << "Unsuport WithinChannel"; + } + auto local_size = op_info->GetAttr("n"); + CHECK(op_info->HasAttr("input_scale")); + auto input_scale = op_info->GetAttr("input_scale"); + VLOG(5) << "lrn input scale: " << input_scale; + + cnmlLrnOpParam_t param; + cnmlBaseOp_t lrn_op; + CNML_CALL( + cnmlCreateLrnOpParam(¶m, CNML_LRN_V3, local_size, alpha, beta, k)); + CNML_CALL(cnmlCreateLrnOp( + &lrn_op, param, input_tensor->mlu_tensor(), output_tensor->mlu_tensor())); + CNML_CALL(cnmlDestroyLrnOpParam(¶m)); + + graph->SetComputingDataType( + lrn_op, input_tensor->mlu_tensor(), 1 / input_scale); + CNML_CALL(cnmlSetOperationComputingDataType( + lrn_op, output_tensor->mlu_tensor(), fp_type, nullptr)); + + graph->FuseOp(lrn_op); + CNML_CALL(cnmlDestroyBaseOp(&lrn_op)); + return SUCCESS; +} + +} // namespace mlu +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(lrn, kMLU, paddle::lite::subgraph::mlu::LrnConverter); diff --git a/lite/kernels/mlu/bridges/lrn_op_test.cc b/lite/kernels/mlu/bridges/lrn_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..21f7e816baeac264bf1b43b7520d464afa38c395 --- /dev/null +++ b/lite/kernels/mlu/bridges/lrn_op_test.cc @@ -0,0 +1,242 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/lrn_op.h" +#include +#include +#include +#include +#include + +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/mlu/bridges/test_helper.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace mlu { + +/** + * @brief get sum of x^2 between channels [size elements] + * + * @tparam float + * @param input + * @param channel_id: the c-th channel within n-th graph. + * @param offset_within_channel: the pixel's offset within a channel. + * @param offset_num: the first address of n-th graph. + * @param c + * @param h + * @param w + * @param size + * @return float + */ +float lrn_square(const float* input, + int channel_id, + int offset_within_channel, + int offset_num, + int c, + int h, + int w, + int size) { + int pre_pad = (size - 1) / 2; + float res = 0; + const float* src = input + offset_num; + + // handle left channels with padding situation. + if (channel_id - pre_pad < 0) { + for (int i = 0; i <= channel_id; ++i) { + res += src[i * h * w + offset_within_channel] * + src[i * h * w + offset_within_channel]; + } + } + + // handle left channels. + if (channel_id - pre_pad >= 0) { + for (int i = channel_id - pre_pad; i <= channel_id; ++i) { + res += src[i * h * w + offset_within_channel] * + src[i * h * w + offset_within_channel]; + } + } + + // handle right channels. + if (channel_id + pre_pad < c) { + for (int i = channel_id + 1; i <= channel_id + pre_pad; ++i) { + res += src[i * h * w + offset_within_channel] * + src[i * h * w + offset_within_channel]; + } + } + + // handle right channels with padding situation. + if (channel_id + pre_pad >= c && channel_id + 1 < c) { + for (int i = channel_id + 1; i < c; ++i) { + res += src[i * h * w + offset_within_channel] * + src[i * h * w + offset_within_channel]; + } + } + + return res; +} + +void lrn_compute_ref(std::shared_ptr op) { + Scope* scope = op->scope(); + const OpInfo* op_info = op->op_info(); + auto x = + scope->FindVar(op_info->Input("X").front())->GetMutable(); + auto out = scope->FindVar(op_info->Output("Out").front()) + ->GetMutable(); + + const float* x_data = x->data(); + float* out_data = out->mutable_data(); + auto x_dims = x->dims(); + + auto alpha = op_info->GetAttr("alpha"); + auto beta = op_info->GetAttr("beta"); + auto k = op_info->GetAttr("k"); + auto norm_region = op_info->GetAttr("norm_region"); + auto local_size = op_info->GetAttr("n"); + + int N = x_dims[0]; + int C = x_dims[1]; + int H = x_dims[2]; + int W = x_dims[3]; + + int offset_num = 0; + int offset_within_channel = 0; + int dst_id; + + float square; + + for (int n = 0; n < N; ++n) { + offset_num = n * C * H * W; + + for (int c = 0; c < C; ++c) { + for (int h = 0; h < H; ++h) { + for (int w = 0; w < W; ++w) { + offset_within_channel = h * W + w; + dst_id = offset_num + c * H * W + offset_within_channel; + square = lrn_square(x_data, + c, + offset_within_channel, + offset_num, + C, + H, + W, + local_size); + out_data[dst_id] = x_data[dst_id] * pow(k + alpha * square, -beta); + } + } + } + } +} + +void test_lrn(float alpha, + float beta, + float k, + int local_size, + int n, + int c, + int h, + int w, + const std::string& norm_region) { + Scope scope; + std::string x_var_name("X_test"); + std::string out_var_name("Out_test"); + std::string out_ref_var_name("Out_ref"); + auto* x = scope.NewTensor(x_var_name); + auto* out = scope.NewTensor(out_var_name); + auto* out_ref = scope.NewTensor(out_ref_var_name); + + std::vector x_dim{n, c, h, w}; + x->Resize(x_dim); + out->Resize(x_dim); + out_ref->Resize(x_dim); + auto* x_data = x->mutable_data(); + FillTensor(x, 0.f, 1.f); + float *dmax, *dmin; + std::tie(dmin, dmax) = + std::minmax_element(x_data, x_data + x->data_size() - 1); + + cpp::OpDesc opdesc; + opdesc.SetType("lrn"); + opdesc.SetInput("X", {x_var_name}); + opdesc.SetOutput("Out", {out_var_name}); + opdesc.SetAttr("alpha", alpha); + opdesc.SetAttr("beta", beta); + opdesc.SetAttr("k", k); + opdesc.SetAttr("n", local_size); + opdesc.SetAttr("norm_region", norm_region); + opdesc.SetAttr("input_scale", (*dmax - *dmin) / 255.f); + + auto op = CreateOp(opdesc, &scope); + + // baseline + lrn_compute_ref(op); + out_ref->CopyDataFrom(*out); + + Tensor input_x; + input_x.Resize(x->dims()); + transpose(x->mutable_data(), + input_x.mutable_data(), + {static_cast(x_dim[0]), + static_cast(x_dim[1]), + static_cast(x_dim[2]), + static_cast(x_dim[3])}, + {0, 2, 3, 1}); + x->CopyDataFrom(input_x); + + LaunchOp(op, {x_var_name}, {out_var_name}); + + Tensor output_trans; + auto os = out->dims(); + output_trans.Resize(os); + transpose(out->mutable_data(), + output_trans.mutable_data(), + {static_cast(os[0]), + static_cast(os[2]), + static_cast(os[3]), + static_cast(os[1])}, + {0, 3, 1, 2}); + + auto output_data = output_trans.mutable_data(); + auto* output_ref_data = out_ref->mutable_data(); + for (size_t i = 0; i < out->data_size(); i++) { + EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-4); + } +} + +TEST(MLUBridges, lrn) { + int local_size = 5; + float alpha = 0.0001f; + float beta = 0.75; + float k = 2.0f; + std::string norm_region = "AcrossChannels"; + for (int w : {2, 4, 8}) { + for (int h : {2, 4, 8}) { + for (int c : {1, 2, 3, 4}) { + for (int n : {1, 2, 3, 4}) { + test_lrn(alpha, beta, k, local_size, n, c, h, w, norm_region); + } + } + } + } +} + +} // namespace mlu +} // namespace subgraph +} // namespace lite +} // namespace paddle + +USE_SUBGRAPH_BRIDGE(lrn, kMLU) diff --git a/lite/kernels/mlu/bridges/norm_op.cc b/lite/kernels/mlu/bridges/norm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..492c3932a8c8a68f7eba687dde30d888d6e0f297 --- /dev/null +++ b/lite/kernels/mlu/bridges/norm_op.cc @@ -0,0 +1,111 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/mlu/bridges/graph.h" +#include "lite/kernels/mlu/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace mlu { + +int NormConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto scope = op->scope(); + VLOG(3) << "[MLU] Converting " + op_type + "..."; + + // Get input vars and op attributes + auto x_var_name = op_info->Input("X").front(); + auto x = scope->FindVar(x_var_name)->GetMutable(); + auto x_dims = x->dims().Vectorize(); + + auto out_var_name = op_info->Output("Out").front(); + auto output = scope->FindVar(out_var_name)->GetMutable(); + auto output_dims = output->dims().Vectorize(); + int axis = op_info->GetAttr("axis"); + int epsilon = op_info->GetAttr("epsilon"); + if (axis < 0) { + axis = axis + x_dims.size(); + } + std::vector nchw2nhwc = {0, 3, 1, 2}; + int nhwc_axis = nchw2nhwc[axis]; + + CHECK(graph->HasNode(x_var_name)); + auto input_tensor = graph->GetNode(x_var_name); + auto output_tensor = graph->AddNode( + out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, graph->FPType()); + + // ======== DEBUG =============== + VLOG(6) << "x name=" << x_var_name; + VLOG(6) << "out name=" << out_var_name; + VLOG(6) << "x dims=" << x->dims(); + VLOG(6) << "out dims=" << output->dims(); + VLOG(6) << "axis =" << axis; + VLOG(6) << "nwhc axis=" << nhwc_axis; + VLOG(6) << "epsilon =" << epsilon; + // cnmlPrintTensor(input_tensor->mlu_tensor(), CNML_TENSOR); + // cnmlPrintTensor(output_tensor->mlu_tensor(), CNML_TENSOR); + // ======== DEBUG END ============ + cnmlBaseOp_t norm_op{nullptr}; + + cnmlNormalizeOpParam_t param; + int mode = -1; + switch (axis) { + case 0: + mode = 3; // N + break; + case 1: + mode = 0; // C + break; + case 2: + mode = 4; // H + break; + case 3: + mode = 5; // W + break; + default: + CHECK(0); + break; + } + cnmlCreateNormalizeOpParamV2(¶m, + 0, // p + 0, // use_scale + mode, + 1, // weight + epsilon); + + CNML_CALL(cnmlCreateNormalizeOp(&norm_op, + param, + input_tensor->mlu_tensor(), + output_tensor->mlu_tensor(), + nullptr, + false /*is_fix8_mode*/)); + graph->FuseOp(norm_op); + CNML_CALL(cnmlDestroyBaseOp(&norm_op)); + return SUCCESS; +} + +} // namespace mlu +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(norm, + kMLU, + paddle::lite::subgraph::mlu::NormConverter); diff --git a/lite/kernels/mlu/bridges/norm_op_test.cc b/lite/kernels/mlu/bridges/norm_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..35b5eabbb9ffacd96c3ca6500dd9181f4d5bec5b --- /dev/null +++ b/lite/kernels/mlu/bridges/norm_op_test.cc @@ -0,0 +1,148 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/norm_op.h" + +#include + +#include +#include + +#include "lite/core/op_registry.h" +#include "lite/kernels/mlu/bridges/test_helper.h" +#include "lite/kernels/mlu/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" +namespace paddle { +namespace lite { +namespace subgraph { +namespace mlu { + +// void ToFile(std::string file_name, Tensor* tensor) { +// int count = tensor->dims().production(); +// auto data = tensor->mutable_data(); +// std::ostringstream outs; +// for (size_t i = 0; i < count; i++) { +// outs << data[i] << std::endl; +// } +// std::ofstream of; +// of.open(file_name, std::ios::out); +// of << outs.str(); +// of.close(); +// } + +void norm_ref(const std::shared_ptr op) { + Scope* scope = op->scope(); + const OpInfo* op_info = op->op_info(); + auto x = scope->FindVar(op_info->Input("X").front())->GetMutable(); + auto out = + scope->FindVar(op_info->Output("Out").front())->GetMutable(); + int axis = op_info->GetAttr("axis"); + int epsilon = op_info->GetAttr("epsilon"); + auto x_dims = x->dims(); + if (axis < 0) { + axis += x_dims.size(); + } + out->Resize(x_dims.Vectorize()); + auto* out_data = out->mutable_data(); + + const auto* x_data = x->data(); + int pre_n = x_dims.count(0, axis); + int n = x_dims[axis]; + int post_n = x_dims.count(axis + 1, x_dims.size()); + for (int i = 0; i < pre_n; i++) { + for (int k = 0; k < post_n; k++) { + float sum = epsilon; + const float* in_tmp = x_data + i * n * post_n + k; + for (int j = 0; j < n; j++) { + sum += in_tmp[j * post_n] * in_tmp[j * post_n]; + } + sum = std::sqrt(sum); + float* out_tmp = out_data + i * n * post_n + k; + for (int j = 0; j < n; j++) { + out_tmp[j * post_n] = in_tmp[j * post_n] / sum; + } + } + } +} + +void test_norm(const std::vector& input_shape, int axis) { + // prepare input&output variables + Scope scope; + std::string x_var_name = "x"; + std::string out_var_name = "out"; + std::string out_ref_var_name = "out_ref"; + auto* x = scope.Var(x_var_name)->GetMutable(); + auto* out = scope.Var(out_var_name)->GetMutable(); + auto* out_ref = scope.Var(out_ref_var_name)->GetMutable(); + x->Resize(input_shape); + // initialize input&output data + FillTensor(x, -9, 9); + // initialize op desc + cpp::OpDesc opdesc; + float epsilon = 1e-9f; + opdesc.SetType("norm"); + opdesc.SetInput("X", {x_var_name}); + opdesc.SetOutput("Out", {out_var_name}); + opdesc.SetAttr("axis", static_cast(axis)); + opdesc.SetAttr("epsilon", static_cast(epsilon)); + + // create and convert op to MLU model, then run it on MLU + auto op = CreateOp(opdesc, &scope); + norm_ref(op); + out_ref->CopyDataFrom(*out); + Tensor input_x; + input_x.Resize(DDim(input_shape)); + // change input layout from NCHW to NHWC + transpose(x->mutable_data(), + input_x.mutable_data(), + {static_cast(input_shape[0]), + static_cast(input_shape[1]), + static_cast(input_shape[2]), + static_cast(input_shape[3])}, + {0, 2, 3, 1}); + x->CopyDataFrom(input_x); + + LaunchOp(op, {x_var_name}, {out_var_name}); + auto* out_data = out->mutable_data(); + auto* out_ref_data = out_ref->mutable_data(); + std::vector out_shape = input_shape; + Tensor output_trans; + output_trans.Resize(out_shape); + // Change output layout from NHWC to NCHW + transpose(out_data, + output_trans.mutable_data(), + {static_cast(out_shape[0]), + static_cast(out_shape[2]), + static_cast(out_shape[3]), + static_cast(out_shape[1])}, + {0, 3, 1, 2}); + out_data = output_trans.mutable_data(); + + for (int i = 0; i < out->dims().production(); i++) { + EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2); + } +} + +TEST(MLUBridges, norm) { + test_norm({1, 2, 3, 4}, 1); + test_norm({1, 2, 3, 4}, 2); + test_norm({1, 2, 3, 4}, 3); +} + +} // namespace mlu +} // namespace subgraph +} // namespace lite +} // namespace paddle + +USE_SUBGRAPH_BRIDGE(norm, kMLU); diff --git a/lite/kernels/mlu/bridges/paddle_use_bridges.h b/lite/kernels/mlu/bridges/paddle_use_bridges.h index 2a145c4be2f3231ec6decae81a1bfcb3ed5bafe3..be5c64b3b7056d0b8de1589d198db541b5a3777b 100644 --- a/lite/kernels/mlu/bridges/paddle_use_bridges.h +++ b/lite/kernels/mlu/bridges/paddle_use_bridges.h @@ -43,3 +43,8 @@ USE_SUBGRAPH_BRIDGE(flatten, kMLU); USE_SUBGRAPH_BRIDGE(flatten2, kMLU); USE_SUBGRAPH_BRIDGE(reshape, kMLU); USE_SUBGRAPH_BRIDGE(reshape2, kMLU); +#ifdef LITE_BUILD_EXTRA +USE_SUBGRAPH_BRIDGE(gather, kMLU); +USE_SUBGRAPH_BRIDGE(lrn, kMLU) +USE_SUBGRAPH_BRIDGE(norm, kMLU) +#endif