提交 dd4b992a 编写于 作者: L Liangliang He

Merge branch 'master' into 'master'

Implement fc and eltwise for neon

See merge request !347
//
// Copyright (c) 2018 XiaoMi All rights reserved.
//
#include "mace/kernels/fully_connected.h"
#include "mace/kernels/gemm.h"
namespace mace {
namespace kernels {
void FullyConnectedFunctor<DeviceType::NEON,
float>::operator()(const Tensor *input,
const Tensor *weight,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
std::vector<index_t> output_shape = {input->dim(0), weight->dim(0), 1, 1};
output->Resize(output_shape);
const index_t N = output->dim(0);
const index_t input_size = weight->dim(1);
const index_t output_size = weight->dim(0);
const float *input_ptr = input->data<float>();
const float *weight_ptr = weight->data<float>();
const float *bias_ptr = bias == nullptr ? nullptr : bias->data<float>();
float *output_ptr = output->mutable_data<float>();
for (int i = 0; i < N; ++i) {
Gemm(weight_ptr, input_ptr, 1, output_size, input_size, 1, output_ptr);
for (int j = 0; j < output_size; ++j) {
output_ptr[j] += bias_ptr[j];
}
}
DoActivation(output_ptr, output_ptr, output->size(), activation_,
relux_max_limit_);
}
} // namespace kernels
} // namespace mace
......@@ -76,6 +76,20 @@ struct FullyConnectedFunctor : FullyConnectedBase {
}
};
template <>
struct FullyConnectedFunctor<DeviceType::NEON, float> : FullyConnectedBase {
FullyConnectedFunctor(const BufferType weight_type,
const ActivationType activation,
const float relux_max_limit)
: FullyConnectedBase(weight_type, activation, relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *weight,
const Tensor *bias,
Tensor *output,
StatsFuture *future);
};
template <typename T>
struct FullyConnectedFunctor<DeviceType::OPENCL, T> : FullyConnectedBase {
FullyConnectedFunctor(const BufferType weight_type,
......
......@@ -25,6 +25,11 @@ void Register_Eltwise(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T")
.Build(),
EltwiseOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Eltwise")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
EltwiseOp<DeviceType::NEON, float>);
}
} // namespace ops
......
......@@ -25,6 +25,12 @@ void Register_FullyConnected(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T")
.Build(),
FullyConnectedOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FC")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
FullyConnectedOp<DeviceType::NEON, float>);
}
} // namespace ops
......
......@@ -13,7 +13,7 @@ namespace test {
class FullyConnectedOpTest : public OpsTestBase {};
template <DeviceType D>
template<DeviceType D>
void Simple(const std::vector<index_t> &input_shape,
const std::vector<float> &input_value,
const std::vector<index_t> &weight_shape,
......@@ -38,12 +38,12 @@ void Simple(const std::vector<index_t> &input_shape,
kernels::BufferType::ARGUMENT);
OpDefBuilder("FC", "FullyConnectedTest")
.Input("InputImage")
.Input("WeightImage")
.Input("BiasImage")
.Output("OutputImage")
.AddIntArg("weight_type", kernels::BufferType::WEIGHT_HEIGHT)
.Finalize(net.NewOperatorDef());
.Input("InputImage")
.Input("WeightImage")
.Input("BiasImage")
.Output("OutputImage")
.AddIntArg("weight_type", kernels::BufferType::WEIGHT_HEIGHT)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
......@@ -52,11 +52,11 @@ void Simple(const std::vector<index_t> &input_shape,
kernels::BufferType::IN_OUT_CHANNEL);
} else {
OpDefBuilder("FC", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
}
......@@ -72,14 +72,14 @@ TEST_F(FullyConnectedOpTest, SimpleCPU) {
{1, 2, 3, 4, 5, 6, 7, 8}, {1}, {2}, {1, 1, 1, 1},
{206});
Simple<DeviceType::CPU>(
{1, 1, 2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {2, 10},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100},
{2}, {2, 3}, {1, 1, 1, 2}, {387, 3853});
{1, 1, 2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {2, 10},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100},
{2}, {2, 3}, {1, 1, 1, 2}, {387, 3853});
Simple<DeviceType::CPU>(
{1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {5, 6},
{1, 2, 3, 4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3,
4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3, 4, 5, 6},
{5}, {1, 2, 3, 4, 5}, {1, 1, 1, 5}, {92, 912, 94, 914, 96});
{1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {5, 6},
{1, 2, 3, 4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3,
4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3, 4, 5, 6},
{5}, {1, 2, 3, 4, 5}, {1, 1, 1, 5}, {92, 912, 94, 914, 96});
}
TEST_F(FullyConnectedOpTest, SimpleCPUWithBatch) {
......@@ -92,14 +92,14 @@ TEST_F(FullyConnectedOpTest, SimpleOPENCL) {
{1, 2, 3, 4, 5, 6, 7, 8}, {1}, {2}, {1, 1, 1, 1},
{206});
Simple<DeviceType::OPENCL>(
{1, 1, 2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {2, 10},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100},
{2}, {2, 3}, {1, 1, 1, 2}, {387, 3853});
{1, 1, 2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {2, 10},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100},
{2}, {2, 3}, {1, 1, 1, 2}, {387, 3853});
Simple<DeviceType::OPENCL>(
{1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {5, 6},
{1, 2, 3, 4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3,
4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3, 4, 5, 6},
{5}, {1, 2, 3, 4, 5}, {1, 1, 1, 5}, {92, 912, 94, 914, 96});
{1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {5, 6},
{1, 2, 3, 4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3,
4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3, 4, 5, 6},
{5}, {1, 2, 3, 4, 5}, {1, 1, 1, 5}, {92, 912, 94, 914, 96});
}
TEST_F(FullyConnectedOpTest, SimpleGPUWithBatch) {
......@@ -107,7 +107,7 @@ TEST_F(FullyConnectedOpTest, SimpleGPUWithBatch) {
{1, 2, 3, 4}, {1}, {2}, {2, 1, 1, 1}, {32, 72});
}
template <typename T>
template<typename T>
void Complex(const index_t batch,
const index_t height,
const index_t width,
......@@ -118,17 +118,17 @@ void Complex(const index_t batch,
// Construct graph
OpsTestNet net;
OpDefBuilder("FC", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::OPENCL, float>(
"Input", {batch, height, width, channels});
"Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::OPENCL, float>(
"Weight", {out_channel, height * width * channels});
"Weight", {out_channel, height * width * channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Bias", {out_channel});
// run cpu
......@@ -147,13 +147,13 @@ void Complex(const index_t batch,
kernels::BufferType::ARGUMENT);
OpDefBuilder("FC", "FullyConnectedTest")
.Input("InputImage")
.Input("WeightImage")
.Input("BiasImage")
.Output("OutputImage")
.AddIntArg("weight_type", kernels::BufferType::WEIGHT_HEIGHT)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
.Input("InputImage")
.Input("WeightImage")
.Input("BiasImage")
.Output("OutputImage")
.AddIntArg("weight_type", kernels::BufferType::WEIGHT_HEIGHT)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run on opencl
net.RunOp(DeviceType::OPENCL);
......@@ -189,7 +189,7 @@ TEST_F(FullyConnectedOpTest, OPENCLHalfUnAlignedWithBatch) {
Complex<half>(31, 21, 11, 23, 103);
}
template <typename T>
template<typename T>
void TestWXFormat(const index_t batch,
const index_t height,
const index_t width,
......@@ -200,17 +200,17 @@ void TestWXFormat(const index_t batch,
// Construct graph
OpsTestNet net;
OpDefBuilder("FC", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::OPENCL, float>(
"Input", {batch, height, width, channels});
"Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::OPENCL, float>(
"Weight", {out_channel, height * width * channels});
"Weight", {out_channel, height * width * channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Bias", {out_channel});
// run cpu
......@@ -229,12 +229,12 @@ void TestWXFormat(const index_t batch,
kernels::BufferType::ARGUMENT);
OpDefBuilder("FC", "FullyConnectedTest")
.Input("InputImage")
.Input("WeightImage")
.Input("BiasImage")
.Output("OutputImage")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
.Input("InputImage")
.Input("WeightImage")
.Input("BiasImage")
.Output("OutputImage")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(DeviceType::OPENCL);
......@@ -266,6 +266,57 @@ TEST_F(FullyConnectedOpTest, OPENCLHalfWidthFormatAligned) {
TestWXFormat<half>(1, 16, 32, 32, 32);
}
void FullyConnectedTestNEON(const index_t batch,
const index_t height,
const index_t width,
const index_t channels,
const index_t out_channel) {
srand(time(NULL));
// Construct graph
OpsTestNet net;
OpDefBuilder("FC", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::CPU, float>(
"Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::CPU, float>(
"Weight", {out_channel, height * width * channels});
net.AddRandomInput<DeviceType::CPU, float>("Bias", {out_channel});
// run cpu
net.RunOp();
// Run on neon
OpDefBuilder("FC", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("OutputNeon")
.Finalize(net.NewOperatorDef());
// Run on device
net.RunOp(DeviceType::NEON);
net.FillNHWCInputToNCHWInput<DeviceType::CPU, float>("OutputExptected",
"Output");
ExpectTensorNear<float>(*net.GetOutput("OutputExptected"),
*net.GetOutput("OutputNeon"),
0.001);
}
TEST_F(FullyConnectedOpTest, TestNEON) {
FullyConnectedTestNEON(1, 7, 7, 32, 16);
FullyConnectedTestNEON(1, 7, 7, 512, 128);
FullyConnectedTestNEON(1, 1, 1, 2048, 1024);
}
} // namespace test
} // namespace ops
} // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册