diff --git a/lite/kernels/npu/bridges/batch_norm_op_test.cc b/lite/kernels/npu/bridges/batch_norm_op_test.cc deleted file mode 100644 index 38a876efb7c8ca6c38dee44e3c7a29a141d995d4..0000000000000000000000000000000000000000 --- a/lite/kernels/npu/bridges/batch_norm_op_test.cc +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/operators/batch_norm_op.h" -#include -#include "lite/core/op_registry.h" -#include "lite/kernels/npu/bridges/registry.h" -#include "lite/kernels/npu/bridges/test_helper.h" - -namespace paddle { -namespace lite { -namespace kernels { -namespace npu { -namespace bridges { - -template -void batch_norm_ref(const std::shared_ptr op) { - Scope* scope = op->scope(); - const OpInfo* op_info = op->op_info(); - auto x = scope->FindVar(op_info->Input("X").front())->GetMutable(); - auto y = scope->FindVar(op_info->Output("Y").front())->GetMutable(); - auto bias = - scope->FindVar(op_info->Input("Bias").front())->GetMutable(); - auto scale = - scope->FindVar(op_info->Input("Scale").front())->GetMutable(); - auto mean = - scope->FindVar(op_info->Input("Mean").front())->GetMutable(); - auto variance = - scope->FindVar(op_info->Input("Variance").front())->GetMutable(); - - auto x_data = x->data(); - auto y_data = y->mutable_data(); - auto scale_data = scale->mutable_data(); - auto bias_data = bias->mutable_data(); - auto mean_data = mean->mutable_data(); - auto variance_data = variance->mutable_data(); - DDim x_dims = x->dims(); - - float epsilon = op_info->GetAttr("epsilon"); - float momentum = op_info->GetAttr("momentum"); - auto data_layout = op_info->GetAttr("data_layout"); - - bool global_stats = op_info->GetAttr("use_global_stats"); - if (global_stats) { - int64_t outer_size = 0; - int64_t channel_size = 0; - int64_t inner_size = 0; - if (data_layout == "NCHW") { - outer_size = x_dims[0]; - channel_size = x_dims[1]; - inner_size = x_dims.Slice(2, x_dims.size()).production(); - } else { - LOG(FATAL) << "Unknown storage order: " << data_layout; - } - auto x_ptr = x_data; - auto y_ptr = y_data; - for (int o = 0; o < outer_size; o++) { - for (int c = 0; c < channel_size; c++) { - for (int i = 0; i < inner_size; i++) { - dtype norm_x = - (*x_ptr - mean_data[c]) / std::sqrt(variance_data[c] + epsilon); - *y_ptr = norm_x * scale_data[c] + bias_data[c]; - x_ptr++; - y_ptr++; - } - } - } - } -} - -void test_batch_norm( - int bs, int ic, int ih, int iw, float epsilon, float momentum) { - // prepare input&output variables - Scope scope; - std::string x_var_name = "x"; - std::string out_var_name = "out"; - std::string out_ref_var_name = "out_ref"; - std::string scale_var_name = "scale"; - std::string bias_var_name = "bias"; - std::string mean_var_name = "mean"; - std::string variance_var_name = "variance"; - auto* x = scope.Var(x_var_name)->GetMutable(); - auto* scale = scope.Var(scale_var_name)->GetMutable(); - auto* bias = scope.Var(bias_var_name)->GetMutable(); - auto* mean = scope.Var(mean_var_name)->GetMutable(); - auto* variance = scope.Var(variance_var_name)->GetMutable(); - auto* out = scope.Var(out_var_name)->GetMutable(); - auto* out_ref = scope.Var(out_ref_var_name)->GetMutable(); - x->Resize({bs, ic, ih, iw}); - scale->Resize({ic}); - bias->Resize({ic}); - mean->Resize({ic}); - variance->Resize({ic}); - - // initialize input&output data - FillTensor(x); - FillTensor(scale); - FillTensor(bias); - FillTensor(mean); - // variance > 0 - FillTensor(variance, 1.f, 5.f); - - // initialize op desc - cpp::OpDesc opdesc; - opdesc.SetType("batch_norm"); - opdesc.SetInput("X", {x_var_name}); - opdesc.SetInput("Scale", {scale_var_name}); - opdesc.SetInput("Bias", {bias_var_name}); - opdesc.SetInput("Mean", {mean_var_name}); - opdesc.SetInput("Variance", {variance_var_name}); - opdesc.SetOutput("Y", {out_var_name}); - opdesc.SetAttr("is_test", 1); - opdesc.SetAttr("use_global_stats", true); - opdesc.SetAttr("epsilon", epsilon); - opdesc.SetAttr("momentum", momentum); - opdesc.SetAttr("data_layout", std::string("NCHW")); - - // create and convert op to NPU model, then run it on NPU - auto op = CreateOp(opdesc, &scope); - LauchOp(op, {x_var_name}, {out_var_name}); - out_ref->CopyDataFrom(*out); - - // execute reference implementation and save to output tensor - batch_norm_ref(op); - - // compare results - auto* out_data = out->mutable_data(); - auto* out_ref_data = out_ref->mutable_data(); - for (int i = 0; i < out->dims().production(); i++) { - EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2); - } -} - -TEST(NPUBridges, batch_norm) { - for (auto bs : {1, 4, 7}) { - for (auto ic : {1, 4, 7}) { - for (auto ih : {1, 4, 7}) { - for (auto iw : {1, 4, 7}) { - for (auto epsilon : {1e-4f, 1e-5f}) { - for (auto momentum : {0.9f, 0.99f}) { - test_batch_norm(bs, ic, ih, iw, epsilon, momentum); - } - } - } - } - } - } -} - -} // namespace bridges -} // namespace npu -} // namespace kernels -} // namespace lite -} // namespace paddle - -USE_LITE_OP(batch_norm); -USE_NPU_BRIDGE(batch_norm); diff --git a/lite/kernels/npu/bridges/transpose_op.cc b/lite/kernels/npu/bridges/transpose_op.cc index f758ababaca16a9187ca5ea416c84704b09fc19f..bdac84df3ca96d14891f3636292a13252246be19 100644 --- a/lite/kernels/npu/bridges/transpose_op.cc +++ b/lite/kernels/npu/bridges/transpose_op.cc @@ -37,7 +37,7 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { CHECK(x_type->layout() == DATALAYOUT(kNCHW)); auto x = scope->FindMutableTensor(x_name); auto x_dims = x->dims(); - auto out_name = op_info->Input("Out").front(); + auto out_name = op_info->Output("Out").front(); auto axis = op_info->GetAttr>("axis"); // X node diff --git a/lite/kernels/npu/bridges/transpose_op_test.cc b/lite/kernels/npu/bridges/transpose_op_test.cc deleted file mode 100644 index 9ad2610caa4f1674c1a07afd62a4b85361ec6645..0000000000000000000000000000000000000000 --- a/lite/kernels/npu/bridges/transpose_op_test.cc +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/operators/transpose_op.h" -#include -#include "lite/core/op_registry.h" -#include "lite/kernels/npu/bridges/registry.h" -#include "lite/kernels/npu/bridges/test_helper.h" - -namespace paddle { -namespace lite { -namespace kernels { -namespace npu { -namespace bridges { - -int data_index(std::vector pos, DDimLite dims) { - int d1 = dims[1]; - int d2 = dims[2]; - int d3 = dims[3]; - return pos[3] + pos[2] * d3 + pos[1] * d3 * d2 + pos[0] * d3 * d2 * d1; -} - -std::vector pos_trans(std::vector in_pos, std::vector axis) { - std::vector out_pos(in_pos.size()); - for (int i = 0; i < axis.size(); i++) { - out_pos[axis[i]] = in_pos[i]; - } - return out_pos; -} - -void transpose_ref(const std::shared_ptr op) { - Scope* scope = op->scope(); - const OpInfo* op_info = op->op_info(); - auto input = - scope->FindVar(op_info->Input("X").front())->GetMutable(); - auto output = - scope->FindVar(op_info->Output("Out").front())->GetMutable(); - auto x_dims = input->dims(); - auto y_dims = output->dims(); - auto axis = op_info->GetAttr>("axis"); - - auto* input_data = input->data(); - auto* output_data = output->mutable_data(); - - int input_n = x_dims[0]; - int input_c = x_dims[1]; - int input_h = x_dims[2]; - int input_w = x_dims[3]; - int output_n = y_dims[0]; - int output_c = y_dims[1]; - int output_h = y_dims[2]; - int output_w = y_dims[3]; - - for (int n = 0; n < input_n; ++n) { - for (int c = 0; c < input_c; ++c) { - for (int h = 0; h < input_h; ++h) { - for (int w = 0; w < input_w; ++w) { - std::vector in_pos{n, c, h, w}; - std::vector out_pos = pos_trans(in_pos, axis); - int in_index = data_index(in_pos, x_dims); - int out_index = data_index(out_pos, y_dims); - output_data[out_index] = input_data[in_index]; - } - } - } - } -} - -void test_transpose(int bs, int ic, int ih, int iw, std::vector axis) { - // prepare input&output variables - Scope scope; - std::string x_var_name = "x"; - std::string out_var_name = "out"; - std::string out_ref_var_name = "out_ref"; - auto* x = scope.Var(x_var_name)->GetMutable(); - auto* out = scope.Var(out_var_name)->GetMutable(); - auto* out_ref = scope.Var(out_ref_var_name)->GetMutable(); - x->Resize({bs, ic, ih, iw}); - - // initialize input&output data - FillTensor(x); - - // initialize op desc - cpp::OpDesc opdesc; - opdesc.SetType("transpose"); - opdesc.SetInput("X", {x_var_name}); - opdesc.SetOutput("Out", {out_var_name}); - opdesc.SetAttr("axis", axis); - - // create and convert op to NPU model, then run it on NPU - auto op = CreateOp(opdesc, &scope); - LauchOp(op, {x_var_name}, {out_var_name}); - out_ref->CopyDataFrom(*out); - - // execute reference implementation and save to output tensor - transpose_ref(op); - - // compare results - auto* out_data = out->mutable_data(); - auto* out_ref_data = out_ref->mutable_data(); - for (int i = 0; i < out->dims().production(); i++) { - EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2); - } -} - -TEST(NPUBridges, transpose) { -#if 0 - for (auto bs : {1, 4, 7}) { - for (auto ic : {1, 4, 7}) { - for (auto ih : {1, 4, 7}) { - for (auto iw : {1, 4, 7}) { - for (auto axis : {std::vector{0, 1, 2, 3}, - std::vector{0, 1, 3, 2}, - std::vector{0, 3, 1, 2}, - std::vector{1, 2, 3, 0}, - std::vector{3, 2, 1, 0}, - std::vector{2, 3, 1, 0}}) { - test_transpose(bs, ic, ih, iw, axis); - } - } - } - } - } -#endif - test_transpose(2, 3, 4, 5, std::vector{0, 1, 3, 2}); - // test_transpose(2, 3, 4, 5, std::vector{0, 1, 2, 3}); - // test_transpose(2, 2, 2, 2, std::vector{0,1,3,2}); - // test_transpose(1, 1, 2, 2, std::vector{0,1,3,2}); - // test_transpose(1, 1, 1, 2, std::vector{0,1,2,3}); -} - -} // namespace bridges -} // namespace npu -} // namespace kernels -} // namespace lite -} // namespace paddle - -USE_LITE_OP(transpose); -USE_NPU_BRIDGE(transpose); - -USE_LITE_OP(transpose2); -USE_NPU_BRIDGE(transpose2); diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index 7eda7acd69b6bfdc6be92857f23adb35c71f0487..0a1ff3906f03065c93759926fd3e3262697d48e9 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -25,13 +25,13 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH #lite_cc_test(test_kernel_write_to_array_compute SRCS write_to_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_concat_compute SRCS concat_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) - lite_cc_test(test_kernel_transpose_compute SRCS transpose_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_transpose_compute SRCS transpose_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${cuda_kernels} ${npu_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_reshape_compute SRCS reshape_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_layer_norm_compute SRCS layer_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_dropout_compute SRCS dropout_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_softmax_compute SRCS softmax_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_mul_compute SRCS mul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) - lite_cc_test(test_kernel_batch_norm_compute SRCS batch_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_batch_norm_compute SRCS batch_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${cuda_kernels} ${npu_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) if(LITE_BUILD_EXTRA) lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) diff --git a/lite/tests/kernels/batch_norm_compute_test.cc b/lite/tests/kernels/batch_norm_compute_test.cc index c462dd011f1db1f60046b54b892ea7973d7beb7c..ae65e0e3c320ff153a99d2a1656227bad34428d4 100644 --- a/lite/tests/kernels/batch_norm_compute_test.cc +++ b/lite/tests/kernels/batch_norm_compute_test.cc @@ -159,6 +159,8 @@ TEST(BatchNorm, precision) { Place place; #if defined(LITE_WITH_XPU) place = TARGET(kXPU); +#elif defined(LITE_WITH_NPU) + place = TARGET(kNPU); #else return; #endif diff --git a/lite/tests/kernels/transpose_compute_test.cc b/lite/tests/kernels/transpose_compute_test.cc index 62e0fc8e410092975eed3ea5fec441a7859de81f..b4407bb5690fe8c1f4305cea584f9abf5af121bb 100644 --- a/lite/tests/kernels/transpose_compute_test.cc +++ b/lite/tests/kernels/transpose_compute_test.cc @@ -16,6 +16,7 @@ #include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_ops.h" #include "lite/core/arena/framework.h" +#include "lite/tests/utils/fill_data.h" namespace paddle { namespace lite { @@ -24,13 +25,13 @@ int data_index(std::vector pos, DDimLite dims) { int d1 = dims[1]; int d2 = dims[2]; int d3 = dims[3]; - return pos[3] + pos[2] * d3 + pos[1] * d3 * d2 + pos[0] * d3 * d2 * d1; + return pos[0] * d1 * d2 * d3 + pos[1] * d2 * d3 + pos[2] * d3 + pos[3]; } std::vector pos_trans(std::vector in_pos, std::vector axis) { std::vector out_pos(in_pos.size()); for (int i = 0; i < axis.size(); i++) { - out_pos[axis[i]] = in_pos[i]; + out_pos[i] = in_pos[axis[i]]; } return out_pos; } @@ -42,35 +43,34 @@ class TransposeComputeTester : public arena::TestCase { std::string input_ = "x"; std::string output_ = "out"; std::string xshape_ = "xshape"; - DDim x_dims_; + DDim dims_; std::vector axis_; public: TransposeComputeTester(const Place& place, const std::string& alias, - DDim x_dims, + DDim dims, std::vector axis) - : TestCase(place, alias), x_dims_(x_dims), axis_(axis) {} + : TestCase(place, alias), dims_(dims), axis_(axis) {} void RunBaseline(Scope* scope) override { auto* out = scope->NewTensor(output_); CHECK(out); auto* x = scope->FindTensor(input_); - auto x_dims = x->dims(); - std::vector out_shape(x_dims.size(), 0); - for (size_t i = 0; i < x_dims.size(); i++) { - out_shape[i] = x_dims[axis_[i]]; + std::vector out_shape(dims_.size(), 0); + for (size_t i = 0; i < dims_.size(); i++) { + out_shape[i] = dims_[axis_[i]]; } out->Resize(out_shape); auto y_dims = out->dims(); - int input_n = x_dims[0]; - int input_c = x_dims[1]; - int input_h = x_dims[2]; - int input_w = x_dims[3]; + int input_n = dims_[0]; + int input_c = dims_[1]; + int input_h = dims_[2]; + int input_w = dims_[3]; auto input_data = x->data(); auto output_data = out->mutable_data(); @@ -81,7 +81,7 @@ class TransposeComputeTester : public arena::TestCase { for (int w = 0; w < input_w; ++w) { std::vector in_pos{n, c, h, w}; std::vector out_pos = pos_trans(in_pos, axis_); - int in_index = data_index(in_pos, x_dims); + int in_index = data_index(in_pos, dims_); int out_index = data_index(out_pos, y_dims); output_data[out_index] = input_data[in_index]; } @@ -91,7 +91,7 @@ class TransposeComputeTester : public arena::TestCase { if (op_type_ == "transpose2") { auto* xshape = scope->NewTensor(xshape_); - auto xshape_dims = x_dims.Vectorize(); + auto xshape_dims = dims_.Vectorize(); xshape_dims.insert(xshape_dims.begin(), 0); xshape->Resize(xshape_dims); } @@ -108,11 +108,9 @@ class TransposeComputeTester : public arena::TestCase { } void PrepareData() override { - std::vector data(x_dims_.production()); - for (int i = 0; i < x_dims_.production(); i++) { - data[i] = i * 1.1; - } - SetCommonTensor(input_, x_dims_, data.data()); + std::vector din(dims_.production()); + fill_data_rand(din.data(), -1.f, 1.f, dims_.production()); + SetCommonTensor(input_, dims_, din.data()); } }; @@ -122,14 +120,16 @@ TEST(Transpose, precision) { Place place; #ifdef LITE_WITH_XPU place = TARGET(kXPU); +#elif defined(LITE_WITH_NPU) + place = TARGET(kNPU); + abs_error = 1e-2; // Using fp16 in NPU #else return; #endif DDim x_dims{{2, 3, 4, 5}}; - // [XPU]: {3, 1, 0, 2} is unsupported std::vector> axes{ - {0, 1, 2, 3}, {0, 1, 3, 2}, {0, 2, 1, 3}, {3, 1, 2, 0}}; + {0, 1, 2, 3}, {0, 1, 3, 2}, {0, 2, 1, 3}, {3, 1, 2, 0}, {3, 1, 0, 2}}; for (auto axis : axes) { std::unique_ptr tester( new TransposeComputeTester(place, "def", x_dims, axis));