提交 63135b06 编写于 作者: Y yejianwu

fix typo, clean redundant enum in transformer

上级 6439dce1
......@@ -31,8 +31,12 @@ MaceStatus LSTMCellFunctor<DeviceType::GPU, T>::operator()(
Tensor *cell,
Tensor *output,
StatsFuture *future) {
MACE_CHECK(input->dim_size() == 2 && input->dim(1) % 4 == 0,
"LSTM step should be a multiple of 4");
const index_t height = input->dim(0);
const index_t width = input->dim(1);
const index_t width_blocks = width / 4;
auto runtime = OpenCLRuntime::Global();
......@@ -53,14 +57,13 @@ MaceStatus LSTMCellFunctor<DeviceType::GPU, T>::operator()(
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_));
}
const index_t width_blocks = RoundUpDiv4(width);
const uint32_t gws[2] = {static_cast<uint32_t>(width_blocks),
static_cast<uint32_t>(height)};
if (!IsVecEqual(input_shape_, input->shape())) {
std::vector<index_t> output_shape_paded = {height, 1, 1, width};
std::vector<index_t> output_shape_padded = {height, 1, 1, width};
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape_paded, BufferType::IN_OUT_CHANNEL,
CalImage2DShape(output_shape_padded, BufferType::IN_OUT_CHANNEL,
&output_image_shape);
MACE_RETURN_IF_ERROR(output->ResizeImage(input->shape(),
output_image_shape));
......
......@@ -40,9 +40,6 @@ class LSTMCellOp : public Operator<D, T> {
Tensor *cell = this->Output(CELL);
Tensor *output = this->Output(OUTPUT);
MACE_CHECK(input->dim_size() == 2 && input->dim(1) % 4 == 0,
"LSTM step should be a multiple of 4");
return functor_(
input, pre_output, weight, bias, pre_cell, cell, output, future);
};
......
......@@ -200,9 +200,8 @@ class TransformerRule(Enum):
QUANTIZE_NODES = 23
ADD_QUANTIZE_TENSOR_RANGE = 24
QUANTIZE_WEIGHTS = 25
TRANSPOSE_MATMUL_WEIGHT = 26
TRANSFORM_LSTMCELL_ZEROSTATE = 27
TRANSFORM_BASIC_LSTMCELL = 28
TRANSFORM_LSTMCELL_ZEROSTATE = 26
TRANSFORM_BASIC_LSTMCELL = 27
class ConverterInterface(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册