diff --git a/mace/core/operator.cc b/mace/core/operator.cc index e759d89d6ec7d187ed50b688c84c6bd4dd8f1ddb..eca09f3b102b00ea530a34230a254c24f6a103ce 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -76,6 +76,7 @@ extern void Register_Relu(OperatorRegistry *op_registry); extern void Register_ResizeBilinear(OperatorRegistry *op_registry); extern void Register_SpaceToBatchND(OperatorRegistry *op_registry); extern void Register_Softmax(OperatorRegistry *op_registry); +extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry); OperatorRegistry::OperatorRegistry() { Register_AddN(this); @@ -95,6 +96,7 @@ OperatorRegistry::OperatorRegistry() { Register_ResizeBilinear(this); Register_SpaceToBatchND(this); Register_Softmax(this); + Register_FoldedBatchNorm(this); } } // namespace mace diff --git a/mace/kernels/batch_norm.h b/mace/kernels/batch_norm.h index a8cbe58ac569a0aedbacd48413bd7735a7185bca..5c62e33da27639854f4b5d70867eebe75cfd9652 100644 --- a/mace/kernels/batch_norm.h +++ b/mace/kernels/batch_norm.h @@ -12,15 +12,26 @@ namespace mace { namespace kernels { +struct BatchNormFunctorBase { + BatchNormFunctorBase(bool folded_constant, bool fused_relu) : + folded_constant_(folded_constant), + fused_relu_(fused_relu){} + + const bool folded_constant_; + const bool fused_relu_; +}; + template -struct BatchNormFunctor { - float epsilon_; +struct BatchNormFunctor : BatchNormFunctorBase{ + BatchNormFunctor(const bool folded_constant, const bool fused_relu) : + BatchNormFunctorBase(folded_constant, fused_relu) {} void operator()(const Tensor *input, const Tensor *scale, const Tensor *offset, const Tensor *mean, const Tensor *var, + const float epsilon, Tensor *output, StatsFuture *future) { // Batch normalization in the paper https://arxiv.org/abs/1502.03167 . @@ -39,24 +50,27 @@ struct BatchNormFunctor { Tensor::MappingGuard input_mapper(input); Tensor::MappingGuard scale_mapper(scale); Tensor::MappingGuard offset_mapper(offset); - Tensor::MappingGuard mean_mapper(mean); - Tensor::MappingGuard var_mapper(var); Tensor::MappingGuard output_mapper(output); const T *input_ptr = input->data(); const T *scale_ptr = scale->data(); const T *offset_ptr = offset->data(); - const T *mean_ptr = mean->data(); - const T *var_ptr = var->data(); T *output_ptr = output->mutable_data(); - vector new_scale(channels); - vector new_offset(channels); - + vector new_scale; + vector new_offset; + if (!folded_constant_) { + new_scale.resize(channels); + new_offset.resize(channels); + Tensor::MappingGuard mean_mapper(mean); + Tensor::MappingGuard var_mapper(var); + const T *mean_ptr = mean->data(); + const T *var_ptr = var->data(); #pragma omp parallel for - for (index_t c = 0; c < channels; ++c) { - new_scale[c] = scale_ptr[c] / std::sqrt(var_ptr[c] + epsilon_); - new_offset[c] = offset_ptr[c] - mean_ptr[c] * new_scale[c]; + for (index_t c = 0; c < channels; ++c) { + new_scale[c] = scale_ptr[c] / std::sqrt(var_ptr[c] + epsilon); + new_offset[c] = offset_ptr[c] - mean_ptr[c] * new_scale[c]; + } } index_t pos = 0; @@ -66,7 +80,14 @@ struct BatchNormFunctor { for (index_t h = 0; h < height; ++h) { for (index_t w = 0; w < width; ++w) { for (index_t c = 0; c < channels; ++c) { - output_ptr[pos] = new_scale[c] * input_ptr[pos] + new_offset[c]; + if (folded_constant_) { + output_ptr[pos] = scale_ptr[c] * input_ptr[pos] + offset_ptr[c]; + } else { + output_ptr[pos] = new_scale[c] * input_ptr[pos] + new_offset[c]; + } + if (fused_relu_) { + output_ptr[pos] = std::max(output_ptr[pos], static_cast(0)); + } ++pos; } } @@ -82,18 +103,20 @@ void BatchNormFunctor::operator()( const Tensor *offset, const Tensor *mean, const Tensor *var, + const float epsilon, Tensor *output, StatsFuture *future); template -struct BatchNormFunctor { - float epsilon_; - +struct BatchNormFunctor : BatchNormFunctorBase { + BatchNormFunctor(const bool folded_constant, const bool fused_relu) : + BatchNormFunctorBase(folded_constant, fused_relu) {} void operator()(const Tensor *input, const Tensor *scale, const Tensor *offset, const Tensor *mean, const Tensor *var, + const float epsilon, Tensor *output, StatsFuture *future); }; diff --git a/mace/kernels/opencl/batch_norm_opencl.cc b/mace/kernels/opencl/batch_norm_opencl.cc index 6cd46e76a8490632227457854152adf6da92242a..c3dc8445f80e7b51b6956d7d789c5a46e4052dc7 100644 --- a/mace/kernels/opencl/batch_norm_opencl.cc +++ b/mace/kernels/opencl/batch_norm_opencl.cc @@ -19,8 +19,11 @@ void BatchNormFunctor::operator()( const Tensor *offset, const Tensor *mean, const Tensor *var, + const float epsilon, Tensor *output, StatsFuture *future) { + MACE_CHECK(folded_constant_ || (mean != nullptr && var != nullptr)); + const index_t batch = input->dim(0); const index_t height = input->dim(1); const index_t width = input->dim(2); @@ -33,15 +36,23 @@ void BatchNormFunctor::operator()( auto dt = DataTypeToEnum::value; built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); + if (folded_constant_) { + built_options.emplace("-DFOLDED_CONSTANT"); + } + if (fused_relu_) { + built_options.emplace("-DFUSED_RELU"); + } auto bm_kernel = runtime->BuildKernel("batch_norm", "batch_norm", built_options); uint32_t idx = 0; bm_kernel.setArg(idx++, *(static_cast(input->buffer()))); bm_kernel.setArg(idx++, *(static_cast(scale->buffer()))); bm_kernel.setArg(idx++, *(static_cast(offset->buffer()))); - bm_kernel.setArg(idx++, *(static_cast(mean->buffer()))); - bm_kernel.setArg(idx++, *(static_cast(var->buffer()))); - bm_kernel.setArg(idx++, epsilon_); + if (!folded_constant_) { + bm_kernel.setArg(idx++, *(static_cast(mean->buffer()))); + bm_kernel.setArg(idx++, *(static_cast(var->buffer()))); + bm_kernel.setArg(idx++, epsilon); + } bm_kernel.setArg(idx++, *(static_cast(output->buffer()))); const uint32_t gws[3] = {static_cast(channel_blocks), @@ -89,7 +100,8 @@ void BatchNormFunctor::operator()( << output->dim(0) << "_" << output->dim(1) << "_" << output->dim(2) << "_" - << output->dim(3); + << output->dim(3) << "_" + << folded_constant_; OpenCLProfilingTimer timer(&event); Tuner::Get()->template TuneOrRun(ss.str(), lws, diff --git a/mace/kernels/opencl/cl/batch_norm.cl b/mace/kernels/opencl/cl/batch_norm.cl index 027b678bfeac4345155dca29eea17ab0212ec7d6..f40609664de78c01a9dac3d66aec3f3c5b90d99e 100644 --- a/mace/kernels/opencl/cl/batch_norm.cl +++ b/mace/kernels/opencl/cl/batch_norm.cl @@ -3,27 +3,39 @@ __kernel void batch_norm(__read_only image2d_t input, __read_only image2d_t scale, __read_only image2d_t offset, +#ifndef FOLDED_CONSTANT __read_only image2d_t mean, __read_only image2d_t var, __private const float epsilon, +#endif __write_only image2d_t output) { const int ch_blk = get_global_id(0); const int w = get_global_id(1); const int hb = get_global_id(2); const int width = get_global_size(1); +#ifdef FOLDED_CONSTANT + DATA_TYPE4 bn_scale = READ_IMAGET(scale, SAMPLER, (int2)(ch_blk, 0)); + DATA_TYPE4 bn_offset = READ_IMAGET(offset, SAMPLER, (int2)(ch_blk, 0)); +#else DATA_TYPE4 scale_value = READ_IMAGET(scale, SAMPLER, (int2)(ch_blk, 0)); DATA_TYPE4 offset_value = READ_IMAGET(offset, SAMPLER, (int2)(ch_blk, 0)); DATA_TYPE4 mean_value = READ_IMAGET(mean, SAMPLER, (int2)(ch_blk, 0)); DATA_TYPE4 var_value = READ_IMAGET(var, SAMPLER, (int2)(ch_blk, 0)); // native_rsqrt seems not faster than rsqrt - DATA_TYPE4 new_scale = scale_value * rsqrt(var_value + (DATA_TYPE4)epsilon); - DATA_TYPE4 new_offset = mad(0 - mean_value, new_scale, offset_value); + DATA_TYPE4 bn_scale = scale_value * rsqrt(var_value + (DATA_TYPE4)epsilon); + DATA_TYPE4 bn_offset = mad(0 - mean_value, bn_scale, offset_value); +#endif const int pos = mad24(ch_blk, width, w); DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb)); - DATA_TYPE4 out = mad(in, new_scale, new_offset); + DATA_TYPE4 out = mad(in, bn_scale, bn_offset); + +#ifdef FUSED_RELU + out = fmax(out, 0); +#endif + WRITE_IMAGET(output, (int2)(pos, hb), out); } diff --git a/mace/ops/batch_norm.h b/mace/ops/batch_norm.h index 96c4a1fc1c42bc9596a37434b1660277ec11ea7d..c6d2dd27f7ee36dcf08bac799b199174094f7892 100644 --- a/mace/ops/batch_norm.h +++ b/mace/ops/batch_norm.h @@ -2,8 +2,8 @@ // Copyright (c) 2017 XiaoMi All rights reserved. // -#ifndef MACE_BATCH_NORM_H_ -#define MACE_BATCH_NORM_H_ +#ifndef MACE_OPS_BATCH_NORM_H_ +#define MACE_OPS_BATCH_NORM_H_ #include "mace/core/operator.h" #include "mace/kernels/batch_norm.h" @@ -14,9 +14,9 @@ template class BatchNormOp : public Operator { public: BatchNormOp(const OperatorDef &operator_def, Workspace *ws) - : Operator(operator_def, ws), functor_() { - functor_.epsilon_ = - OperatorBase::GetSingleArgument("epsilon", static_cast(1e-4)); + : Operator(operator_def, ws), functor_(false, false) { + epsilon_ = + OperatorBase::GetSingleArgument("epsilon", static_cast(1e-4)); } bool Run(StatsFuture *future) override { @@ -40,11 +40,12 @@ class BatchNormOp : public Operator { Tensor *output = this->Output(OUTPUT); output->ResizeLike(input); - functor_(input, scale, offset, mean, var, output, future); + functor_(input, scale, offset, mean, var, epsilon_, output, future); return true; } private: + float epsilon_; kernels::BatchNormFunctor functor_; protected: @@ -54,4 +55,4 @@ class BatchNormOp : public Operator { } // namespace mace -#endif // MACE_BATCH_NORM_H_ +#endif // MACE_OPS_BATCH_NORM_H_ diff --git a/mace/ops/folded_batch_norm.cc b/mace/ops/folded_batch_norm.cc new file mode 100644 index 0000000000000000000000000000000000000000..5a04c48dd8f2000c9a33b175ec5c67f4c4aebe81 --- /dev/null +++ b/mace/ops/folded_batch_norm.cc @@ -0,0 +1,37 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/folded_batch_norm.h" + +namespace mace { + +void Register_FoldedBatchNorm(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("FoldedBatchNorm") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + FoldedBatchNormOp); + +#if MACE_ENABLE_NEON + REGISTER_OPERATOR(op_registry, OpKeyBuilder("FoldedBatchNorm") + .Device(DeviceType::NEON) + .TypeConstraint("T") + .Build(), + FoldedBatchNormOp); +#endif // MACE_ENABLE_NEON + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("FoldedBatchNorm") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + FoldedBatchNormOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("FoldedBatchNorm") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + FoldedBatchNormOp); +} + +} // namespace mace diff --git a/mace/ops/folded_batch_norm.h b/mace/ops/folded_batch_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..390a30c7119fb20efa5d733dedade9e399b00e32 --- /dev/null +++ b/mace/ops/folded_batch_norm.h @@ -0,0 +1,50 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_FOLDED_BATCH_NORM_H_ +#define MACE_OPS_FOLDED_BATCH_NORM_H_ + +#include "mace/core/operator.h" +#include "mace/kernels/batch_norm.h" + +namespace mace { + +template +class FoldedBatchNormOp : public Operator { + public: + FoldedBatchNormOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws), + functor_(true, OperatorBase::GetSingleArgument("fused_relu", false)) { + } + + bool Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + const Tensor *scale = this->Input(SCALE); + const Tensor *offset = this->Input(OFFSET); + + MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional. ", + input->dim_size()); + MACE_CHECK(scale->dim_size() == 1, "scale must be 1-dimensional. ", + scale->dim_size()); + MACE_CHECK(offset->dim_size() == 1, "offset must be 1-dimensional. ", + offset->dim_size()); + + Tensor *output = this->Output(OUTPUT); + output->ResizeLike(input); + + functor_(input, scale, offset, nullptr, nullptr, 0, output, future); + return true; + } + + private: + kernels::BatchNormFunctor functor_; + + protected: + OP_INPUT_TAGS(INPUT, SCALE, OFFSET, MEAN, VAR); + OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace mace + +#endif // MACE_OPS_FOLDED_BATCH_NORM_H_ diff --git a/mace/ops/folded_batch_norm_test.cc b/mace/ops/folded_batch_norm_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5ee0a9473b196cc2747ae748dbdd37cbcbc5a5c3 --- /dev/null +++ b/mace/ops/folded_batch_norm_test.cc @@ -0,0 +1,393 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { + +class FoldedBatchNormOpTest : public OpsTestBase {}; + +void CalculateScaleOffset(const std::vector &gamma, + const std::vector &beta, + const std::vector &mean, + const std::vector &var, + const float epsilon, + std::vector &scale, + std::vector &offset) { + size_t size = gamma.size(); + for (int i = 0 ; i < size; ++i) { + scale[i] = gamma[i] / std::sqrt(var[i] + epsilon); + offset[i] = offset[i] - mean[i] * scale[i]; + } +} + +template +void Simple() { + OpsTestNet net; + + // Add input data + net.AddInputFromArray("Input", {1, 6, 2, 1}, + {5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}); + std::vector scale(1); + std::vector offset(1); + CalculateScaleOffset({4.0f}, {2.0}, {10}, {11.67f}, 1e-3, scale, offset); + net.AddInputFromArray("Scale", {1}, scale); + net.AddInputFromArray("Offset", {1}, offset); + + if (D == DeviceType::OPENCL) { + BufferToImage(net, "Input", "InputImage", + kernels::BufferType::IN_OUT); + BufferToImage(net, "Scale", "ScaleImage", + kernels::BufferType::ARGUMENT); + BufferToImage(net, "Offset", "OffsetImage", + kernels::BufferType::ARGUMENT); + + OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest") + .Input("InputImage") + .Input("ScaleImage") + .Input("OffsetImage") + .Output("OutputImage") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + + // Transfer output + ImageToBuffer(net, "OutputImage", "Output", + kernels::BufferType::IN_OUT); + } else { + OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest") + .Input("Input") + .Input("Scale") + .Input("Offset") + .Output("Output") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + } + + // Check + auto expected = + CreateTensor({1, 6, 2, 1}, {-3.86, -3.86, -1.51, -1.51, 0.83, 0.83, + 3.17, 3.17, 5.51, 5.51, 7.86, 7.86}); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-2); +} + +TEST_F(FoldedBatchNormOpTest, SimpleCPU) { Simple(); } + +/* +TEST_F(FoldedBatchNormOpTest, SimpleNEON) { + Simple(); +} +*/ + +TEST_F(FoldedBatchNormOpTest, SimpleOPENCL) { Simple(); } + +/* +TEST_F(FoldedBatchNormOpTest, SimpleRandomNeon) { + srand(time(NULL)); + + // generate random input + index_t batch = 1 + rand() % 10; + index_t channels = 3 + rand() % 50; + index_t height = 64; + index_t width = 64; + // Construct graph + OpsTestNet net; + OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest") + .Input("Input") + .Input("Scale") + .Input("Offset") + .Input("Mean") + .Input("Var") + .Input("Epsilon") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput("Input", {batch, channels, height, +width}); + net.AddRandomInput("Scale", {channels}); + net.AddRandomInput("Offset", {channels}); + net.AddRandomInput("Mean", {channels}); + net.AddRandomInput("Var", {channels}, true); + net.AddInputFromArray("Epsilon", {}, {1e-3}); + + // run cpu + net.RunOp(); + + // Check + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + // Run NEON + net.RunOp(DeviceType::NEON); + + ExpectTensorNear(expected, *net.GetOutput("Output"), 1e-2); +} + +TEST_F(FoldedBatchNormOpTest, ComplexRandomNeon) { + srand(time(NULL)); + + // generate random input + index_t batch = 1 + rand() % 10; + index_t channels = 3 + rand() % 50; + index_t height = 103; + index_t width = 113; + // Construct graph + OpsTestNet net; + OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest") + .Input("Input") + .Input("Scale") + .Input("Offset") + .Input("Mean") + .Input("Var") + .Input("Epsilon") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput("Input", {batch, channels, height, +width}); + net.AddRandomInput("Scale", {channels}); + net.AddRandomInput("Offset", {channels}); + net.AddRandomInput("Mean", {channels}); + net.AddRandomInput("Var", {channels}, true); + net.AddInputFromArray("Epsilon", {}, {1e-3}); + + // run cpu + net.RunOp(); + + // Check + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + // Run NEON + net.RunOp(DeviceType::NEON); + + ExpectTensorNear(expected, *net.GetOutput("Output"), 1e-2); +} +*/ + +TEST_F(FoldedBatchNormOpTest, SimpleRandomOPENCL) { + srand(time(NULL)); + + // generate random input + index_t batch = 1 + rand() % 10; + index_t channels = 3 + rand() % 50; + index_t height = 64; + index_t width = 64; + + // Construct graph + OpsTestNet net; + OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest") + .Input("Input") + .Input("Scale") + .Input("Offset") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput( + "Input", {batch, height, width, channels}); + net.AddRandomInput("Scale", {channels}); + net.AddRandomInput("Offset", {channels}); + + // run cpu + net.RunOp(); + + // Check + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + // Run on opencl + BufferToImage(net, "Input", "InputImage", + kernels::BufferType::IN_OUT); + BufferToImage(net, "Scale", "ScaleImage", + kernels::BufferType::ARGUMENT); + BufferToImage(net, "Offset", "OffsetImage", + kernels::BufferType::ARGUMENT); + + OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest") + .Input("InputImage") + .Input("ScaleImage") + .Input("OffsetImage") + .Output("OutputImage") + .Finalize(net.NewOperatorDef()); + + // Run on opencl + net.RunOp(DeviceType::OPENCL); + net.Sync(); + + ImageToBuffer(net, "OutputImage", "OPENCLOutput", + kernels::BufferType::IN_OUT); + ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 1e-2); +} + +TEST_F(FoldedBatchNormOpTest, SimpleRandomHalfOPENCL) { + srand(time(NULL)); + + // generate random input + index_t batch = 1 + rand() % 10; + index_t channels = 3 + rand() % 50; + index_t height = 64; + index_t width = 64; + + // Construct graph + OpsTestNet net; + OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest") + .Input("Input") + .Input("Scale") + .Input("Offset") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput( + "Input", {batch, height, width, channels}); + net.AddRandomInput("Scale", {channels}); + net.AddRandomInput("Offset", {channels}); + + // run cpu + net.RunOp(); + + // Check + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + // Run on opencl + BufferToImage(net, "Input", "InputImage", + kernels::BufferType::IN_OUT); + BufferToImage(net, "Scale", "ScaleImage", + kernels::BufferType::ARGUMENT); + BufferToImage(net, "Offset", "OffsetImage", + kernels::BufferType::ARGUMENT); + + OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest") + .Input("InputImage") + .Input("ScaleImage") + .Input("OffsetImage") + .Output("OutputImage") + .AddIntArg("T", static_cast(DataType::DT_HALF)) + .Finalize(net.NewOperatorDef()); + + // Run on opencl + net.RunOp(DeviceType::OPENCL); + net.Sync(); + + ImageToBuffer(net, "OutputImage", "OPENCLOutput", + kernels::BufferType::IN_OUT); + ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 0.5); +} + +TEST_F(FoldedBatchNormOpTest, ComplexRandomOPENCL) { + srand(time(NULL)); + + // generate random input + index_t batch = 1 + rand() % 10; + index_t channels = 3 + rand() % 50; + index_t height = 103; + index_t width = 113; + + // Construct graph + OpsTestNet net; + OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest") + .Input("Input") + .Input("Scale") + .Input("Offset") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput( + "Input", {batch, height, width, channels}); + net.AddRandomInput("Scale", {channels}); + net.AddRandomInput("Offset", {channels}); + + // run cpu + net.RunOp(); + + // Check + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + // Run on opencl + BufferToImage(net, "Input", "InputImage", + kernels::BufferType::IN_OUT); + BufferToImage(net, "Scale", "ScaleImage", + kernels::BufferType::ARGUMENT); + BufferToImage(net, "Offset", "OffsetImage", + kernels::BufferType::ARGUMENT); + + OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest") + .Input("InputImage") + .Input("ScaleImage") + .Input("OffsetImage") + .Output("OutputImage") + .Finalize(net.NewOperatorDef()); + + // Run on opencl + net.RunOp(DeviceType::OPENCL); + + ImageToBuffer(net, "OutputImage", "OPENCLOutput", + kernels::BufferType::IN_OUT); + ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 1e-2); +} + +TEST_F(FoldedBatchNormOpTest, ComplexRandomHalfOPENCL) { + srand(time(NULL)); + + // generate random input + index_t batch = 1 + rand() % 10; + index_t channels = 3 + rand() % 50; + index_t height = 103; + index_t width = 113; + + // Construct graph + OpsTestNet net; + OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest") + .Input("Input") + .Input("Scale") + .Input("Offset") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput( + "Input", {batch, height, width, channels}); + net.AddRandomInput("Scale", {channels}); + net.AddRandomInput("Offset", {channels}); + + // run cpu + net.RunOp(); + + // Check + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + // Run on opencl + BufferToImage(net, "Input", "InputImage", + kernels::BufferType::IN_OUT); + BufferToImage(net, "Scale", "ScaleImage", + kernels::BufferType::ARGUMENT); + BufferToImage(net, "Offset", "OffsetImage", + kernels::BufferType::ARGUMENT); + + OpDefBuilder("FoldedBatchNorm", "FoldedBatchNormTest") + .Input("InputImage") + .Input("ScaleImage") + .Input("OffsetImage") + .Output("OutputImage") + .AddIntArg("T", static_cast(DataType::DT_HALF)) + .Finalize(net.NewOperatorDef()); + + // Run on opencl + net.RunOp(DeviceType::OPENCL); + + ImageToBuffer(net, "OutputImage", "OPENCLOutput", + kernels::BufferType::IN_OUT); + ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 0.5); +} +} diff --git a/mace/python/tools/tf_converter.py b/mace/python/tools/tf_converter.py index 2b182fe67eb2a836a2b09752804b8fe27c20160a..99383535f20ea2bcfb70bf4444599acc039b5e2c 100644 --- a/mace/python/tools/tf_converter.py +++ b/mace/python/tools/tf_converter.py @@ -39,6 +39,13 @@ def main(unused_args): f.write(str(output_graph_def)) print("Model conversion is completed.") +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') def parse_args(): """Parses command line arguments.""" @@ -91,7 +98,9 @@ def parse_args(): help="template path") parser.add_argument( "--confuse", - type=bool, + type=str2bool, + nargs='?', + const=False, default=False, help="confuse model names") parser.add_argument( diff --git a/mace/python/tools/tf_converter_lib.py b/mace/python/tools/tf_converter_lib.py index 006d41fe8851637a731e15e22c0b9cc36e69be13..64bd7b9e41dcaf6770b109f8f0bc4598da885105 100644 --- a/mace/python/tools/tf_converter_lib.py +++ b/mace/python/tools/tf_converter_lib.py @@ -1,6 +1,7 @@ from mace.proto import mace_pb2 import tensorflow as tf import numpy as np +import math from mace.python.tools import memory_optimizer # TODO: support NCHW formt, now only support NHWC. @@ -136,6 +137,22 @@ class TFConverter(object): output_shapes.append(output_shape) op.output_shape.extend(output_shapes) + def add_tensor(self, name, shape, tf_dt, value): + tensor = self.net_def.tensors.add() + tensor.name = name + + shape = list(shape) + tensor.dims.extend(shape) + + if tf_dt == tf.float32: + tensor.data_type = mace_pb2.DT_FLOAT + tensor.float_data.extend(value.flat) + elif tf_dt == tf.int32: + tensor.data_type = mace_pb2.DT_INT32 + tensor.int32_data.extend(value.flat) + else: + raise Exception("Not supported tensor type: " + tf_dt.name) + def convert_tensor(self, op): if op.outputs[0].name not in self.unused_tensor: tensor = self.net_def.tensors.add() @@ -211,26 +228,58 @@ class TFConverter(object): arg = op_def.arg.add() arg.name = 'T' arg.i = self.dt + data_format_arg = op_def.arg.add() + data_format_arg.name = 'data_format' + data_format_arg.s = 'NHWC' op_def.name = op.name - op_def.type = 'BatchNorm' + op_def.type = 'FoldedBatchNorm' + + gamma_tensor = get_input_tensor(op, 1) + for i in range(1, 5): + input_tensor = get_input_tensor(op, i) + assert input_tensor.shape == gamma_tensor.shape + self.unused_tensor.add(input_tensor.name) + + gamma_value = get_input_tensor(op, 1).eval().astype(np.float32) + beta_value = get_input_tensor(op, 2).eval().astype(np.float32) + mean_value = get_input_tensor(op, 3).eval().astype(np.float32) + var_value = get_input_tensor(op, 4).eval().astype(np.float32) + epsilon_value = op.get_attr('epsilon') + + scale_value = ( + (1.0 / np.vectorize(math.sqrt)(var_value + epsilon_value)) * + gamma_value) + offset_value = (-mean_value * scale_value) + beta_value + idx = gamma_tensor.name.rfind('/') + name_prefix = gamma_tensor.name[:idx] + '/' + input_names = [name_prefix+'scale:0', name_prefix+'offset:0'] + self.add_tensor(input_names[0], gamma_value.shape, + gamma_tensor.dtype, scale_value) + self.add_tensor(input_names[1], gamma_value.shape, + gamma_tensor.dtype, offset_value) + if self.device == 'gpu': op_def.input.extend([op.inputs[0].name]) - for i in range(1, len(op.inputs)): - output_name = self.add_buffer_to_image(op.inputs[i].name, "ARGUMENT") + for name in input_names: + output_name = self.add_buffer_to_image(name, "ARGUMENT") op_def.input.extend([output_name]) else: - op_def.input.extend([input.name for input in op.inputs]) - op_def.output.extend([op.outputs[0].name]) - - self.add_output_shape(op.outputs, op_def) + op_def.input.extend([input.name for input in input_names]) - epsilon_arg = op_def.arg.add() - epsilon_arg.name = 'epsilon' - epsilon_arg.f = op.get_attr('epsilon') - data_format_arg = op_def.arg.add() - data_format_arg.name = 'data_format' - data_format_arg.s = 'NHWC' self.resolved_ops[op.name] = 1 + final_op = op + + if len(self.tf_graph[op.name]) == 1 and self.tf_graph[op.name][0].type == 'Relu': + relu_op = self.tf_graph[op.name][0] + final_op = relu_op + fused_relu_arg = op_def.arg.add() + fused_relu_arg.name = 'fused_relu' + fused_relu_arg.i = 1 + self.resolved_ops[relu_op.name] = 1 + + op_def.output.extend([final_op.outputs[0].name]) + self.add_output_shape(final_op.outputs, op_def) + self.net_def.op.extend([op_def]) def convert_batchnorm(self, op): diff --git a/tools/validate_gcn.sh b/tools/validate_gcn.sh index 7cbfc4d5104e22299cbea4f8bd47a0b7881921f1..8d01b110122366a7839aa4fb49c7685784369667 100755 --- a/tools/validate_gcn.sh +++ b/tools/validate_gcn.sh @@ -96,7 +96,7 @@ bazel-bin/mace/python/tools/tf_converter --input=${TF_MODEL_FILE_PATH} \ --output_type=source \ --template=${MACE_SOURCE_DIR}/mace/python/tools/model.template \ --model_tag=${MODEL_TAG} \ - --confuse=False || exit -1 + --confuse=True || exit -1 echo "Step 3: Generate version source" rm -rf ${VERSION_SOURCE_PATH}