提交 63135b06 编写于 作者: Y yejianwu

fix typo, clean redundant enum in transformer

上级 6439dce1
...@@ -31,8 +31,12 @@ MaceStatus LSTMCellFunctor<DeviceType::GPU, T>::operator()( ...@@ -31,8 +31,12 @@ MaceStatus LSTMCellFunctor<DeviceType::GPU, T>::operator()(
Tensor *cell, Tensor *cell,
Tensor *output, Tensor *output,
StatsFuture *future) { StatsFuture *future) {
MACE_CHECK(input->dim_size() == 2 && input->dim(1) % 4 == 0,
"LSTM step should be a multiple of 4");
const index_t height = input->dim(0); const index_t height = input->dim(0);
const index_t width = input->dim(1); const index_t width = input->dim(1);
const index_t width_blocks = width / 4;
auto runtime = OpenCLRuntime::Global(); auto runtime = OpenCLRuntime::Global();
...@@ -53,14 +57,13 @@ MaceStatus LSTMCellFunctor<DeviceType::GPU, T>::operator()( ...@@ -53,14 +57,13 @@ MaceStatus LSTMCellFunctor<DeviceType::GPU, T>::operator()(
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_)); static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_));
} }
const index_t width_blocks = RoundUpDiv4(width);
const uint32_t gws[2] = {static_cast<uint32_t>(width_blocks), const uint32_t gws[2] = {static_cast<uint32_t>(width_blocks),
static_cast<uint32_t>(height)}; static_cast<uint32_t>(height)};
if (!IsVecEqual(input_shape_, input->shape())) { if (!IsVecEqual(input_shape_, input->shape())) {
std::vector<index_t> output_shape_paded = {height, 1, 1, width}; std::vector<index_t> output_shape_padded = {height, 1, 1, width};
std::vector<size_t> output_image_shape; std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape_paded, BufferType::IN_OUT_CHANNEL, CalImage2DShape(output_shape_padded, BufferType::IN_OUT_CHANNEL,
&output_image_shape); &output_image_shape);
MACE_RETURN_IF_ERROR(output->ResizeImage(input->shape(), MACE_RETURN_IF_ERROR(output->ResizeImage(input->shape(),
output_image_shape)); output_image_shape));
......
...@@ -40,9 +40,6 @@ class LSTMCellOp : public Operator<D, T> { ...@@ -40,9 +40,6 @@ class LSTMCellOp : public Operator<D, T> {
Tensor *cell = this->Output(CELL); Tensor *cell = this->Output(CELL);
Tensor *output = this->Output(OUTPUT); Tensor *output = this->Output(OUTPUT);
MACE_CHECK(input->dim_size() == 2 && input->dim(1) % 4 == 0,
"LSTM step should be a multiple of 4");
return functor_( return functor_(
input, pre_output, weight, bias, pre_cell, cell, output, future); input, pre_output, weight, bias, pre_cell, cell, output, future);
}; };
......
...@@ -200,9 +200,8 @@ class TransformerRule(Enum): ...@@ -200,9 +200,8 @@ class TransformerRule(Enum):
QUANTIZE_NODES = 23 QUANTIZE_NODES = 23
ADD_QUANTIZE_TENSOR_RANGE = 24 ADD_QUANTIZE_TENSOR_RANGE = 24
QUANTIZE_WEIGHTS = 25 QUANTIZE_WEIGHTS = 25
TRANSPOSE_MATMUL_WEIGHT = 26 TRANSFORM_LSTMCELL_ZEROSTATE = 26
TRANSFORM_LSTMCELL_ZEROSTATE = 27 TRANSFORM_BASIC_LSTMCELL = 27
TRANSFORM_BASIC_LSTMCELL = 28
class ConverterInterface(object): class ConverterInterface(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册