diff --git a/mace/kernels/opencl/lstmcell.cc b/mace/kernels/opencl/lstmcell.cc index 4cfc98b2824ec726ee686bb7335e7dcd47baa98c..ffc185d0dc84b2019c473827e8d02edc141e1482 100644 --- a/mace/kernels/opencl/lstmcell.cc +++ b/mace/kernels/opencl/lstmcell.cc @@ -31,8 +31,12 @@ MaceStatus LSTMCellFunctor::operator()( Tensor *cell, Tensor *output, StatsFuture *future) { + MACE_CHECK(input->dim_size() == 2 && input->dim(1) % 4 == 0, + "LSTM step should be a multiple of 4"); + const index_t height = input->dim(0); const index_t width = input->dim(1); + const index_t width_blocks = width / 4; auto runtime = OpenCLRuntime::Global(); @@ -53,14 +57,13 @@ MaceStatus LSTMCellFunctor::operator()( static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); } - const index_t width_blocks = RoundUpDiv4(width); const uint32_t gws[2] = {static_cast(width_blocks), static_cast(height)}; if (!IsVecEqual(input_shape_, input->shape())) { - std::vector output_shape_paded = {height, 1, 1, width}; + std::vector output_shape_padded = {height, 1, 1, width}; std::vector output_image_shape; - CalImage2DShape(output_shape_paded, BufferType::IN_OUT_CHANNEL, + CalImage2DShape(output_shape_padded, BufferType::IN_OUT_CHANNEL, &output_image_shape); MACE_RETURN_IF_ERROR(output->ResizeImage(input->shape(), output_image_shape)); diff --git a/mace/ops/lstmcell.h b/mace/ops/lstmcell.h index c632cd3dee39c20b9e1cc4420f679c01bb1c0e63..300794f2341261a0ea13d1be0dffc48a3a6e1a78 100644 --- a/mace/ops/lstmcell.h +++ b/mace/ops/lstmcell.h @@ -40,9 +40,6 @@ class LSTMCellOp : public Operator { Tensor *cell = this->Output(CELL); Tensor *output = this->Output(OUTPUT); - MACE_CHECK(input->dim_size() == 2 && input->dim(1) % 4 == 0, - "LSTM step should be a multiple of 4"); - return functor_( input, pre_output, weight, bias, pre_cell, cell, output, future); }; diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 8ceb3ff12739b0294a650df0e1e53494a24eb1c6..91dfe6c6ca045df20ea13bb103258529f21a7956 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -200,9 +200,8 @@ class TransformerRule(Enum): QUANTIZE_NODES = 23 ADD_QUANTIZE_TENSOR_RANGE = 24 QUANTIZE_WEIGHTS = 25 - TRANSPOSE_MATMUL_WEIGHT = 26 - TRANSFORM_LSTMCELL_ZEROSTATE = 27 - TRANSFORM_BASIC_LSTMCELL = 28 + TRANSFORM_LSTMCELL_ZEROSTATE = 26 + TRANSFORM_BASIC_LSTMCELL = 27 class ConverterInterface(object):