提交 73057ea3 编写于 作者: 李寅

Implement fc and eltwise for neon

上级 fdf938ce
//
// Copyright (c) 2018 XiaoMi All rights reserved.
//
#include "mace/kernels/fully_connected.h"
#include "mace/kernels/gemm.h"
namespace mace {
namespace kernels {
void FullyConnectedFunctor<DeviceType::NEON,
float>::operator()(const Tensor *input,
const Tensor *weight,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
std::vector<index_t> output_shape = {input->dim(0), weight->dim(0), 1, 1};
output->Resize(output_shape);
const index_t N = output->dim(0);
const index_t input_size = weight->dim(1);
const index_t output_size = weight->dim(0);
const float *input_ptr = input->data<float>();
const float *weight_ptr = weight->data<float>();
const float *bias_ptr = bias == nullptr ? nullptr : bias->data<float>();
float *output_ptr = output->mutable_data<float>();
for (int i = 0; i < N; ++i) {
Gemm(weight_ptr, input_ptr, 1, output_size, input_size, 1, output_ptr);
for (int j = 0; j < output_size; ++j) {
output_ptr[j] += bias_ptr[j];
}
}
DoActivation(output_ptr, output_ptr, output->size(), activation_,
relux_max_limit_);
}
} // namespace kernels
} // namespace mace
...@@ -76,6 +76,20 @@ struct FullyConnectedFunctor : FullyConnectedBase { ...@@ -76,6 +76,20 @@ struct FullyConnectedFunctor : FullyConnectedBase {
} }
}; };
template <>
struct FullyConnectedFunctor<DeviceType::NEON, float> : FullyConnectedBase {
FullyConnectedFunctor(const BufferType weight_type,
const ActivationType activation,
const float relux_max_limit)
: FullyConnectedBase(weight_type, activation, relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *weight,
const Tensor *bias,
Tensor *output,
StatsFuture *future);
};
template <typename T> template <typename T>
struct FullyConnectedFunctor<DeviceType::OPENCL, T> : FullyConnectedBase { struct FullyConnectedFunctor<DeviceType::OPENCL, T> : FullyConnectedBase {
FullyConnectedFunctor(const BufferType weight_type, FullyConnectedFunctor(const BufferType weight_type,
......
...@@ -25,6 +25,11 @@ void Register_Eltwise(OperatorRegistry *op_registry) { ...@@ -25,6 +25,11 @@ void Register_Eltwise(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T") .TypeConstraint<half>("T")
.Build(), .Build(),
EltwiseOp<DeviceType::OPENCL, half>); EltwiseOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Eltwise")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
EltwiseOp<DeviceType::NEON, float>);
} }
} // namespace ops } // namespace ops
......
...@@ -25,6 +25,12 @@ void Register_FullyConnected(OperatorRegistry *op_registry) { ...@@ -25,6 +25,12 @@ void Register_FullyConnected(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T") .TypeConstraint<half>("T")
.Build(), .Build(),
FullyConnectedOp<DeviceType::OPENCL, half>); FullyConnectedOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FC")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
FullyConnectedOp<DeviceType::NEON, float>);
} }
} // namespace ops } // namespace ops
......
...@@ -13,7 +13,7 @@ namespace test { ...@@ -13,7 +13,7 @@ namespace test {
class FullyConnectedOpTest : public OpsTestBase {}; class FullyConnectedOpTest : public OpsTestBase {};
template <DeviceType D> template<DeviceType D>
void Simple(const std::vector<index_t> &input_shape, void Simple(const std::vector<index_t> &input_shape,
const std::vector<float> &input_value, const std::vector<float> &input_value,
const std::vector<index_t> &weight_shape, const std::vector<index_t> &weight_shape,
...@@ -38,12 +38,12 @@ void Simple(const std::vector<index_t> &input_shape, ...@@ -38,12 +38,12 @@ void Simple(const std::vector<index_t> &input_shape,
kernels::BufferType::ARGUMENT); kernels::BufferType::ARGUMENT);
OpDefBuilder("FC", "FullyConnectedTest") OpDefBuilder("FC", "FullyConnectedTest")
.Input("InputImage") .Input("InputImage")
.Input("WeightImage") .Input("WeightImage")
.Input("BiasImage") .Input("BiasImage")
.Output("OutputImage") .Output("OutputImage")
.AddIntArg("weight_type", kernels::BufferType::WEIGHT_HEIGHT) .AddIntArg("weight_type", kernels::BufferType::WEIGHT_HEIGHT)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
...@@ -52,11 +52,11 @@ void Simple(const std::vector<index_t> &input_shape, ...@@ -52,11 +52,11 @@ void Simple(const std::vector<index_t> &input_shape,
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
} else { } else {
OpDefBuilder("FC", "FullyConnectedTest") OpDefBuilder("FC", "FullyConnectedTest")
.Input("Input") .Input("Input")
.Input("Weight") .Input("Weight")
.Input("Bias") .Input("Bias")
.Output("Output") .Output("Output")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
} }
...@@ -72,14 +72,14 @@ TEST_F(FullyConnectedOpTest, SimpleCPU) { ...@@ -72,14 +72,14 @@ TEST_F(FullyConnectedOpTest, SimpleCPU) {
{1, 2, 3, 4, 5, 6, 7, 8}, {1}, {2}, {1, 1, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8}, {1}, {2}, {1, 1, 1, 1},
{206}); {206});
Simple<DeviceType::CPU>( Simple<DeviceType::CPU>(
{1, 1, 2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {2, 10}, {1, 1, 2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {2, 10},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100},
{2}, {2, 3}, {1, 1, 1, 2}, {387, 3853}); {2}, {2, 3}, {1, 1, 1, 2}, {387, 3853});
Simple<DeviceType::CPU>( Simple<DeviceType::CPU>(
{1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {5, 6}, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {5, 6},
{1, 2, 3, 4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3, {1, 2, 3, 4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3,
4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3, 4, 5, 6}, 4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3, 4, 5, 6},
{5}, {1, 2, 3, 4, 5}, {1, 1, 1, 5}, {92, 912, 94, 914, 96}); {5}, {1, 2, 3, 4, 5}, {1, 1, 1, 5}, {92, 912, 94, 914, 96});
} }
TEST_F(FullyConnectedOpTest, SimpleCPUWithBatch) { TEST_F(FullyConnectedOpTest, SimpleCPUWithBatch) {
...@@ -92,14 +92,14 @@ TEST_F(FullyConnectedOpTest, SimpleOPENCL) { ...@@ -92,14 +92,14 @@ TEST_F(FullyConnectedOpTest, SimpleOPENCL) {
{1, 2, 3, 4, 5, 6, 7, 8}, {1}, {2}, {1, 1, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8}, {1}, {2}, {1, 1, 1, 1},
{206}); {206});
Simple<DeviceType::OPENCL>( Simple<DeviceType::OPENCL>(
{1, 1, 2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {2, 10}, {1, 1, 2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {2, 10},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100},
{2}, {2, 3}, {1, 1, 1, 2}, {387, 3853}); {2}, {2, 3}, {1, 1, 1, 2}, {387, 3853});
Simple<DeviceType::OPENCL>( Simple<DeviceType::OPENCL>(
{1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {5, 6}, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {5, 6},
{1, 2, 3, 4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3, {1, 2, 3, 4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3,
4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3, 4, 5, 6}, 4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3, 4, 5, 6},
{5}, {1, 2, 3, 4, 5}, {1, 1, 1, 5}, {92, 912, 94, 914, 96}); {5}, {1, 2, 3, 4, 5}, {1, 1, 1, 5}, {92, 912, 94, 914, 96});
} }
TEST_F(FullyConnectedOpTest, SimpleGPUWithBatch) { TEST_F(FullyConnectedOpTest, SimpleGPUWithBatch) {
...@@ -107,7 +107,7 @@ TEST_F(FullyConnectedOpTest, SimpleGPUWithBatch) { ...@@ -107,7 +107,7 @@ TEST_F(FullyConnectedOpTest, SimpleGPUWithBatch) {
{1, 2, 3, 4}, {1}, {2}, {2, 1, 1, 1}, {32, 72}); {1, 2, 3, 4}, {1}, {2}, {2, 1, 1, 1}, {32, 72});
} }
template <typename T> template<typename T>
void Complex(const index_t batch, void Complex(const index_t batch,
const index_t height, const index_t height,
const index_t width, const index_t width,
...@@ -118,17 +118,17 @@ void Complex(const index_t batch, ...@@ -118,17 +118,17 @@ void Complex(const index_t batch,
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
OpDefBuilder("FC", "FullyConnectedTest") OpDefBuilder("FC", "FullyConnectedTest")
.Input("Input") .Input("Input")
.Input("Weight") .Input("Weight")
.Input("Bias") .Input("Bias")
.Output("Output") .Output("Output")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Add input data // Add input data
net.AddRandomInput<DeviceType::OPENCL, float>( net.AddRandomInput<DeviceType::OPENCL, float>(
"Input", {batch, height, width, channels}); "Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::OPENCL, float>( net.AddRandomInput<DeviceType::OPENCL, float>(
"Weight", {out_channel, height * width * channels}); "Weight", {out_channel, height * width * channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Bias", {out_channel}); net.AddRandomInput<DeviceType::OPENCL, float>("Bias", {out_channel});
// run cpu // run cpu
...@@ -147,13 +147,13 @@ void Complex(const index_t batch, ...@@ -147,13 +147,13 @@ void Complex(const index_t batch,
kernels::BufferType::ARGUMENT); kernels::BufferType::ARGUMENT);
OpDefBuilder("FC", "FullyConnectedTest") OpDefBuilder("FC", "FullyConnectedTest")
.Input("InputImage") .Input("InputImage")
.Input("WeightImage") .Input("WeightImage")
.Input("BiasImage") .Input("BiasImage")
.Output("OutputImage") .Output("OutputImage")
.AddIntArg("weight_type", kernels::BufferType::WEIGHT_HEIGHT) .AddIntArg("weight_type", kernels::BufferType::WEIGHT_HEIGHT)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run on opencl // Run on opencl
net.RunOp(DeviceType::OPENCL); net.RunOp(DeviceType::OPENCL);
...@@ -189,7 +189,7 @@ TEST_F(FullyConnectedOpTest, OPENCLHalfUnAlignedWithBatch) { ...@@ -189,7 +189,7 @@ TEST_F(FullyConnectedOpTest, OPENCLHalfUnAlignedWithBatch) {
Complex<half>(31, 21, 11, 23, 103); Complex<half>(31, 21, 11, 23, 103);
} }
template <typename T> template<typename T>
void TestWXFormat(const index_t batch, void TestWXFormat(const index_t batch,
const index_t height, const index_t height,
const index_t width, const index_t width,
...@@ -200,17 +200,17 @@ void TestWXFormat(const index_t batch, ...@@ -200,17 +200,17 @@ void TestWXFormat(const index_t batch,
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
OpDefBuilder("FC", "FullyConnectedTest") OpDefBuilder("FC", "FullyConnectedTest")
.Input("Input") .Input("Input")
.Input("Weight") .Input("Weight")
.Input("Bias") .Input("Bias")
.Output("Output") .Output("Output")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Add input data // Add input data
net.AddRandomInput<DeviceType::OPENCL, float>( net.AddRandomInput<DeviceType::OPENCL, float>(
"Input", {batch, height, width, channels}); "Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::OPENCL, float>( net.AddRandomInput<DeviceType::OPENCL, float>(
"Weight", {out_channel, height * width * channels}); "Weight", {out_channel, height * width * channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Bias", {out_channel}); net.AddRandomInput<DeviceType::OPENCL, float>("Bias", {out_channel});
// run cpu // run cpu
...@@ -229,12 +229,12 @@ void TestWXFormat(const index_t batch, ...@@ -229,12 +229,12 @@ void TestWXFormat(const index_t batch,
kernels::BufferType::ARGUMENT); kernels::BufferType::ARGUMENT);
OpDefBuilder("FC", "FullyConnectedTest") OpDefBuilder("FC", "FullyConnectedTest")
.Input("InputImage") .Input("InputImage")
.Input("WeightImage") .Input("WeightImage")
.Input("BiasImage") .Input("BiasImage")
.Output("OutputImage") .Output("OutputImage")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(DeviceType::OPENCL); net.RunOp(DeviceType::OPENCL);
...@@ -266,6 +266,57 @@ TEST_F(FullyConnectedOpTest, OPENCLHalfWidthFormatAligned) { ...@@ -266,6 +266,57 @@ TEST_F(FullyConnectedOpTest, OPENCLHalfWidthFormatAligned) {
TestWXFormat<half>(1, 16, 32, 32, 32); TestWXFormat<half>(1, 16, 32, 32, 32);
} }
void FullyConnectedTestNEON(const index_t batch,
const index_t height,
const index_t width,
const index_t channels,
const index_t out_channel) {
srand(time(NULL));
// Construct graph
OpsTestNet net;
OpDefBuilder("FC", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::CPU, float>(
"Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::CPU, float>(
"Weight", {out_channel, height * width * channels});
net.AddRandomInput<DeviceType::CPU, float>("Bias", {out_channel});
// run cpu
net.RunOp();
// Run on neon
OpDefBuilder("FC", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("OutputNeon")
.Finalize(net.NewOperatorDef());
// Run on device
net.RunOp(DeviceType::NEON);
net.FillNHWCInputToNCHWInput<DeviceType::CPU, float>("OutputExptected",
"Output");
ExpectTensorNear<float>(*net.GetOutput("OutputExptected"),
*net.GetOutput("OutputNeon"),
0.001);
}
TEST_F(FullyConnectedOpTest, TestNEON) {
FullyConnectedTestNEON(1, 7, 7, 32, 16);
FullyConnectedTestNEON(1, 7, 7, 512, 128);
FullyConnectedTestNEON(1, 1, 1, 2048, 1024);
}
} // namespace test } // namespace test
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册