提交 963de04e 编写于 作者: 李寅

Merge branch 'add_op_lrn' into 'master'

add op : local response norm

See merge request !394
......@@ -92,6 +92,7 @@ extern void Register_FullyConnected(OperatorRegistry *op_registry);
extern void Register_FusedConv2D(OperatorRegistry *op_registry);
extern void Register_GlobalAvgPooling(OperatorRegistry *op_registry);
extern void Register_ImageToBuffer(OperatorRegistry *op_registry);
extern void Register_LocalResponseNorm(OperatorRegistry *op_registry);
extern void Register_MatMul(OperatorRegistry *op_registry);
extern void Register_Pad(OperatorRegistry *op_registry);
extern void Register_Pooling(OperatorRegistry *op_registry);
......@@ -129,6 +130,7 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_FusedConv2D(this);
ops::Register_GlobalAvgPooling(this);
ops::Register_ImageToBuffer(this);
ops::Register_LocalResponseNorm(this);
ops::Register_MatMul(this);
ops::Register_Pad(this);
ops::Register_Pooling(this);
......
//
// Copyright (c) 2018 XiaoMi All rights reserved.
//
#include "mace/kernels/local_response_norm.h"
namespace mace {
namespace kernels {
void LocalResponseNormFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input,
int depth_radius,
float bias,
float alpha,
float beta,
Tensor *output,
StatsFuture *future) {
const index_t batch = input->dim(0);
const index_t channels = input->dim(1);
const index_t height = input->dim(2);
const index_t width = input->dim(3);
const float *input_ptr = input->data<float>();
float *output_ptr = output->mutable_data<float>();
index_t image_size = height * width;
index_t batch_size = channels * image_size;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
const int begin_input_c = std::max(static_cast<index_t>(0),
c - depth_radius);
const int end_input_c = std::min(channels, c + depth_radius + 1);
index_t pos = b * batch_size;
for (index_t hw = 0; hw < height * width; ++hw, ++pos) {
float accum = 0.f;
for (int input_c = begin_input_c; input_c < end_input_c; ++input_c) {
const float input_val = input_ptr[pos + input_c * image_size];
accum += input_val * input_val;
}
const float multiplier = std::pow(bias + alpha * accum, -beta);
output_ptr[pos + c * image_size] =
input_ptr[pos + c * image_size] * multiplier;
}
}
}
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_LOCAL_RESPONSE_NORM_H_
#define MACE_KERNELS_LOCAL_RESPONSE_NORM_H_
#include <algorithm>
#include <memory>
#include <vector>
#include "mace/core/future.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/tensor.h"
#include "mace/public/mace.h"
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct LocalResponseNormFunctor {
void operator()(const Tensor *input,
int depth_radius,
float bias,
float alpha,
float beta,
Tensor *output,
StatsFuture *future) {
const index_t batch = input->dim(0);
const index_t height = input->dim(1);
const index_t width = input->dim(2);
const index_t channels = input->dim(3);
Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard output_mapper(output);
const T *input_ptr = input->data<T>();
T *output_ptr = output->mutable_data<T>();
const int elements = batch * height * width;
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < elements; ++i) {
for (index_t c = 0; c < channels; ++c) {
const int begin_input_c = std::max(static_cast<index_t>(0),
c - depth_radius);
const int end_input_c = std::min(channels, c + depth_radius + 1);
index_t pos = i * channels;
float accum = 0.f;
for (int input_c = begin_input_c; input_c < end_input_c; ++input_c) {
const float input_val = input_ptr[pos + input_c];
accum += input_val * input_val;
}
const float multiplier = std::pow(bias + alpha * accum, -beta);
output_ptr[pos + c] = input_ptr[pos + c] * multiplier;
}
}
}
};
template <>
struct LocalResponseNormFunctor<DeviceType::NEON, float> {
void operator()(const Tensor *input,
int depth_radius,
float bias,
float alpha,
float beta,
Tensor *output,
StatsFuture *future);
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_LOCAL_RESPONSE_NORM_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/local_response_norm.h"
namespace mace {
namespace ops {
void Register_LocalResponseNorm(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("LocalResponseNorm")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
LocalResponseNormOp<DeviceType::CPU, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("LocalResponseNorm")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
LocalResponseNormOp<DeviceType::NEON, float>);
}
} // namespace ops
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPS_LOCAL_RESPONSE_NORM_H_
#define MACE_OPS_LOCAL_RESPONSE_NORM_H_
#include "mace/core/operator.h"
#include "mace/kernels/local_response_norm.h"
namespace mace {
namespace ops {
template <DeviceType D, class T>
class LocalResponseNormOp : public Operator<D, T> {
public:
LocalResponseNormOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_() {
depth_radius_ = OperatorBase::GetSingleArgument<int>("depth_radius", 5);
bias_ = OperatorBase::GetSingleArgument<float>("bias", 1.0f);
alpha_ = OperatorBase::GetSingleArgument<float>("alpha", 1.0f);
beta_ = OperatorBase::GetSingleArgument<float>("beta", 0.5f);
}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional. ",
input->dim_size());
Tensor *output = this->Output(OUTPUT);
output->ResizeLike(input);
functor_(input, depth_radius_, bias_, alpha_, beta_, output, future);
return true;
}
private:
int depth_radius_;
float bias_;
float alpha_;
float beta_;
kernels::LocalResponseNormFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT);
OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_LOCAL_RESPONSE_NORM_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
template <DeviceType D, typename T>
static void LocalResponseNorm(
int iters, int batch, int channels, int height, int width) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, channels});
OpDefBuilder("LocalResponseNorm", "LocalResponseNormBM")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
// tuning
setenv("MACE_TUNING", "1", 1);
net.RunOp(D);
unsetenv("MACE_TUNING");
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
#define BM_LOCAL_RESPONSE_NORM_MACRO(N, C, H, W, TYPE, DEVICE) \
static void BM_LOCAL_RESPONSE_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE(\
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
LocalResponseNorm<DEVICE, TYPE>(iters, N, C, H, W); \
} \
BENCHMARK(BM_LOCAL_RESPONSE_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define BM_LOCAL_RESPONSE_NORM(N, C, H, W) \
BM_LOCAL_RESPONSE_NORM_MACRO(N, C, H, W, float, CPU); \
BM_LOCAL_RESPONSE_NORM_MACRO(N, C, H, W, float, NEON);
BM_LOCAL_RESPONSE_NORM(1, 1, 512, 512);
BM_LOCAL_RESPONSE_NORM(1, 3, 128, 128);
BM_LOCAL_RESPONSE_NORM(1, 3, 512, 512);
BM_LOCAL_RESPONSE_NORM(1, 32, 112, 112);
BM_LOCAL_RESPONSE_NORM(1, 64, 256, 256);
BM_LOCAL_RESPONSE_NORM(1, 64, 512, 512);
BM_LOCAL_RESPONSE_NORM(1, 128, 56, 56);
BM_LOCAL_RESPONSE_NORM(1, 128, 256, 256);
BM_LOCAL_RESPONSE_NORM(1, 256, 14, 14);
BM_LOCAL_RESPONSE_NORM(1, 512, 14, 14);
BM_LOCAL_RESPONSE_NORM(1, 1024, 7, 7);
BM_LOCAL_RESPONSE_NORM(32, 1, 256, 256);
BM_LOCAL_RESPONSE_NORM(32, 3, 256, 256);
} // namespace test
} // namespace ops
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class LocalResponseNormOpTest : public OpsTestBase {};
template<DeviceType D>
void Simple() {
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>("Input", {1, 1, 2, 6},
{5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15});
OpDefBuilder("LocalResponseNorm", "LocalResponseNormTest")
.Input("Input")
.AddIntArg("depth_radius", 5)
.AddFloatArg("bias", 1.0f)
.AddFloatArg("alpha", 1.0f)
.AddFloatArg("beta", 0.5f)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
// Check
auto expected =
CreateTensor<float>({1, 1, 2, 6}, {0.28, 0.28, 0.39, 0.39, 0.51, 0.51,
0.34, 0.34, 0.40, 0.40, 0.47, 0.47});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0, 1e-2);
}
TEST_F(LocalResponseNormOpTest, SimpleCPU) { Simple<DeviceType::CPU>(); }
TEST_F(LocalResponseNormOpTest, NEONTest) {
srand(time(NULL));
unsigned int seed;
// generate random input
index_t batch = 1 + rand_r(&seed) % 10;
index_t channels = 3 + rand_r(&seed) % 50;
index_t height = 64;
index_t width = 64;
// Construct graph
OpsTestNet net;
OpDefBuilder("LocalResponseNorm", "LocalResponseNormTest")
.Input("Input")
.AddIntArg("depth_radius", 5)
.AddFloatArg("bias", 1.0f)
.AddFloatArg("alpha", 1.0f)
.AddFloatArg("beta", 0.5f)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::CPU, float>(
"Input", {batch, height, width, channels});
// run cpu
net.RunOp();
OpDefBuilder("LocalResponseNorm", "LocalResponseNormTest")
.Input("InputNeon")
.AddIntArg("depth_radius", 5)
.AddFloatArg("bias", 1.0f)
.AddFloatArg("alpha", 1.0f)
.AddFloatArg("beta", 0.5f)
.Output("OutputNeon")
.Finalize(net.NewOperatorDef());
net.FillNHWCInputToNCHWInput<DeviceType::CPU, float>("InputNeon", "Input");
// Run on neon
net.RunOp(DeviceType::NEON);
net.Sync();
net.FillNHWCInputToNCHWInput<DeviceType::CPU, float>("OutputExpected",
"Output");
ExpectTensorNear<float>(*net.GetOutput("OutputExpected"),
*net.GetOutput("OutputNeon"),
0, 0.001);
}
} // namespace test
} // namespace ops
} // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册