diff --git a/mace/kernels/arm/fully_connected.cc b/mace/kernels/arm/fully_connected.cc new file mode 100644 index 0000000000000000000000000000000000000000..e28c8f3e19d09257694dd4bd86d7998e88f67466 --- /dev/null +++ b/mace/kernels/arm/fully_connected.cc @@ -0,0 +1,39 @@ +// +// Copyright (c) 2018 XiaoMi All rights reserved. +// + +#include "mace/kernels/fully_connected.h" +#include "mace/kernels/gemm.h" + +namespace mace { +namespace kernels { + +void FullyConnectedFunctor::operator()(const Tensor *input, + const Tensor *weight, + const Tensor *bias, + Tensor *output, + StatsFuture *future) { + std::vector output_shape = {input->dim(0), weight->dim(0), 1, 1}; + output->Resize(output_shape); + const index_t N = output->dim(0); + const index_t input_size = weight->dim(1); + const index_t output_size = weight->dim(0); + const float *input_ptr = input->data(); + const float *weight_ptr = weight->data(); + const float *bias_ptr = bias == nullptr ? nullptr : bias->data(); + float *output_ptr = output->mutable_data(); + + for (int i = 0; i < N; ++i) { + Gemm(weight_ptr, input_ptr, 1, output_size, input_size, 1, output_ptr); + for (int j = 0; j < output_size; ++j) { + output_ptr[j] += bias_ptr[j]; + } + } + + DoActivation(output_ptr, output_ptr, output->size(), activation_, + relux_max_limit_); +} + +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/fully_connected.h b/mace/kernels/fully_connected.h index 4ab385291da1854808f73cd0bdd926c7cc17c616..3c21fba7ba827ab78ac81467886672262bf8d391 100644 --- a/mace/kernels/fully_connected.h +++ b/mace/kernels/fully_connected.h @@ -76,6 +76,20 @@ struct FullyConnectedFunctor : FullyConnectedBase { } }; +template <> +struct FullyConnectedFunctor : FullyConnectedBase { + FullyConnectedFunctor(const BufferType weight_type, + const ActivationType activation, + const float relux_max_limit) + : FullyConnectedBase(weight_type, activation, relux_max_limit) {} + + void operator()(const Tensor *input, + const Tensor *weight, + const Tensor *bias, + Tensor *output, + StatsFuture *future); +}; + template struct FullyConnectedFunctor : FullyConnectedBase { FullyConnectedFunctor(const BufferType weight_type, diff --git a/mace/ops/eltwise.cc b/mace/ops/eltwise.cc index 5c49f3563f46dc30e634fc225dca262a32d3d682..07c775ad68c00bbc37e9ab2ee7ee7cec5082744a 100644 --- a/mace/ops/eltwise.cc +++ b/mace/ops/eltwise.cc @@ -25,6 +25,11 @@ void Register_Eltwise(OperatorRegistry *op_registry) { .TypeConstraint("T") .Build(), EltwiseOp); + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Eltwise") + .Device(DeviceType::NEON) + .TypeConstraint("T") + .Build(), + EltwiseOp); } } // namespace ops diff --git a/mace/ops/fully_connected.cc b/mace/ops/fully_connected.cc index dd4c5b87228d8e1c0aacbf2b93e5cb1a03e3aa76..2558e4575cd762fc63f8b600e808920806a293f9 100644 --- a/mace/ops/fully_connected.cc +++ b/mace/ops/fully_connected.cc @@ -25,6 +25,12 @@ void Register_FullyConnected(OperatorRegistry *op_registry) { .TypeConstraint("T") .Build(), FullyConnectedOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("FC") + .Device(DeviceType::NEON) + .TypeConstraint("T") + .Build(), + FullyConnectedOp); } } // namespace ops diff --git a/mace/ops/fully_connected_test.cc b/mace/ops/fully_connected_test.cc index f839f95c975f94545d3661642de0ffc762c2a27d..6d16f01812f6093b4a48a4eaf9b58b13f4aeb469 100644 --- a/mace/ops/fully_connected_test.cc +++ b/mace/ops/fully_connected_test.cc @@ -13,7 +13,7 @@ namespace test { class FullyConnectedOpTest : public OpsTestBase {}; -template +template void Simple(const std::vector &input_shape, const std::vector &input_value, const std::vector &weight_shape, @@ -38,12 +38,12 @@ void Simple(const std::vector &input_shape, kernels::BufferType::ARGUMENT); OpDefBuilder("FC", "FullyConnectedTest") - .Input("InputImage") - .Input("WeightImage") - .Input("BiasImage") - .Output("OutputImage") - .AddIntArg("weight_type", kernels::BufferType::WEIGHT_HEIGHT) - .Finalize(net.NewOperatorDef()); + .Input("InputImage") + .Input("WeightImage") + .Input("BiasImage") + .Output("OutputImage") + .AddIntArg("weight_type", kernels::BufferType::WEIGHT_HEIGHT) + .Finalize(net.NewOperatorDef()); // Run net.RunOp(D); @@ -52,11 +52,11 @@ void Simple(const std::vector &input_shape, kernels::BufferType::IN_OUT_CHANNEL); } else { OpDefBuilder("FC", "FullyConnectedTest") - .Input("Input") - .Input("Weight") - .Input("Bias") - .Output("Output") - .Finalize(net.NewOperatorDef()); + .Input("Input") + .Input("Weight") + .Input("Bias") + .Output("Output") + .Finalize(net.NewOperatorDef()); // Run net.RunOp(D); } @@ -72,14 +72,14 @@ TEST_F(FullyConnectedOpTest, SimpleCPU) { {1, 2, 3, 4, 5, 6, 7, 8}, {1}, {2}, {1, 1, 1, 1}, {206}); Simple( - {1, 1, 2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {2, 10}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}, - {2}, {2, 3}, {1, 1, 1, 2}, {387, 3853}); + {1, 1, 2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {2, 10}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}, + {2}, {2, 3}, {1, 1, 1, 2}, {387, 3853}); Simple( - {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {5, 6}, - {1, 2, 3, 4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3, - 4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3, 4, 5, 6}, - {5}, {1, 2, 3, 4, 5}, {1, 1, 1, 5}, {92, 912, 94, 914, 96}); + {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {5, 6}, + {1, 2, 3, 4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3, + 4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3, 4, 5, 6}, + {5}, {1, 2, 3, 4, 5}, {1, 1, 1, 5}, {92, 912, 94, 914, 96}); } TEST_F(FullyConnectedOpTest, SimpleCPUWithBatch) { @@ -92,14 +92,14 @@ TEST_F(FullyConnectedOpTest, SimpleOPENCL) { {1, 2, 3, 4, 5, 6, 7, 8}, {1}, {2}, {1, 1, 1, 1}, {206}); Simple( - {1, 1, 2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {2, 10}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}, - {2}, {2, 3}, {1, 1, 1, 2}, {387, 3853}); + {1, 1, 2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {2, 10}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}, + {2}, {2, 3}, {1, 1, 1, 2}, {387, 3853}); Simple( - {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {5, 6}, - {1, 2, 3, 4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3, - 4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3, 4, 5, 6}, - {5}, {1, 2, 3, 4, 5}, {1, 1, 1, 5}, {92, 912, 94, 914, 96}); + {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {5, 6}, + {1, 2, 3, 4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3, + 4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3, 4, 5, 6}, + {5}, {1, 2, 3, 4, 5}, {1, 1, 1, 5}, {92, 912, 94, 914, 96}); } TEST_F(FullyConnectedOpTest, SimpleGPUWithBatch) { @@ -107,7 +107,7 @@ TEST_F(FullyConnectedOpTest, SimpleGPUWithBatch) { {1, 2, 3, 4}, {1}, {2}, {2, 1, 1, 1}, {32, 72}); } -template +template void Complex(const index_t batch, const index_t height, const index_t width, @@ -118,17 +118,17 @@ void Complex(const index_t batch, // Construct graph OpsTestNet net; OpDefBuilder("FC", "FullyConnectedTest") - .Input("Input") - .Input("Weight") - .Input("Bias") - .Output("Output") - .Finalize(net.NewOperatorDef()); + .Input("Input") + .Input("Weight") + .Input("Bias") + .Output("Output") + .Finalize(net.NewOperatorDef()); // Add input data net.AddRandomInput( - "Input", {batch, height, width, channels}); + "Input", {batch, height, width, channels}); net.AddRandomInput( - "Weight", {out_channel, height * width * channels}); + "Weight", {out_channel, height * width * channels}); net.AddRandomInput("Bias", {out_channel}); // run cpu @@ -147,13 +147,13 @@ void Complex(const index_t batch, kernels::BufferType::ARGUMENT); OpDefBuilder("FC", "FullyConnectedTest") - .Input("InputImage") - .Input("WeightImage") - .Input("BiasImage") - .Output("OutputImage") - .AddIntArg("weight_type", kernels::BufferType::WEIGHT_HEIGHT) - .AddIntArg("T", static_cast(DataTypeToEnum::value)) - .Finalize(net.NewOperatorDef()); + .Input("InputImage") + .Input("WeightImage") + .Input("BiasImage") + .Output("OutputImage") + .AddIntArg("weight_type", kernels::BufferType::WEIGHT_HEIGHT) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); // Run on opencl net.RunOp(DeviceType::OPENCL); @@ -189,7 +189,7 @@ TEST_F(FullyConnectedOpTest, OPENCLHalfUnAlignedWithBatch) { Complex(31, 21, 11, 23, 103); } -template +template void TestWXFormat(const index_t batch, const index_t height, const index_t width, @@ -200,17 +200,17 @@ void TestWXFormat(const index_t batch, // Construct graph OpsTestNet net; OpDefBuilder("FC", "FullyConnectedTest") - .Input("Input") - .Input("Weight") - .Input("Bias") - .Output("Output") - .Finalize(net.NewOperatorDef()); + .Input("Input") + .Input("Weight") + .Input("Bias") + .Output("Output") + .Finalize(net.NewOperatorDef()); // Add input data net.AddRandomInput( - "Input", {batch, height, width, channels}); + "Input", {batch, height, width, channels}); net.AddRandomInput( - "Weight", {out_channel, height * width * channels}); + "Weight", {out_channel, height * width * channels}); net.AddRandomInput("Bias", {out_channel}); // run cpu @@ -229,12 +229,12 @@ void TestWXFormat(const index_t batch, kernels::BufferType::ARGUMENT); OpDefBuilder("FC", "FullyConnectedTest") - .Input("InputImage") - .Input("WeightImage") - .Input("BiasImage") - .Output("OutputImage") - .AddIntArg("T", static_cast(DataTypeToEnum::value)) - .Finalize(net.NewOperatorDef()); + .Input("InputImage") + .Input("WeightImage") + .Input("BiasImage") + .Output("OutputImage") + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); // Run net.RunOp(DeviceType::OPENCL); @@ -266,6 +266,57 @@ TEST_F(FullyConnectedOpTest, OPENCLHalfWidthFormatAligned) { TestWXFormat(1, 16, 32, 32, 32); } +void FullyConnectedTestNEON(const index_t batch, + const index_t height, + const index_t width, + const index_t channels, + const index_t out_channel) { + srand(time(NULL)); + + // Construct graph + OpsTestNet net; + OpDefBuilder("FC", "FullyConnectedTest") + .Input("Input") + .Input("Weight") + .Input("Bias") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput( + "Input", {batch, height, width, channels}); + net.AddRandomInput( + "Weight", {out_channel, height * width * channels}); + net.AddRandomInput("Bias", {out_channel}); + + // run cpu + net.RunOp(); + + // Run on neon + OpDefBuilder("FC", "FullyConnectedTest") + .Input("Input") + .Input("Weight") + .Input("Bias") + .Output("OutputNeon") + .Finalize(net.NewOperatorDef()); + + // Run on device + net.RunOp(DeviceType::NEON); + + net.FillNHWCInputToNCHWInput("OutputExptected", + "Output"); + + ExpectTensorNear(*net.GetOutput("OutputExptected"), + *net.GetOutput("OutputNeon"), + 0.001); +} + +TEST_F(FullyConnectedOpTest, TestNEON) { + FullyConnectedTestNEON(1, 7, 7, 32, 16); + FullyConnectedTestNEON(1, 7, 7, 512, 128); + FullyConnectedTestNEON(1, 1, 1, 2048, 1024); +} + } // namespace test } // namespace ops } // namespace mace