From 622173e7eac782b12fe0a27b97475708cc2a2b6d Mon Sep 17 00:00:00 2001 From: yejianwu Date: Wed, 29 Aug 2018 11:32:44 +0800 Subject: [PATCH] add lstm gpu, unstack cpu --- mace/kernels/BUILD | 2 + mace/kernels/lstmcell.h | 60 +++++ mace/kernels/opencl/cl/common.h | 7 +- mace/kernels/opencl/cl/lstmcell.cl | 157 ++++++++++++ mace/kernels/opencl/lstmcell.cc | 100 ++++++++ mace/kernels/unstack.h | 83 +++++++ mace/ops/BUILD | 2 + mace/ops/lstmcell.cc | 35 +++ mace/ops/lstmcell.h | 60 +++++ mace/ops/lstmcell_benchmark.cc | 105 +++++++++ mace/ops/lstmcell_test.cc | 223 ++++++++++++++++++ mace/ops/ops_register.cc | 4 + mace/ops/ops_test_util.h | 15 ++ mace/ops/unstack.cc | 34 +++ mace/ops/unstack.h | 49 ++++ mace/ops/unstack_test.cc | 77 ++++++ mace/python/tools/convert_util.py | 20 +- .../tools/converter_tool/base_converter.py | 7 + .../converter_tool/tensorflow_converter.py | 15 ++ .../tools/converter_tool/transformer.py | 162 +++++++++++++ mace/python/tools/memory_optimizer.py | 2 +- .../opencl-kernel/opencl_kernel_configure.bzl | 1 + 22 files changed, 1212 insertions(+), 8 deletions(-) create mode 100644 mace/kernels/lstmcell.h create mode 100644 mace/kernels/opencl/cl/lstmcell.cl create mode 100644 mace/kernels/opencl/lstmcell.cc create mode 100644 mace/kernels/unstack.h create mode 100644 mace/ops/lstmcell.cc create mode 100644 mace/ops/lstmcell.h create mode 100644 mace/ops/lstmcell_benchmark.cc create mode 100644 mace/ops/lstmcell_test.cc create mode 100644 mace/ops/unstack.cc create mode 100644 mace/ops/unstack.h create mode 100644 mace/ops/unstack_test.cc diff --git a/mace/kernels/BUILD b/mace/kernels/BUILD index 5706f94e..3491b743 100644 --- a/mace/kernels/BUILD +++ b/mace/kernels/BUILD @@ -45,11 +45,13 @@ cc_library( exclude = [ "buffer_to_image.h", "image_to_buffer.h", + "lstmcell.h", ], ) + if_opencl_enabled(glob([ "opencl/*.h", "buffer_to_image.h", "image_to_buffer.h", + "lstmcell.h", ])), copts = [ "-Werror", diff --git a/mace/kernels/lstmcell.h b/mace/kernels/lstmcell.h new file mode 100644 index 00000000..46439fae --- /dev/null +++ b/mace/kernels/lstmcell.h @@ -0,0 +1,60 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_KERNELS_LSTMCELL_H_ +#define MACE_KERNELS_LSTMCELL_H_ + +#include +#include +#include +#include + +#include "mace/core/future.h" +#include "mace/core/runtime/opencl/cl2_header.h" +#include "mace/core/tensor.h" + +#if defined(MACE_ENABLE_NEON) +#include +#endif + +namespace mace { +namespace kernels { + +template +struct LSTMCellFunctor; + +template +struct LSTMCellFunctor { + explicit LSTMCellFunctor(T forget_bias) : + forget_bias_(static_cast(forget_bias)) {} + MaceStatus operator()(const Tensor *input, + const Tensor *pre_output, + const Tensor *weight, + const Tensor *bias, + const Tensor *pre_cell, + Tensor *cell, + Tensor *output, + StatsFuture *future); + + T forget_bias_; + cl::Kernel kernel_; + uint32_t kwg_size_; + std::unique_ptr kernel_error_; + std::vector input_shape_; +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_LSTMCELL_H_ diff --git a/mace/kernels/opencl/cl/common.h b/mace/kernels/opencl/cl/common.h index 8408f1be..8f4aa37f 100644 --- a/mace/kernels/opencl/cl/common.h +++ b/mace/kernels/opencl/cl/common.h @@ -61,6 +61,11 @@ __constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +inline float4 do_sigmoid(float4 in) { + // native_func not support half + return native_recip(1.0 + native_exp(-in)); +} + inline DATA_TYPE4 do_activation(DATA_TYPE4 in, #ifdef USE_PRELU DATA_TYPE4 prelu_alpha, @@ -80,7 +85,7 @@ inline DATA_TYPE4 do_activation(DATA_TYPE4 in, out = tanh(in); #endif #ifdef USE_SIGMOID - out = native_recip((DATA_TYPE)1 + native_exp(-in)); + out = do_sigmoid(in); #endif return out; } diff --git a/mace/kernels/opencl/cl/lstmcell.cl b/mace/kernels/opencl/cl/lstmcell.cl new file mode 100644 index 00000000..140132bd --- /dev/null +++ b/mace/kernels/opencl/cl/lstmcell.cl @@ -0,0 +1,157 @@ +#include + +__kernel void lstmcell(KERNEL_ERROR_PARAMS + GLOBAL_WORK_GROUP_SIZE_DIM2 + __read_only image2d_t input, + __read_only image2d_t pre_output, + __read_only image2d_t weight, + __read_only image2d_t bias, + __read_only image2d_t pre_cell, + __private const float forget_bias, + __private const int width, + __write_only image2d_t cell, + __write_only image2d_t output) { + const int w_blk_idx = get_global_id(0); + const int h_idx = get_global_id(1); + +#ifndef NON_UNIFORM_WORK_GROUP + if (w_blk_idx >= global_size_dim0 || h_idx >= global_size_dim1) return; +#endif + + // fc_res0 -> i + // fc_res1 -> j + // fc_res2 -> f + // fc_res3 -> o + DATA_TYPE4 fc_res0 = 0.0, fc_res1 = 0.0, fc_res2 = 0.0, fc_res3 = 0.0; + DATA_TYPE4 in, pre_h; + DATA_TYPE4 w0, w1, w2, w3; + // concat matmul + for (short i = 0; i < global_size_dim0; ++i) { + in = READ_IMAGET(input, SAMPLER, (int2)(i, h_idx)); + short k = 4 * i; + + w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k)); + w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k)); + w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k)); + w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k)); + + fc_res0 += in.x * w0; + fc_res1 += in.x * w1; + fc_res2 += in.x * w2; + fc_res3 += in.x * w3; + + k = 4 * i + 1; + if (k < width) { + w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k)); + w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k)); + w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k)); + w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k)); + + fc_res0 += in.y * w0; + fc_res1 += in.y * w1; + fc_res2 += in.y * w2; + fc_res3 += in.y * w3; + } + + k = 4 * i + 2; + if (k < width) { + w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k)); + w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k)); + w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k)); + w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k)); + + fc_res0 += in.z * w0; + fc_res1 += in.z * w1; + fc_res2 += in.z * w2; + fc_res3 += in.z * w3; + } + + k = 4 * i + 3; + if (k < width) { + w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k)); + w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k)); + w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k)); + w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k)); + + fc_res0 += in.w * w0; + fc_res1 += in.w * w1; + fc_res2 += in.w * w2; + fc_res3 += in.w * w3; + } + } + + for (short i = 0; i < global_size_dim0; ++i) { + pre_h = READ_IMAGET(pre_output, SAMPLER, (int2)(i, h_idx)); + short k = 4 * (i + global_size_dim0); + short k_limit = 4 * global_size_dim0 + width; + + w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k)); + w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k)); + w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k)); + w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k)); + + fc_res0 += pre_h.x * w0; + fc_res1 += pre_h.x * w1; + fc_res2 += pre_h.x * w2; + fc_res3 += pre_h.x * w3; + + k = 4 * (i + global_size_dim0) + 1; + if (k < k_limit) { + w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k)); + w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k)); + w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k)); + w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k)); + + fc_res0 += pre_h.y * w0; + fc_res1 += pre_h.y * w1; + fc_res2 += pre_h.y * w2; + fc_res3 += pre_h.y * w3; + } + + k = 4 * (i + global_size_dim0) + 2; + if (k < k_limit) { + w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k)); + w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k)); + w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k)); + w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k)); + + fc_res0 += pre_h.z * w0; + fc_res1 += pre_h.z * w1; + fc_res2 += pre_h.z * w2; + fc_res3 += pre_h.z * w3; + } + + k = 4 * (i + global_size_dim0) + 3; + if (k < k_limit) { + w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k)); + w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k)); + w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k)); + w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k)); + + fc_res0 += pre_h.w * w0; + fc_res1 += pre_h.w * w1; + fc_res2 += pre_h.w * w2; + fc_res3 += pre_h.w * w3; + } + } + + // bias + DATA_TYPE4 b0, b1, b2, b3; + b0 = READ_IMAGET(bias, SAMPLER, (int2)(w_blk_idx, 0)); + b1 = READ_IMAGET(bias, SAMPLER, (int2)(w_blk_idx + global_size_dim0, 0)); + b2 = READ_IMAGET(bias, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, 0)); + b3 = READ_IMAGET(bias, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, 0)); + fc_res0 += b0; + fc_res1 += b1; + fc_res2 += b2; + fc_res3 += b3; + + // gate + DATA_TYPE4 pre_c, c, h; + pre_c = READ_IMAGET(pre_cell, SAMPLER, (int2)(w_blk_idx, h_idx)); + c = do_sigmoid(fc_res0) * tanh(fc_res1) + do_sigmoid((fc_res2 + (float4)forget_bias)) * pre_c; + h = do_sigmoid(fc_res3) * tanh(c); + + WRITE_IMAGET(cell, (int2)(w_blk_idx, h_idx), c); + WRITE_IMAGET(output, (int2)(w_blk_idx, h_idx), h); +} diff --git a/mace/kernels/opencl/lstmcell.cc b/mace/kernels/opencl/lstmcell.cc new file mode 100644 index 00000000..4cfc98b2 --- /dev/null +++ b/mace/kernels/opencl/lstmcell.cc @@ -0,0 +1,100 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/kernels/lstmcell.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/kernels/opencl/helper.h" +#include "mace/utils/tuner.h" +#include "mace/utils/utils.h" + +namespace mace { +namespace kernels { + +template +MaceStatus LSTMCellFunctor::operator()( + const Tensor *input, + const Tensor *pre_output, + const Tensor *weight, + const Tensor *bias, + const Tensor *pre_cell, + Tensor *cell, + Tensor *output, + StatsFuture *future) { + const index_t height = input->dim(0); + const index_t width = input->dim(1); + + auto runtime = OpenCLRuntime::Global(); + + if (kernel_.get() == nullptr) { + std::set built_options; + OUT_OF_RANGE_CONFIG(kernel_error_); + NON_UNIFORM_WG_CONFIG; + auto dt = DataTypeToEnum::value; + std::string kernel_name = MACE_OBFUSCATE_SYMBOL("lstmcell"); + built_options.emplace("-Dlstmcell=" + kernel_name); + built_options.emplace("-DDATA_TYPE=" + DtToUpCompatibleCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpCompatibleCLCMDDt(dt)); + + MACE_RETURN_IF_ERROR(runtime->BuildKernel("lstmcell", kernel_name, + built_options, &kernel_)); + + kwg_size_ = + static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); + } + + const index_t width_blocks = RoundUpDiv4(width); + const uint32_t gws[2] = {static_cast(width_blocks), + static_cast(height)}; + + if (!IsVecEqual(input_shape_, input->shape())) { + std::vector output_shape_paded = {height, 1, 1, width}; + std::vector output_image_shape; + CalImage2DShape(output_shape_paded, BufferType::IN_OUT_CHANNEL, + &output_image_shape); + MACE_RETURN_IF_ERROR(output->ResizeImage(input->shape(), + output_image_shape)); + MACE_RETURN_IF_ERROR(cell->ResizeImage(input->shape(), output_image_shape)); + + uint32_t idx = 0; + OUT_OF_RANGE_SET_ARG; + SET_2D_GWS_ARGS(kernel_); + kernel_.setArg(idx++, *(input->opencl_image())); + kernel_.setArg(idx++, *(pre_output->opencl_image())); + kernel_.setArg(idx++, *(weight->opencl_image())); + kernel_.setArg(idx++, *(bias->opencl_image())); + kernel_.setArg(idx++, *(pre_cell->opencl_image())); + kernel_.setArg(idx++, static_cast(forget_bias_)); + kernel_.setArg(idx++, static_cast(width)); + kernel_.setArg(idx++, *(cell->opencl_image())); + kernel_.setArg(idx++, *(output->opencl_image())); + + input_shape_ = input->shape(); + } + + const std::vector lws = {kwg_size_ / 16, 16, 0}; + std::string tuning_key = + Concat("lstmcell_opencl_kernel", output->dim(0), output->dim(1)); + MACE_RETURN_IF_ERROR(TuningOrRun2DKernel(kernel_, tuning_key, + gws, lws, future)); + OUT_OF_RANGE_VALIDATION(kernel_error_); + + return MACE_SUCCESS; +} + +template struct LSTMCellFunctor; + +template struct LSTMCellFunctor; + +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/unstack.h b/mace/kernels/unstack.h new file mode 100644 index 00000000..08e3d798 --- /dev/null +++ b/mace/kernels/unstack.h @@ -0,0 +1,83 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_KERNELS_UNSTACK_H_ +#define MACE_KERNELS_UNSTACK_H_ + +#include +#include +#include +#include + +#include "mace/core/future.h" +#include "mace/core/tensor.h" +#include "mace/public/mace.h" + +namespace mace { +namespace kernels { + +template +struct UnstackFunctor { + explicit UnstackFunctor(int axis) : axis_(axis) {} + + MaceStatus operator()(const Tensor *input, + const std::vector &outputs, + StatsFuture *future) { + std::vector input_shape = input->shape(); + MACE_CHECK(axis_ >= -(input->dim_size()) && axis_ < input->dim_size(), + "axis out of bound."); + if (axis_ < 0) { + axis_ += input->dim_size(); + } + MACE_CHECK(outputs.size() == input_shape[axis_], + "output size not equal input_shape[axis]"); + + std::vector output_shape = input_shape; + output_shape.erase(output_shape.begin() + axis_); + + std::vector output_data(outputs.size(), nullptr); + for (size_t i = 0; i < input_shape[axis_]; ++i) { + MACE_RETURN_IF_ERROR(outputs[i]->Resize(output_shape)); + output_data[i] = outputs[i]->mutable_data(); + } + const T *input_data = input->data(); + + index_t high_dim_elem_size = + std::accumulate(input_shape.begin(), input_shape.begin() + axis_, 1, + std::multiplies()); + index_t low_dim_elem_size = + std::accumulate(input_shape.begin() + axis_ + 1, input_shape.end(), 1, + std::multiplies()); + + for (index_t h = 0; h < high_dim_elem_size; ++h) { + int input_idx = h * input_shape[axis_] * low_dim_elem_size; + int output_idx = h * low_dim_elem_size; + for (size_t i = 0; i < input_shape[axis_]; ++i) { + memcpy(output_data[i] + output_idx, input_data + input_idx, + sizeof(T) * low_dim_elem_size); + input_idx += low_dim_elem_size; + } + } + + SetFutureDefaultWaitFn(future); + return MACE_SUCCESS; + } + + int axis_; +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_UNSTACK_H_ diff --git a/mace/ops/BUILD b/mace/ops/BUILD index 690f4a96..07aad1d2 100644 --- a/mace/ops/BUILD +++ b/mace/ops/BUILD @@ -38,11 +38,13 @@ cc_library( "*_benchmark.cc", "buffer_to_image.cc", "image_to_buffer.cc", + "lstmcell.cc", ], ) + if_opencl_enabled( [ "buffer_to_image.cc", "image_to_buffer.cc", + "lstmcell.cc", ], ), hdrs = glob( diff --git a/mace/ops/lstmcell.cc b/mace/ops/lstmcell.cc new file mode 100644 index 00000000..9926ad4b --- /dev/null +++ b/mace/ops/lstmcell.cc @@ -0,0 +1,35 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/lstmcell.h" + +namespace mace { +namespace ops { + +void Register_LSTMCell(OperatorRegistryBase *op_registry) { + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("LSTMCell") + .Device(DeviceType::GPU) + .TypeConstraint("T") + .Build(), + LSTMCellOp); + + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("LSTMCell") + .Device(DeviceType::GPU) + .TypeConstraint("T") + .Build(), + LSTMCellOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/lstmcell.h b/mace/ops/lstmcell.h new file mode 100644 index 00000000..a4032379 --- /dev/null +++ b/mace/ops/lstmcell.h @@ -0,0 +1,60 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_LSTMCELL_H_ +#define MACE_OPS_LSTMCELL_H_ + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/lstmcell.h" + +namespace mace { +namespace ops { + +template +class LSTMCellOp : public Operator { + public: + LSTMCellOp(const OperatorDef &op_def, Workspace *ws) + : Operator(op_def, ws), + functor_(static_cast( + OperatorBase::GetOptionalArg("value", 0.0))) {} + + MaceStatus Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + const Tensor *pre_output = this->Input(PRE_OUTPUT); + const Tensor *weight = this->Input(WEIGHT); + const Tensor *bias = this->Input(BIAS); + const Tensor *pre_cell = this->Input(PRE_CELL); + Tensor *cell = this->Output(CELL); + Tensor *output = this->Output(OUTPUT); + + MACE_CHECK(input->dim_size() == 2 && input->dim(1) % 4 == 0, + "LSTM step should be a multiple of 4"); + + return functor_( + input, pre_output, weight, bias, pre_cell, cell, output, future); + }; + + protected: + kernels::LSTMCellFunctor functor_; + + MACE_OP_INPUT_TAGS(INPUT, PRE_OUTPUT, WEIGHT, BIAS, PRE_CELL); + MACE_OP_OUTPUT_TAGS(CELL, OUTPUT); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_LSTMCELL_H_ diff --git a/mace/ops/lstmcell_benchmark.cc b/mace/ops/lstmcell_benchmark.cc new file mode 100644 index 00000000..a465485e --- /dev/null +++ b/mace/ops/lstmcell_benchmark.cc @@ -0,0 +1,105 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/operator.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +namespace { +template +void LSTMCell(int iters, int batch, int lstm_step) { + mace::testing::StopTiming(); + + OpsTestNet net; + + // Add input data + if (D == DeviceType::GPU) { + net.AddRandomInput("Input", {batch, lstm_step}); + net.AddRandomInput("PreOutput", {batch, lstm_step}); + net.AddRandomInput("Weight", {2 * lstm_step, 4 * lstm_step}); + net.AddRandomInput("Bias", {4 * lstm_step}); + net.AddRandomInput("PreCell", {batch, lstm_step}); + } else { + MACE_NOT_IMPLEMENTED; + } + + if (D == DeviceType::GPU) { + BufferToImage(&net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(&net, "PreOutput", "PreOutputImage", + kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(&net, "Weight", "WeightImage", + kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(&net, "Bias", "BiasImage", + kernels::BufferType::ARGUMENT); + BufferToImage(&net, "PreCell", "PreCellImage", + kernels::BufferType::IN_OUT_CHANNEL); + + OpDefBuilder("LSTMCell", "LSTMCellTest") + .Input("InputImage") + .Input("PreOutputImage") + .Input("WeightImage") + .Input("BiasImage") + .Input("PreCellImage") + .AddFloatArg("forget_add", 0.0f) + .Output("CellImage") + .Output("OutputImage") + .Finalize(net.NewOperatorDef()); + } else { + MACE_NOT_IMPLEMENTED; + } + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} +} // namespace + +#define MACE_BM_LSTMCELL_MACRO(N, LSTM_STEP, TYPE, DEVICE) \ + static void MACE_BM_LSTMCELL_##N##_##LSTM_STEP##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t macc = \ + static_cast(iters) * N * 2 * LSTM_STEP * 4 * LSTM_STEP; \ + const int64_t tot = static_cast(iters) * N * LSTM_STEP; \ + mace::testing::MaccProcessed(macc); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + LSTMCell(iters, N, LSTM_STEP); \ + } \ + MACE_BENCHMARK(MACE_BM_LSTMCELL_##N##_##LSTM_STEP##_##TYPE##_##DEVICE) + +#define MACE_BM_LSTMCELL(N, LSTM_STEP) \ + MACE_BM_LSTMCELL_MACRO(N, LSTM_STEP, float, GPU); \ + MACE_BM_LSTMCELL_MACRO(N, LSTM_STEP, half, GPU); + +MACE_BM_LSTMCELL(1, 200); +MACE_BM_LSTMCELL(20, 200); +MACE_BM_LSTMCELL(20, 320); +MACE_BM_LSTMCELL(32, 400); +MACE_BM_LSTMCELL(32, 640); +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/lstmcell_test.cc b/mace/ops/lstmcell_test.cc new file mode 100644 index 00000000..41866193 --- /dev/null +++ b/mace/ops/lstmcell_test.cc @@ -0,0 +1,223 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/operator.h" +#include "mace/kernels/eltwise.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class LSTMCellTest : public OpsTestBase {}; + +namespace { + +template +void LSTMCellCPU(OpsTestNet *net, + const std::string &input_name, + const std::string &pre_output_name, + const std::string &weight_name, + const std::string &bias_name, + const std::string &pre_cell_name, + const float &forget_add_name, + const std::string &cell_name, + const std::string &output_name) { + OpDefBuilder("Concat", "Concat") + .Input(input_name) + .Input(pre_output_name) + .AddIntArg("axis", 1) + .Output("ConcatOutput") + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("MatMul", "MatMul") + .Input("ConcatOutput") + .Input(weight_name) + .Output("MatMulOutput") + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("BiasAdd", "BiasAdd") + .Input("MatMulOutput") + .Input(bias_name) + .Output("BiasOutput") + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("Split", "FCSplit") + .Input("BiasOutput") + .AddIntArg("axis", 1) + .Output("SplitOutput0") + .Output("SplitOutput1") + .Output("SplitOutput2") + .Output("SplitOutput3") + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("Activation", "InputSigmoid") + .Input("SplitOutput0") + .AddStringArg("activation", "SIGMOID") + .Output("InputSigmoid") + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("Activation", "NewInputTanh") + .Input("SplitOutput1") + .AddStringArg("activation", "TANH") + .Output("NewInputTanh") + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("Eltwise", "RememberMul") + .Input("InputSigmoid") + .Input("NewInputTanh") + .AddIntArg("T", DataTypeToEnum::v()) + .AddIntArg("type", static_cast(kernels::EltwiseType::PROD)) + .Output("RememberMul") + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("Eltwise", "ForgetAdd") + .Input("SplitOutput2") + .AddFloatArg("value", forget_add_name) + .AddIntArg("T", DataTypeToEnum::v()) + .AddIntArg("type", static_cast(kernels::EltwiseType::SUM)) + .Output("ForgetAdd") + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("Activation", "ForgetSigmoid") + .Input("ForgetAdd") + .AddStringArg("activation", "SIGMOID") + .Output("ForgetSigmoid") + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("Eltwise", "ForgetMul") + .Input("ForgetSigmoid") + .Input(pre_cell_name) + .AddIntArg("T", DataTypeToEnum::v()) + .AddIntArg("type", static_cast(kernels::EltwiseType::PROD)) + .Output("ForgetMulPreCell") + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("Eltwise", "Cell") + .Input("RememberMul") + .Input("ForgetMulPreCell") + .AddIntArg("T", DataTypeToEnum::v()) + .AddIntArg("type", static_cast(kernels::EltwiseType::SUM)) + .Output(cell_name) + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("Activation", "CellTanh") + .Input(cell_name) + .AddStringArg("activation", "TANH") + .Output("CellTanh") + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("Activation", "OutputSigmoid") + .Input("SplitOutput3") + .AddStringArg("activation", "SIGMOID") + .Output("OutputSigmoid") + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("Eltwise", "FinalMul") + .Input("OutputSigmoid") + .Input("CellTanh") + .AddIntArg("T", DataTypeToEnum::v()) + .AddIntArg("type", static_cast(kernels::EltwiseType::PROD)) + .Output(output_name) + .Finalize(net->AddNewOperatorDef()); +} + +template +void TestLSTMCell(const uint32_t &batch, + const uint32_t &lstm_step, + const float &forget_add) { + // Construct graph + OpsTestNet net; + + net.AddRandomInput("Input", {batch, lstm_step}); + net.AddRandomInput("PreOutput", {batch, lstm_step}); + net.AddRandomInput("Weight", {2 * lstm_step, 4 * lstm_step}); + net.AddRandomInput("Bias", {4 * lstm_step}); + net.AddRandomInput("PreCell", {batch, lstm_step}); + + net.CopyData("Input", "InputCPU"); + net.CopyData("PreOutput", "PreOutputCPU"); + net.CopyData("Weight", "WeightCPU"); + net.CopyData("Bias", "BiasCPU"); + net.CopyData("PreCell", "PreCellCPU"); + + // Run on CPU + LSTMCellCPU(&net, "InputCPU", "PreOutputCPU", "WeightCPU", "BiasCPU", + "PreCellCPU", forget_add, "CellCPU", "OutputCPU"); + // Run + net.RunOp(DeviceType::CPU); + + // Run on GPU + BufferToImage(&net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(&net, "PreOutput", "PreOutputImage", + kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(&net, "Weight", "WeightImage", + kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(&net, "Bias", "BiasImage", + kernels::BufferType::ARGUMENT); + BufferToImage(&net, "PreCell", "PreCellImage", + kernels::BufferType::IN_OUT_CHANNEL); + + OpDefBuilder("LSTMCell", "LSTMCellTest") + .Input("InputImage") + .Input("PreOutputImage") + .Input("WeightImage") + .Input("BiasImage") + .Input("PreCellImage") + .AddFloatArg("forget_add", forget_add) + .Output("CellImage") + .Output("OutputImage") + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(D); + + ImageToBuffer(&net, "OutputImage", "Output", + kernels::BufferType::IN_OUT_CHANNEL); + ImageToBuffer(&net, "CellImage", "Cell", + kernels::BufferType::IN_OUT_CHANNEL); + + + Tensor expected_cell, expected_output; + expected_cell.Copy(*net.GetOutput("CellCPU")); + expected_output.Copy(*net.GetOutput("OutputCPU")); + + if (DataTypeToEnum::value == DT_HALF) { + ExpectTensorNear(expected_cell, *net.GetOutput("Cell"), 1e-3); + ExpectTensorNear(expected_output, *net.GetOutput("Output"), 1e-3); + } else { + ExpectTensorNear(expected_cell, *net.GetOutput("Cell"), 1e-5); + ExpectTensorNear(expected_output, *net.GetOutput("Output"), 1e-5); + } +} +} // namespace + +TEST_F(LSTMCellTest, OPENCLRandomHalf) { + TestLSTMCell(1, 4, 0.0f); + TestLSTMCell(2, 16, 0.0f); + TestLSTMCell(2, 200, 0.5f); + TestLSTMCell(20, 320, 0.5f); +} + +TEST_F(LSTMCellTest, OPENCLRandomFloat) { + TestLSTMCell(1, 4, 0.0f); + TestLSTMCell(2, 16, 0.0f); + TestLSTMCell(2, 200, 0.5f); + TestLSTMCell(20, 320, 0.5f); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/ops_register.cc b/mace/ops/ops_register.cc index c318eb44..c8d85822 100644 --- a/mace/ops/ops_register.cc +++ b/mace/ops/ops_register.cc @@ -53,6 +53,7 @@ extern void Register_Shape(OperatorRegistryBase *op_registry); extern void Register_Split(OperatorRegistryBase *op_registry); extern void Register_Softmax(OperatorRegistryBase *op_registry); extern void Register_Stack(OperatorRegistryBase *op_registry); +extern void Register_Unstack(OperatorRegistryBase *op_registry); extern void Register_StridedSlice(OperatorRegistryBase *op_registry); extern void Register_SpaceToBatchND(OperatorRegistryBase *op_registry); extern void Register_SpaceToDepth(OperatorRegistryBase *op_registry); @@ -64,6 +65,7 @@ extern void Register_WinogradTransform(OperatorRegistryBase *op_registry); #ifdef MACE_ENABLE_OPENCL extern void Register_BufferToImage(OperatorRegistryBase *op_registry); extern void Register_ImageToBuffer(OperatorRegistryBase *op_registry); +extern void Register_LSTMCell(OperatorRegistryBase *op_registry); #endif // MACE_ENABLE_OPENCL } // namespace ops @@ -105,6 +107,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() { ops::Register_Split(this); ops::Register_Softmax(this); ops::Register_Stack(this); + ops::Register_Unstack(this); ops::Register_StridedSlice(this); ops::Register_SpaceToBatchND(this); ops::Register_SpaceToDepth(this); @@ -116,6 +119,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() { #ifdef MACE_ENABLE_OPENCL ops::Register_BufferToImage(this); ops::Register_ImageToBuffer(this); + ops::Register_LSTMCell(this); #endif // MACE_ENABLE_OPENCL } diff --git a/mace/ops/ops_test_util.h b/mace/ops/ops_test_util.h index 568cbb8f..2dc29241 100644 --- a/mace/ops/ops_test_util.h +++ b/mace/ops/ops_test_util.h @@ -201,6 +201,21 @@ class OpsTestNet { } } + template + void CopyData(const std::string &src_name, + const std::string &dst_name) { + Tensor *input = ws_.GetTensor(src_name); + Tensor *output = ws_.CreateTensor(dst_name, GetDeviceAllocator(D), + DataTypeToEnum::v()); + + const std::vector input_shape = input->shape(); + output->Resize(input_shape); + + Tensor::MappingGuard input_guard(input); + const T *input_data = input->data(); + output->CopyBytes(input->raw_data(), input->size() * input->SizeOfType()); + } + template void TransformDataFormat(const std::string &src_name, const DataFormat src_format, diff --git a/mace/ops/unstack.cc b/mace/ops/unstack.cc new file mode 100644 index 00000000..7b1c815b --- /dev/null +++ b/mace/ops/unstack.cc @@ -0,0 +1,34 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/unstack.h" + +namespace mace { +namespace ops { + +void Register_Unstack(OperatorRegistryBase *op_registry) { + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Unstack") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + UnstackOp); + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Unstack") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + UnstackOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/unstack.h b/mace/ops/unstack.h new file mode 100644 index 00000000..1f743bd5 --- /dev/null +++ b/mace/ops/unstack.h @@ -0,0 +1,49 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_UNSTACK_H_ +#define MACE_OPS_UNSTACK_H_ + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/unstack.h" + +namespace mace { +namespace ops { + +template +class UnstackOp : public Operator { + public: + UnstackOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws), + functor_(OperatorBase::GetOptionalArg("axis", 0)) {} + + MaceStatus Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + const std::vector outputs = this->Outputs(); + return functor_(input, outputs, future); + } + + private: + kernels::UnstackFunctor functor_; + + protected: + MACE_OP_OUTPUT_TAGS(INPUT); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_UNSTACK_H_ diff --git a/mace/ops/unstack_test.cc b/mace/ops/unstack_test.cc new file mode 100644 index 00000000..674ec0ae --- /dev/null +++ b/mace/ops/unstack_test.cc @@ -0,0 +1,77 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class UnstackOpTest : public OpsTestBase {}; + +namespace { + +void TestUnstack(const std::vector &input_shape, + const std::vector &input, + int axis, + const std::vector &output_shape, + const std::vector> &outputs) { + OpsTestNet net; + net.AddInputFromArray("Input", input_shape, input); + + auto op_builder = OpDefBuilder("Unstack", "UnstackOpTest") + .Input("Input") + .AddIntArg("axis", axis); + + for (size_t i = 0; i < outputs.size(); ++i) { + op_builder.Output(MakeString("Output", i)); + } + op_builder.Finalize(net.NewOperatorDef()); + + net.RunOp(); + + for (size_t i = 0; i < outputs.size(); ++i) { + LOG(INFO) << MakeString("Output", i); + net.AddInputFromArray("ExpectedOutput", output_shape, + outputs[i]); + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput(MakeString("Output", i).c_str())); + } +} + +} // namespace + +TEST_F(UnstackOpTest, TestUnstackScalar) { + TestUnstack({3}, {1, 2, 3}, 0, {}, {{1}, {2}, {3}}); +} + +TEST_F(UnstackOpTest, TestUnstackVector) { + TestUnstack({3, 2}, {1, 4, 2, 5, 3, 6}, 0, {2}, {{1, 4}, {2, 5}, {3, 6}}); + TestUnstack({3, 2}, {1, 4, 2, 5, 3, 6}, -2, {2}, {{1, 4}, {2, 5}, {3, 6}}); + TestUnstack({2, 3}, {1, 2, 3, 4, 5, 6}, 1, {2}, {{1, 4}, {2, 5}, {3, 6}}); +} + +TEST_F(UnstackOpTest, TestUnstackHighRank) { + TestUnstack({2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, -3, {2, 3}, + {{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11, 12}}); + TestUnstack({2, 2, 3}, {1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}, 1, {2, 3}, + {{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11, 12}}); + TestUnstack({2, 3, 2}, {1, 7, 2, 8, 3, 9, 4, 10, 5, 11, 6, 12}, 2, {2, 3}, + {{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11, 12}}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/python/tools/convert_util.py b/mace/python/tools/convert_util.py index 31590913..97cf7f98 100644 --- a/mace/python/tools/convert_util.py +++ b/mace/python/tools/convert_util.py @@ -44,18 +44,26 @@ def calculate_image_shape(buffer_type, shape, winograd_blk_size=0): image_shape[0] = shape[1] image_shape[1] = shape[2] * shape[3] * roundup_div4(shape[0]) elif buffer_type == OpenCLBufferType.IN_OUT_CHANNEL: - mace_check(len(shape) == 4, "Conv2D input/output buffer should be 4D") - image_shape[0] = roundup_div4(shape[3]) * shape[2] - image_shape[1] = shape[0] * shape[1] + mace_check(len(shape) == 2 or len(shape) == 4, + "input/output buffer should be 2D|4D") + if len(shape) == 4: + image_shape[0] = roundup_div4(shape[3]) * shape[2] + image_shape[1] = shape[0] * shape[1] + elif len(shape) == 2: + image_shape[0] = roundup_div4(shape[1]) + image_shape[1] = shape[0] elif buffer_type == OpenCLBufferType.ARGUMENT: mace_check(len(shape) == 1, "Argument buffer should be 1D not " + str(shape)) image_shape[0] = roundup_div4(shape[0]) image_shape[1] = 1 elif buffer_type == OpenCLBufferType.IN_OUT_HEIGHT: - mace_check(len(shape) == 4, "Input/output buffer should be 4D") - image_shape[0] = shape[2] * shape[3] - image_shape[1] = shape[0] * roundup_div4(shape[1]) + if len(shape) == 4: + image_shape[0] = shape[2] * shape[3] + image_shape[1] = shape[0] * roundup_div4(shape[1]) + elif len(shape) == 2: + image_shape[0] = shape[0] + image_shape[1] = roundup_div4(shape[1]) elif buffer_type == OpenCLBufferType.IN_OUT_WIDTH: mace_check(len(shape) == 4, "Input/output buffer should be 4D") image_shape[0] = roundup_div4(shape[2]) * shape[3] diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 99fac06f..8ceb3ff1 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -93,6 +93,7 @@ MaceSupportedOps = [ 'Gather', 'Identity', 'LocalResponseNorm', + 'LSTMCell', 'MatMul', 'Pad', 'Pooling', @@ -107,6 +108,7 @@ MaceSupportedOps = [ 'Shape', 'Squeeze', 'Stack', + 'Unstack', 'StridedSlice', 'Softmax', 'SpaceToBatchND', @@ -198,6 +200,9 @@ class TransformerRule(Enum): QUANTIZE_NODES = 23 ADD_QUANTIZE_TENSOR_RANGE = 24 QUANTIZE_WEIGHTS = 25 + TRANSPOSE_MATMUL_WEIGHT = 26 + TRANSFORM_LSTMCELL_ZEROSTATE = 27 + TRANSFORM_BASIC_LSTMCELL = 28 class ConverterInterface(object): @@ -336,6 +341,8 @@ class ConverterOption(object): # Model structure related transformation TransformerRule.REMOVE_IDENTITY_OP, TransformerRule.TRANSFORM_GLOBAL_POOLING, + TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE, + TransformerRule.TRANSFORM_BASIC_LSTMCELL, TransformerRule.FOLD_RESHAPE, TransformerRule.TRANSFORM_MATMUL_TO_FC, TransformerRule.FOLD_BATCHNORM, diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index 732abdce..17957e75 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -96,6 +96,8 @@ TFSupportedOps = [ 'Slice', 'Stack', 'Pack', + 'Unstack', + 'Unpack', 'Cast', 'ArgMax', 'Split', @@ -196,6 +198,8 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.Slice.name: self.convert_slice, TFOpType.Pack.name: self.convert_stack, TFOpType.Stack.name: self.convert_stack, + TFOpType.Unpack.name: self.convert_unstack, + TFOpType.Unstack.name: self.convert_unstack, TFOpType.Cast.name: self.convert_cast, TFOpType.ArgMax.name: self.convert_argmax, TFOpType.Split.name: self.convert_split, @@ -774,6 +778,17 @@ class TensorflowConverter(base_converter.ConverterInterface): except ValueError: axis_arg.i = 0 + def convert_unstack(self, tf_op): + op = self.convert_general_op(tf_op) + op.type = MaceOp.Unstack.name + + axis_arg = op.arg.add() + axis_arg.name = MaceKeyword.mace_axis_str + try: + axis_arg.i = tf_op.get_attr(MaceKeyword.mace_axis_str) + except ValueError: + axis_arg.i = 0 + def convert_cast(self, tf_op): op = self.convert_general_op(tf_op) op.type = MaceOp.Cast.name diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index dbb9d605..2fc019ac 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -15,6 +15,7 @@ import enum import numpy as np +import re from mace.proto import mace_pb2 from mace.python.tools.converter_tool import base_converter @@ -49,6 +50,10 @@ class Transformer(base_converter.ConverterInterface): TransformerRule.REMOVE_IDENTITY_OP: self.remove_identity_op, TransformerRule.TRANSFORM_GLOBAL_POOLING: self.transform_global_pooling, + TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE: + self.transform_lstmcell_zerostate, + TransformerRule.TRANSFORM_BASIC_LSTMCELL: + self.transform_basic_lstmcell, TransformerRule.FOLD_RESHAPE: self.fold_reshape, TransformerRule.TRANSFORM_MATMUL_TO_FC: self.transform_matmul_to_fc, @@ -332,6 +337,154 @@ class Transformer(base_converter.ConverterInterface): return False + def transform_lstmcell_zerostate(self): + net = self._model + + zero_state_pattern = \ + re.compile(r'^.*BasicLSTMCellZeroState_?[0-9]*/[a-zA-Z]+_?[0-9]*') # noqa + for op in net.op: + if op.type == MaceOp.Fill.name and \ + zero_state_pattern.match(op.name): + print("Transform lstm zerostate") + concat_op = self._producer[op.input[0]] + consumer_op = self._consumers[op.output[0]][0] + + dims = [self._consts[concat_op.input[0]].int32_data[0], + self._consts[concat_op.input[1]].int32_data[0]] + tensor_def = net.tensors.add() + tensor_def.name = op.output[0].replace('/zeros', '/init_const') + tensor_def.dims.extend(dims) + tensor_def.data_type = self._consts[op.input[1]].data_type + tensor_def.float_data.extend( + [self._consts[op.input[1]].float_data[0]] * + (dims[0] * dims[1])) + + for i in range(len(consumer_op.input)): + if zero_state_pattern.match(consumer_op.input[i][:-2]): + consumer_op.input[i] = tensor_def.name + + net.tensors.remove(self._consts[op.input[1]]) + net.tensors.remove(self._consts[concat_op.input[0]]) + net.tensors.remove(self._consts[concat_op.input[1]]) + + net.op.remove(concat_op) + net.op.remove(op) + + return True + + def transform_basic_lstmcell(self): + if self._option.device != DeviceType.GPU.value: + return False + + net = self._model + basic_lstm_concat_pattern = \ + re.compile(r'^.*basic_lstm_cell_?[0-9]*/concat_?[0-9]*') + for op in net.op: + if op.type == MaceOp.Concat.name and \ + basic_lstm_concat_pattern.match(op.name): + print("Transform basic lstmcell") + ops_to_delete = [] + ops_to_delete.extend([op]) + + op_def = net.op.add() + op_def.name = op.name.replace('/concat', '/folded_lstmcell') + op_def.type = MaceOp.LSTMCell.name + op_def.arg.extend(op.arg[:-1]) + + # Concat pre output and cur input + # extend concat inputs + op_def.input.extend([op_input for op_input in op.input]) + + # lstm MatMul in FC of [pre_output, cur_input] + matmul_op = self._consumers[op.output[0]][0] + ops_to_delete.extend([matmul_op]) + # extend MatMul weight input + op_def.input.extend([matmul_op.input[1]]) + + # lstm BiasAdd in FC of [pre_output, cur_input] + biasadd_op = self._consumers[matmul_op.output[0]][0] + ops_to_delete.extend([biasadd_op]) + # extend BiasAdd bias input + op_def.input.extend([biasadd_op.input[1]]) + + # Split FC output into i, j, f, o + # i = input_gate, j = new_input, f = forget_gate, o = output_gate # noqa + split_op = self._consumers[biasadd_op.output[0]][0] + ops_to_delete.extend([split_op]) + + # input gate activation + input_gate_op = self._consumers[split_op.output[0]][0] + ops_to_delete.extend([input_gate_op]) + # new input gate + new_input_tanh_op = self._consumers[split_op.output[1]][0] + ops_to_delete.extend([new_input_tanh_op]) + # forget gate add + forget_add_op = self._consumers[split_op.output[2]][0] + ops_to_delete.extend([forget_add_op]) + # output gate activation + output_gate_op = self._consumers[split_op.output[3]][0] + ops_to_delete.extend([output_gate_op]) + + # extend forget add + mace_check(len(forget_add_op.input) == 1, + 'Wrong LSTM format in forget gate inputs') + for arg in forget_add_op.arg: + if arg.name == MaceKeyword.mace_scalar_input_str: + op_def.arg.extend([arg]) + + # state remember + remember_mul_op = self._consumers[input_gate_op.output[0]][0] + ops_to_delete.extend([remember_mul_op]) + mace_check(remember_mul_op.name == self._consumers[ + new_input_tanh_op.output[0]][0].name, + 'Wrong LSTM format in input sig & input tanh mul') + + # forget gate activation + forget_gate_op = self._consumers[forget_add_op.output[0]][0] + ops_to_delete.extend([forget_gate_op]) + + # Mul `forget` & `pre cell state` + forget_mul_op = self._consumers[forget_gate_op.output[0]][0] + ops_to_delete.extend([forget_mul_op]) + + # extend pre cell state input + op_def.input.extend([forget_mul_op.input[0]]) + + # get cur cell state + # Add `forget gate output` & `remember mul output` + remember_forget_add_op = \ + self._consumers[remember_mul_op.output[0]][0] + ops_to_delete.extend([remember_forget_add_op]) + mace_check(remember_forget_add_op.name == + self._consumers[forget_mul_op.output[0]][0].name, + 'Wrong LSTM format in add forget gate & remember mul') # noqa + op_def.output.extend([remember_forget_add_op.output[0]]) + op_def.output_shape.extend(remember_forget_add_op.output_shape) + + # cell state output tanh + for consumer in \ + self._consumers[remember_forget_add_op.output[0]]: + if consumer.type == MaceOp.Activation.name and \ + consumer.name.find('basic_lstm_cell') > 0: + cell_tanh_op = consumer + ops_to_delete.extend([cell_tanh_op]) + + # final mul, get output + final_mul_op = self._consumers[cell_tanh_op.output[0]][0] + ops_to_delete.extend([final_mul_op]) + mace_check(final_mul_op.name == + self._consumers[output_gate_op.output[0]][0].name, + 'Wrong LSTM format in final mul') + op_def.output.extend([final_mul_op.output[0]]) + op_def.output_shape.extend(final_mul_op.output_shape) + + for op_to_del in ops_to_delete: + net.op.remove(op_to_del) + + return True + + return False + def fold_conv_and_bn(self): net = self._model for op in net.op: @@ -1156,6 +1309,15 @@ class Transformer(base_converter.ConverterInterface): if ConverterUtil.get_arg(op, MaceKeyword.mace_activation_type_str).s == ActivationType.PRELU.name: # noqa self.buffer_to_image(op, 1, OpenCLBufferType.ARGUMENT) + elif op.type == MaceOp.LSTMCell.name: + if op.input[1] in self._consts: + self.buffer_to_image(op, 1, + OpenCLBufferType.IN_OUT_CHANNEL) + self.buffer_to_image(op, 2, OpenCLBufferType.IN_OUT_CHANNEL) + self.buffer_to_image(op, 3, OpenCLBufferType.ARGUMENT) + if op.input[4] in self._consts: + self.buffer_to_image(op, 4, + OpenCLBufferType.IN_OUT_CHANNEL) # Add OpenCL max image size arg = net.arg.add() diff --git a/mace/python/tools/memory_optimizer.py b/mace/python/tools/memory_optimizer.py index c0f1ddd0..4864f066 100644 --- a/mace/python/tools/memory_optimizer.py +++ b/mace/python/tools/memory_optimizer.py @@ -240,7 +240,7 @@ class GPUMemoryOptimizer(MemoryOptimizer): op_type) else: if len(output_shape) == 2: # only support fc/softmax - buffer_shape = [output_shape[0], 1, 1, output_shape[1]] + buffer_shape = [output_shape[0], output_shape[1]] elif len(output_shape) == 4: buffer_shape = output_shape else: diff --git a/repository/opencl-kernel/opencl_kernel_configure.bzl b/repository/opencl-kernel/opencl_kernel_configure.bzl index 0d1b9cf0..dfb5da15 100644 --- a/repository/opencl-kernel/opencl_kernel_configure.bzl +++ b/repository/opencl-kernel/opencl_kernel_configure.bzl @@ -37,6 +37,7 @@ def _opencl_encrypt_kernel_impl(repository_ctx): unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/depthwise_conv2d.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/eltwise.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/fully_connected.cl")) + unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/lstmcell.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/matmul.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/pad.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/pooling.cl")) -- GitLab