// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/lite/kernels/arm/fc_compute.h" #include #include #include #include #include #include #include #include "paddle/fluid/lite/arm/math/funcs.h" #include "paddle/fluid/lite/core/op_registry.h" namespace paddle { namespace lite { namespace kernels { namespace arm { template void FillData(T* a, const int n, const T lower = static_cast(-2.f), const T upper = static_cast(2.f)) { static unsigned int seed = 100; std::mt19937 rng(seed++); std::uniform_real_distribution uniform_dist(0, 1); for (int i = 0; i < n; ++i) { a[i] = static_cast(uniform_dist(rng) * (upper - lower) + lower); } } TEST(fc_arm, retrive_op) { auto fc = KernelRegistry::Global().Create("fc"); ASSERT_FALSE(fc.empty()); ASSERT_TRUE(fc.front()); } TEST(fc_arm, init) { FcCompute fc; ASSERT_EQ(fc.precision(), PRECISION(kFloat)); ASSERT_EQ(fc.target(), TARGET(kARM)); } TEST(fc_arm, compare_test) { using T = float; for (int m : {1, 2, 3, 4}) { for (int n : {1, 2, 3, 4}) { for (int k : {1, 2, 3, 4}) { for (bool with_bias : {true, false}) { VLOG(3) << "m: " << m << ", n: " << n << ", k: " << k << (with_bias ? ", with bias" : ""); lite::Tensor x, w, b, out, ref; x.Resize({m, k}); w.Resize({k, n}); b.Resize({1, n}); out.Resize({m, n}); ref.Resize({m, n}); auto* x_data = x.mutable_data(); auto* w_data = w.mutable_data(); auto* b_data = with_bias ? b.mutable_data() : nullptr; auto* out_data = out.mutable_data(); auto* ref_data = ref.mutable_data(); FillData(x_data, x.dims().production()); FillData(w_data, w.dims().production()); FillData(out_data, out.dims().production(), 0, 0); FillData(ref_data, ref.dims().production(), 0, 0); if (with_bias) { FillData(b_data, b.dims().production()); } FcCompute fc; operators::FcParam param; param.input = &x; param.w = &w; param.bias = with_bias ? &b : nullptr; param.output = &out; param.in_num_col_dims = 1; param.in_mat_dims = x.dims(); DeviceInfo::Init(); std::unique_ptr ctx(new KernelContext); ctx->As(); fc.SetParam(param); fc.SetContext(std::move(ctx)); fc.PrepareForRun(); fc.Run(); lite::arm::math::fc_compute_eigen(x_data, m, k, w_data, k, n, b_data, ref_data); for (int i = 0; i < out.dims().production(); i++) { EXPECT_NEAR(out_data[i], ref_data[i], 1e-3); } } } } } } TEST(fc_arm, num_col_dims) { using T = float; for (bool with_bias : {true, false}) { lite::Tensor x, w, b, out, ref; x.Resize({1, 2, 3}); w.Resize({3, 4}); b.Resize({1, 4}); out.Resize({2, 4}); ref.Resize({2, 4}); auto* x_data = x.mutable_data(); auto* w_data = w.mutable_data(); auto* b_data = with_bias ? b.mutable_data() : nullptr; auto* out_data = out.mutable_data(); auto* ref_data = ref.mutable_data(); FillData(x_data, x.dims().production()); FillData(w_data, w.dims().production()); FillData(out_data, out.dims().production(), 0, 0); FillData(ref_data, ref.dims().production(), 0, 0); if (with_bias) { FillData(b_data, b.dims().production()); } FcCompute fc; operators::FcParam param; param.input = &x; param.w = &w; param.bias = with_bias ? &b : nullptr; param.output = &out; param.in_num_col_dims = 2; param.in_mat_dims = x.dims(); std::unique_ptr ctx(new KernelContext); ctx->As(); DeviceInfo::Init(); fc.SetParam(param); fc.SetContext(std::move(ctx)); fc.PrepareForRun(); fc.Run(); lite::arm::math::fc_compute_eigen(x_data, 2, 3, w_data, 3, 4, b_data, ref_data); for (int i = 0; i < out.dims().production(); i++) { EXPECT_NEAR(out_data[i], ref_data[i], 1e-3); } } } } // namespace arm } // namespace kernels } // namespace lite } // namespace paddle USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def);