提交 cf7b0ec7 编写于 作者: L Liangliang He

Merge branch 'fused-bn' into 'master'

Add folded batch norm to combine batchnorm and relu.

See merge request !199
......@@ -76,6 +76,7 @@ extern void Register_Relu(OperatorRegistry *op_registry);
extern void Register_ResizeBilinear(OperatorRegistry *op_registry);
extern void Register_SpaceToBatchND(OperatorRegistry *op_registry);
extern void Register_Softmax(OperatorRegistry *op_registry);
extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry);
OperatorRegistry::OperatorRegistry() {
Register_AddN(this);
......@@ -95,6 +96,7 @@ OperatorRegistry::OperatorRegistry() {
Register_ResizeBilinear(this);
Register_SpaceToBatchND(this);
Register_Softmax(this);
Register_FoldedBatchNorm(this);
}
} // namespace mace
......@@ -12,15 +12,26 @@
namespace mace {
namespace kernels {
struct BatchNormFunctorBase {
BatchNormFunctorBase(bool folded_constant, bool fused_relu) :
folded_constant_(folded_constant),
fused_relu_(fused_relu){}
const bool folded_constant_;
const bool fused_relu_;
};
template <DeviceType D, typename T>
struct BatchNormFunctor {
float epsilon_;
struct BatchNormFunctor : BatchNormFunctorBase{
BatchNormFunctor(const bool folded_constant, const bool fused_relu) :
BatchNormFunctorBase(folded_constant, fused_relu) {}
void operator()(const Tensor *input,
const Tensor *scale,
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const float epsilon,
Tensor *output,
StatsFuture *future) {
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
......@@ -39,25 +50,28 @@ struct BatchNormFunctor {
Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard scale_mapper(scale);
Tensor::MappingGuard offset_mapper(offset);
Tensor::MappingGuard mean_mapper(mean);
Tensor::MappingGuard var_mapper(var);
Tensor::MappingGuard output_mapper(output);
const T *input_ptr = input->data<T>();
const T *scale_ptr = scale->data<T>();
const T *offset_ptr = offset->data<T>();
const T *mean_ptr = mean->data<T>();
const T *var_ptr = var->data<T>();
T *output_ptr = output->mutable_data<T>();
vector<T> new_scale(channels);
vector<T> new_offset(channels);
vector<T> new_scale;
vector<T> new_offset;
if (!folded_constant_) {
new_scale.resize(channels);
new_offset.resize(channels);
Tensor::MappingGuard mean_mapper(mean);
Tensor::MappingGuard var_mapper(var);
const T *mean_ptr = mean->data<T>();
const T *var_ptr = var->data<T>();
#pragma omp parallel for
for (index_t c = 0; c < channels; ++c) {
new_scale[c] = scale_ptr[c] / std::sqrt(var_ptr[c] + epsilon_);
new_scale[c] = scale_ptr[c] / std::sqrt(var_ptr[c] + epsilon);
new_offset[c] = offset_ptr[c] - mean_ptr[c] * new_scale[c];
}
}
index_t pos = 0;
......@@ -66,7 +80,14 @@ struct BatchNormFunctor {
for (index_t h = 0; h < height; ++h) {
for (index_t w = 0; w < width; ++w) {
for (index_t c = 0; c < channels; ++c) {
if (folded_constant_) {
output_ptr[pos] = scale_ptr[c] * input_ptr[pos] + offset_ptr[c];
} else {
output_ptr[pos] = new_scale[c] * input_ptr[pos] + new_offset[c];
}
if (fused_relu_) {
output_ptr[pos] = std::max(output_ptr[pos], static_cast<T>(0));
}
++pos;
}
}
......@@ -82,18 +103,20 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const float epsilon,
Tensor *output,
StatsFuture *future);
template <typename T>
struct BatchNormFunctor<DeviceType::OPENCL, T> {
float epsilon_;
struct BatchNormFunctor<DeviceType::OPENCL, T> : BatchNormFunctorBase {
BatchNormFunctor(const bool folded_constant, const bool fused_relu) :
BatchNormFunctorBase(folded_constant, fused_relu) {}
void operator()(const Tensor *input,
const Tensor *scale,
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const float epsilon,
Tensor *output,
StatsFuture *future);
};
......
......@@ -19,8 +19,11 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const float epsilon,
Tensor *output,
StatsFuture *future) {
MACE_CHECK(folded_constant_ || (mean != nullptr && var != nullptr));
const index_t batch = input->dim(0);
const index_t height = input->dim(1);
const index_t width = input->dim(2);
......@@ -33,15 +36,23 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(
auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
if (folded_constant_) {
built_options.emplace("-DFOLDED_CONSTANT");
}
if (fused_relu_) {
built_options.emplace("-DFUSED_RELU");
}
auto bm_kernel = runtime->BuildKernel("batch_norm", "batch_norm", built_options);
uint32_t idx = 0;
bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(scale->buffer())));
bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(offset->buffer())));
if (!folded_constant_) {
bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(mean->buffer())));
bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(var->buffer())));
bm_kernel.setArg(idx++, epsilon_);
bm_kernel.setArg(idx++, epsilon);
}
bm_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
......@@ -89,7 +100,8 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(
<< output->dim(0) << "_"
<< output->dim(1) << "_"
<< output->dim(2) << "_"
<< output->dim(3);
<< output->dim(3) << "_"
<< folded_constant_;
OpenCLProfilingTimer timer(&event);
Tuner<uint32_t>::Get()->template TuneOrRun<cl_int>(ss.str(),
lws,
......
......@@ -3,27 +3,39 @@
__kernel void batch_norm(__read_only image2d_t input,
__read_only image2d_t scale,
__read_only image2d_t offset,
#ifndef FOLDED_CONSTANT
__read_only image2d_t mean,
__read_only image2d_t var,
__private const float epsilon,
#endif
__write_only image2d_t output) {
const int ch_blk = get_global_id(0);
const int w = get_global_id(1);
const int hb = get_global_id(2);
const int width = get_global_size(1);
#ifdef FOLDED_CONSTANT
DATA_TYPE4 bn_scale = READ_IMAGET(scale, SAMPLER, (int2)(ch_blk, 0));
DATA_TYPE4 bn_offset = READ_IMAGET(offset, SAMPLER, (int2)(ch_blk, 0));
#else
DATA_TYPE4 scale_value = READ_IMAGET(scale, SAMPLER, (int2)(ch_blk, 0));
DATA_TYPE4 offset_value = READ_IMAGET(offset, SAMPLER, (int2)(ch_blk, 0));
DATA_TYPE4 mean_value = READ_IMAGET(mean, SAMPLER, (int2)(ch_blk, 0));
DATA_TYPE4 var_value = READ_IMAGET(var, SAMPLER, (int2)(ch_blk, 0));
// native_rsqrt seems not faster than rsqrt
DATA_TYPE4 new_scale = scale_value * rsqrt(var_value + (DATA_TYPE4)epsilon);
DATA_TYPE4 new_offset = mad(0 - mean_value, new_scale, offset_value);
DATA_TYPE4 bn_scale = scale_value * rsqrt(var_value + (DATA_TYPE4)epsilon);
DATA_TYPE4 bn_offset = mad(0 - mean_value, bn_scale, offset_value);
#endif
const int pos = mad24(ch_blk, width, w);
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb));
DATA_TYPE4 out = mad(in, new_scale, new_offset);
DATA_TYPE4 out = mad(in, bn_scale, bn_offset);
#ifdef FUSED_RELU
out = fmax(out, 0);
#endif
WRITE_IMAGET(output, (int2)(pos, hb), out);
}
......@@ -2,8 +2,8 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_BATCH_NORM_H_
#define MACE_BATCH_NORM_H_
#ifndef MACE_OPS_BATCH_NORM_H_
#define MACE_OPS_BATCH_NORM_H_
#include "mace/core/operator.h"
#include "mace/kernels/batch_norm.h"
......@@ -14,8 +14,8 @@ template <DeviceType D, class T>
class BatchNormOp : public Operator<D, T> {
public:
BatchNormOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws), functor_() {
functor_.epsilon_ =
: Operator<D, T>(operator_def, ws), functor_(false, false) {
epsilon_ =
OperatorBase::GetSingleArgument<float>("epsilon", static_cast<float>(1e-4));
}
......@@ -40,11 +40,12 @@ class BatchNormOp : public Operator<D, T> {
Tensor *output = this->Output(OUTPUT);
output->ResizeLike(input);
functor_(input, scale, offset, mean, var, output, future);
functor_(input, scale, offset, mean, var, epsilon_, output, future);
return true;
}
private:
float epsilon_;
kernels::BatchNormFunctor<D, T> functor_;
protected:
......@@ -54,4 +55,4 @@ class BatchNormOp : public Operator<D, T> {
} // namespace mace
#endif // MACE_BATCH_NORM_H_
#endif // MACE_OPS_BATCH_NORM_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/folded_batch_norm.h"
namespace mace {
void Register_FoldedBatchNorm(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FoldedBatchNorm")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
FoldedBatchNormOp<DeviceType::CPU, float>);
#if MACE_ENABLE_NEON
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FoldedBatchNorm")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
FoldedBatchNormOp<DeviceType::NEON, float>);
#endif // MACE_ENABLE_NEON
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FoldedBatchNorm")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
.Build(),
FoldedBatchNormOp<DeviceType::OPENCL, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FoldedBatchNorm")
.Device(DeviceType::OPENCL)
.TypeConstraint<half>("T")
.Build(),
FoldedBatchNormOp<DeviceType::OPENCL, half>);
}
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPS_FOLDED_BATCH_NORM_H_
#define MACE_OPS_FOLDED_BATCH_NORM_H_
#include "mace/core/operator.h"
#include "mace/kernels/batch_norm.h"
namespace mace {
template <DeviceType D, class T>
class FoldedBatchNormOp : public Operator<D, T> {
public:
FoldedBatchNormOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(true, OperatorBase::GetSingleArgument<bool>("fused_relu", false)) {
}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *scale = this->Input(SCALE);
const Tensor *offset = this->Input(OFFSET);
MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional. ",
input->dim_size());
MACE_CHECK(scale->dim_size() == 1, "scale must be 1-dimensional. ",
scale->dim_size());
MACE_CHECK(offset->dim_size() == 1, "offset must be 1-dimensional. ",
offset->dim_size());
Tensor *output = this->Output(OUTPUT);
output->ResizeLike(input);
functor_(input, scale, offset, nullptr, nullptr, 0, output, future);
return true;
}
private:
kernels::BatchNormFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT, SCALE, OFFSET, MEAN, VAR);
OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace mace
#endif // MACE_OPS_FOLDED_BATCH_NORM_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
class FoldedBatchNormOpTest : public OpsTestBase {};
void CalculateScaleOffset(const std::vector<float> &gamma,
const std::vector<float> &beta,
const std::vector<float> &mean,
const std::vector<float> &var,
const float epsilon,
std::vector<float> &scale,
std::vector<float> &offset) {
size_t size = gamma.size();
for (int i = 0 ; i < size; ++i) {
scale[i] = gamma[i] / std::sqrt(var[i] + epsilon);
offset[i] = offset[i] - mean[i] * scale[i];
}
}
template <DeviceType D>
void Simple() {
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>("Input", {1, 6, 2, 1},
{5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15});
std::vector<float> scale(1);
std::vector<float> offset(1);
CalculateScaleOffset({4.0f}, {2.0}, {10}, {11.67f}, 1e-3, scale, offset);
net.AddInputFromArray<D, float>("Scale", {1}, scale);
net.AddInputFromArray<D, float>("Offset", {1}, offset);
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
BufferToImage<D, float>(net, "Scale", "ScaleImage",
kernels::BufferType::ARGUMENT);
BufferToImage<D, float>(net, "Offset", "OffsetImage",
kernels::BufferType::ARGUMENT);
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("InputImage")
.Input("ScaleImage")
.Input("OffsetImage")
.Output("OutputImage")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
// Transfer output
ImageToBuffer<D, float>(net, "OutputImage", "Output",
kernels::BufferType::IN_OUT);
} else {
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
}
// Check
auto expected =
CreateTensor<float>({1, 6, 2, 1}, {-3.86, -3.86, -1.51, -1.51, 0.83, 0.83,
3.17, 3.17, 5.51, 5.51, 7.86, 7.86});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-2);
}
TEST_F(FoldedBatchNormOpTest, SimpleCPU) { Simple<DeviceType::CPU>(); }
/*
TEST_F(FoldedBatchNormOpTest, SimpleNEON) {
Simple<DeviceType::NEON>();
}
*/
TEST_F(FoldedBatchNormOpTest, SimpleOPENCL) { Simple<DeviceType::OPENCL>(); }
/*
TEST_F(FoldedBatchNormOpTest, SimpleRandomNeon) {
srand(time(NULL));
// generate random input
index_t batch = 1 + rand() % 10;
index_t channels = 3 + rand() % 50;
index_t height = 64;
index_t width = 64;
// Construct graph
OpsTestNet net;
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.Input("Epsilon")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::CPU, float>("Input", {batch, channels, height,
width});
net.AddRandomInput<DeviceType::CPU, float>("Scale", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Offset", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Mean", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Var", {channels}, true);
net.AddInputFromArray<DeviceType::CPU, float>("Epsilon", {}, {1e-3});
// run cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// Run NEON
net.RunOp(DeviceType::NEON);
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 1e-2);
}
TEST_F(FoldedBatchNormOpTest, ComplexRandomNeon) {
srand(time(NULL));
// generate random input
index_t batch = 1 + rand() % 10;
index_t channels = 3 + rand() % 50;
index_t height = 103;
index_t width = 113;
// Construct graph
OpsTestNet net;
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.Input("Epsilon")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::CPU, float>("Input", {batch, channels, height,
width});
net.AddRandomInput<DeviceType::CPU, float>("Scale", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Offset", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Mean", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Var", {channels}, true);
net.AddInputFromArray<DeviceType::CPU, float>("Epsilon", {}, {1e-3});
// run cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// Run NEON
net.RunOp(DeviceType::NEON);
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 1e-2);
}
*/
TEST_F(FoldedBatchNormOpTest, SimpleRandomOPENCL) {
srand(time(NULL));
// generate random input
index_t batch = 1 + rand() % 10;
index_t channels = 3 + rand() % 50;
index_t height = 64;
index_t width = 64;
// Construct graph
OpsTestNet net;
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::OPENCL, float>(
"Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Scale", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Offset", {channels});
// run cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// Run on opencl
BufferToImage<DeviceType::OPENCL, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
BufferToImage<DeviceType::OPENCL, float>(net, "Scale", "ScaleImage",
kernels::BufferType::ARGUMENT);
BufferToImage<DeviceType::OPENCL, float>(net, "Offset", "OffsetImage",
kernels::BufferType::ARGUMENT);
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("InputImage")
.Input("ScaleImage")
.Input("OffsetImage")
.Output("OutputImage")
.Finalize(net.NewOperatorDef());
// Run on opencl
net.RunOp(DeviceType::OPENCL);
net.Sync();
ImageToBuffer<DeviceType::OPENCL, float>(net, "OutputImage", "OPENCLOutput",
kernels::BufferType::IN_OUT);
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 1e-2);
}
TEST_F(FoldedBatchNormOpTest, SimpleRandomHalfOPENCL) {
srand(time(NULL));
// generate random input
index_t batch = 1 + rand() % 10;
index_t channels = 3 + rand() % 50;
index_t height = 64;
index_t width = 64;
// Construct graph
OpsTestNet net;
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::OPENCL, float>(
"Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Scale", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Offset", {channels});
// run cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// Run on opencl
BufferToImage<DeviceType::OPENCL, half>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
BufferToImage<DeviceType::OPENCL, half>(net, "Scale", "ScaleImage",
kernels::BufferType::ARGUMENT);
BufferToImage<DeviceType::OPENCL, half>(net, "Offset", "OffsetImage",
kernels::BufferType::ARGUMENT);
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("InputImage")
.Input("ScaleImage")
.Input("OffsetImage")
.Output("OutputImage")
.AddIntArg("T", static_cast<int>(DataType::DT_HALF))
.Finalize(net.NewOperatorDef());
// Run on opencl
net.RunOp(DeviceType::OPENCL);
net.Sync();
ImageToBuffer<DeviceType::OPENCL, float>(net, "OutputImage", "OPENCLOutput",
kernels::BufferType::IN_OUT);
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 0.5);
}
TEST_F(FoldedBatchNormOpTest, ComplexRandomOPENCL) {
srand(time(NULL));
// generate random input
index_t batch = 1 + rand() % 10;
index_t channels = 3 + rand() % 50;
index_t height = 103;
index_t width = 113;
// Construct graph
OpsTestNet net;
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::OPENCL, float>(
"Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Scale", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Offset", {channels});
// run cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// Run on opencl
BufferToImage<DeviceType::OPENCL, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
BufferToImage<DeviceType::OPENCL, float>(net, "Scale", "ScaleImage",
kernels::BufferType::ARGUMENT);
BufferToImage<DeviceType::OPENCL, float>(net, "Offset", "OffsetImage",
kernels::BufferType::ARGUMENT);
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("InputImage")
.Input("ScaleImage")
.Input("OffsetImage")
.Output("OutputImage")
.Finalize(net.NewOperatorDef());
// Run on opencl
net.RunOp(DeviceType::OPENCL);
ImageToBuffer<DeviceType::OPENCL, float>(net, "OutputImage", "OPENCLOutput",
kernels::BufferType::IN_OUT);
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 1e-2);
}
TEST_F(FoldedBatchNormOpTest, ComplexRandomHalfOPENCL) {
srand(time(NULL));
// generate random input
index_t batch = 1 + rand() % 10;
index_t channels = 3 + rand() % 50;
index_t height = 103;
index_t width = 113;
// Construct graph
OpsTestNet net;
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::OPENCL, float>(
"Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Scale", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Offset", {channels});
// run cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// Run on opencl
BufferToImage<DeviceType::OPENCL, half>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
BufferToImage<DeviceType::OPENCL, half>(net, "Scale", "ScaleImage",
kernels::BufferType::ARGUMENT);
BufferToImage<DeviceType::OPENCL, half>(net, "Offset", "OffsetImage",
kernels::BufferType::ARGUMENT);
OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest")
.Input("InputImage")
.Input("ScaleImage")
.Input("OffsetImage")
.Output("OutputImage")
.AddIntArg("T", static_cast<int>(DataType::DT_HALF))
.Finalize(net.NewOperatorDef());
// Run on opencl
net.RunOp(DeviceType::OPENCL);
ImageToBuffer<DeviceType::OPENCL, float>(net, "OutputImage", "OPENCLOutput",
kernels::BufferType::IN_OUT);
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 0.5);
}
}
......@@ -39,6 +39,13 @@ def main(unused_args):
f.write(str(output_graph_def))
print("Model conversion is completed.")
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def parse_args():
"""Parses command line arguments."""
......@@ -91,7 +98,9 @@ def parse_args():
help="template path")
parser.add_argument(
"--confuse",
type=bool,
type=str2bool,
nargs='?',
const=False,
default=False,
help="confuse model names")
parser.add_argument(
......
from mace.proto import mace_pb2
import tensorflow as tf
import numpy as np
import math
from mace.python.tools import memory_optimizer
# TODO: support NCHW formt, now only support NHWC.
......@@ -136,6 +137,22 @@ class TFConverter(object):
output_shapes.append(output_shape)
op.output_shape.extend(output_shapes)
def add_tensor(self, name, shape, tf_dt, value):
tensor = self.net_def.tensors.add()
tensor.name = name
shape = list(shape)
tensor.dims.extend(shape)
if tf_dt == tf.float32:
tensor.data_type = mace_pb2.DT_FLOAT
tensor.float_data.extend(value.flat)
elif tf_dt == tf.int32:
tensor.data_type = mace_pb2.DT_INT32
tensor.int32_data.extend(value.flat)
else:
raise Exception("Not supported tensor type: " + tf_dt.name)
def convert_tensor(self, op):
if op.outputs[0].name not in self.unused_tensor:
tensor = self.net_def.tensors.add()
......@@ -211,26 +228,58 @@ class TFConverter(object):
arg = op_def.arg.add()
arg.name = 'T'
arg.i = self.dt
data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format'
data_format_arg.s = 'NHWC'
op_def.name = op.name
op_def.type = 'BatchNorm'
op_def.type = 'FoldedBatchNorm'
gamma_tensor = get_input_tensor(op, 1)
for i in range(1, 5):
input_tensor = get_input_tensor(op, i)
assert input_tensor.shape == gamma_tensor.shape
self.unused_tensor.add(input_tensor.name)
gamma_value = get_input_tensor(op, 1).eval().astype(np.float32)
beta_value = get_input_tensor(op, 2).eval().astype(np.float32)
mean_value = get_input_tensor(op, 3).eval().astype(np.float32)
var_value = get_input_tensor(op, 4).eval().astype(np.float32)
epsilon_value = op.get_attr('epsilon')
scale_value = (
(1.0 / np.vectorize(math.sqrt)(var_value + epsilon_value)) *
gamma_value)
offset_value = (-mean_value * scale_value) + beta_value
idx = gamma_tensor.name.rfind('/')
name_prefix = gamma_tensor.name[:idx] + '/'
input_names = [name_prefix+'scale:0', name_prefix+'offset:0']
self.add_tensor(input_names[0], gamma_value.shape,
gamma_tensor.dtype, scale_value)
self.add_tensor(input_names[1], gamma_value.shape,
gamma_tensor.dtype, offset_value)
if self.device == 'gpu':
op_def.input.extend([op.inputs[0].name])
for i in range(1, len(op.inputs)):
output_name = self.add_buffer_to_image(op.inputs[i].name, "ARGUMENT")
for name in input_names:
output_name = self.add_buffer_to_image(name, "ARGUMENT")
op_def.input.extend([output_name])
else:
op_def.input.extend([input.name for input in op.inputs])
op_def.output.extend([op.outputs[0].name])
self.add_output_shape(op.outputs, op_def)
op_def.input.extend([input.name for input in input_names])
epsilon_arg = op_def.arg.add()
epsilon_arg.name = 'epsilon'
epsilon_arg.f = op.get_attr('epsilon')
data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format'
data_format_arg.s = 'NHWC'
self.resolved_ops[op.name] = 1
final_op = op
if len(self.tf_graph[op.name]) == 1 and self.tf_graph[op.name][0].type == 'Relu':
relu_op = self.tf_graph[op.name][0]
final_op = relu_op
fused_relu_arg = op_def.arg.add()
fused_relu_arg.name = 'fused_relu'
fused_relu_arg.i = 1
self.resolved_ops[relu_op.name] = 1
op_def.output.extend([final_op.outputs[0].name])
self.add_output_shape(final_op.outputs, op_def)
self.net_def.op.extend([op_def])
def convert_batchnorm(self, op):
......
......@@ -96,7 +96,7 @@ bazel-bin/mace/python/tools/tf_converter --input=${TF_MODEL_FILE_PATH} \
--output_type=source \
--template=${MACE_SOURCE_DIR}/mace/python/tools/model.template \
--model_tag=${MODEL_TAG} \
--confuse=False || exit -1
--confuse=True || exit -1
echo "Step 3: Generate version source"
rm -rf ${VERSION_SOURCE_PATH}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册