diff --git a/mace/kernels/opencl/cl/lstmcell.cl b/mace/kernels/opencl/cl/lstmcell.cl index 140132bd928762a6f2554fecf5832fb2dbaaf5be..c020eb56c7c296cc1a8e0a62cb098924f6223cff 100644 --- a/mace/kernels/opencl/cl/lstmcell.cl +++ b/mace/kernels/opencl/cl/lstmcell.cl @@ -9,6 +9,8 @@ __kernel void lstmcell(KERNEL_ERROR_PARAMS __read_only image2d_t pre_cell, __private const float forget_bias, __private const int width, + __private const int hidden_units, + __private const int in_w_blk, __write_only image2d_t cell, __write_only image2d_t output) { const int w_blk_idx = get_global_id(0); @@ -25,114 +27,97 @@ __kernel void lstmcell(KERNEL_ERROR_PARAMS DATA_TYPE4 fc_res0 = 0.0, fc_res1 = 0.0, fc_res2 = 0.0, fc_res3 = 0.0; DATA_TYPE4 in, pre_h; DATA_TYPE4 w0, w1, w2, w3; + int k_offset; // concat matmul - for (short i = 0; i < global_size_dim0; ++i) { + for (short i = 0; i < in_w_blk; ++i) { in = READ_IMAGET(input, SAMPLER, (int2)(i, h_idx)); - short k = 4 * i; + int k = i << 2; w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k)); w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k)); w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k)); w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k)); - fc_res0 += in.x * w0; fc_res1 += in.x * w1; fc_res2 += in.x * w2; fc_res3 += in.x * w3; - k = 4 * i + 1; - if (k < width) { - w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k)); - w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k)); - w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k)); - w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k)); - - fc_res0 += in.y * w0; - fc_res1 += in.y * w1; - fc_res2 += in.y * w2; - fc_res3 += in.y * w3; - } - - k = 4 * i + 2; - if (k < width) { - w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k)); - w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k)); - w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k)); - w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k)); - - fc_res0 += in.z * w0; - fc_res1 += in.z * w1; - fc_res2 += in.z * w2; - fc_res3 += in.z * w3; - } - - k = 4 * i + 3; - if (k < width) { - w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k)); - w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k)); - w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k)); - w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k)); - - fc_res0 += in.w * w0; - fc_res1 += in.w * w1; - fc_res2 += in.w * w2; - fc_res3 += in.w * w3; - } + k += 1; + k_offset = select(-1, k, k < width); + w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k_offset)); + w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k_offset)); + w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k_offset)); + w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k_offset)); + fc_res0 += in.y * w0; + fc_res1 += in.y * w1; + fc_res2 += in.y * w2; + fc_res3 += in.y * w3; + + k += 1; + k_offset = select(-1, k, k < width); + w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k_offset)); + w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k_offset)); + w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k_offset)); + w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k_offset)); + fc_res0 += in.z * w0; + fc_res1 += in.z * w1; + fc_res2 += in.z * w2; + fc_res3 += in.z * w3; + + k += 1; + k_offset = select(-1, k, k < width); + w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k_offset)); + w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k_offset)); + w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k_offset)); + w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k_offset)); + fc_res0 += in.w * w0; + fc_res1 += in.w * w1; + fc_res2 += in.w * w2; + fc_res3 += in.w * w3; } for (short i = 0; i < global_size_dim0; ++i) { pre_h = READ_IMAGET(pre_output, SAMPLER, (int2)(i, h_idx)); - short k = 4 * (i + global_size_dim0); - short k_limit = 4 * global_size_dim0 + width; + int k = (i << 2) + width; w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k)); w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k)); w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k)); w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k)); - fc_res0 += pre_h.x * w0; fc_res1 += pre_h.x * w1; fc_res2 += pre_h.x * w2; fc_res3 += pre_h.x * w3; - k = 4 * (i + global_size_dim0) + 1; - if (k < k_limit) { - w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k)); - w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k)); - w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k)); - w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k)); - - fc_res0 += pre_h.y * w0; - fc_res1 += pre_h.y * w1; - fc_res2 += pre_h.y * w2; - fc_res3 += pre_h.y * w3; - } - - k = 4 * (i + global_size_dim0) + 2; - if (k < k_limit) { - w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k)); - w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k)); - w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k)); - w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k)); - - fc_res0 += pre_h.z * w0; - fc_res1 += pre_h.z * w1; - fc_res2 += pre_h.z * w2; - fc_res3 += pre_h.z * w3; - } + k += 1; + w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k)); + w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k)); + w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k)); + w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k)); + fc_res0 += pre_h.y * w0; + fc_res1 += pre_h.y * w1; + fc_res2 += pre_h.y * w2; + fc_res3 += pre_h.y * w3; - k = 4 * (i + global_size_dim0) + 3; - if (k < k_limit) { - w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k)); - w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k)); - w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k)); - w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k)); + k += 1; + w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k)); + w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k)); + w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k)); + w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k)); + fc_res0 += pre_h.z * w0; + fc_res1 += pre_h.z * w1; + fc_res2 += pre_h.z * w2; + fc_res3 += pre_h.z * w3; - fc_res0 += pre_h.w * w0; - fc_res1 += pre_h.w * w1; - fc_res2 += pre_h.w * w2; - fc_res3 += pre_h.w * w3; - } + k += 1; + w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k)); + w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k)); + w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k)); + w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k)); + fc_res0 += pre_h.w * w0; + fc_res1 += pre_h.w * w1; + fc_res2 += pre_h.w * w2; + fc_res3 += pre_h.w * w3; } // bias diff --git a/mace/kernels/opencl/lstmcell.cc b/mace/kernels/opencl/lstmcell.cc index 6704c0b457d876c28a590860e6ad866ff24228ad..df351539834d6e8ea1849d6b512c7cd63e9297e9 100644 --- a/mace/kernels/opencl/lstmcell.cc +++ b/mace/kernels/opencl/lstmcell.cc @@ -31,12 +31,13 @@ MaceStatus LSTMCellFunctor::operator()( Tensor *cell, Tensor *output, StatsFuture *future) { - MACE_CHECK(input->dim_size() == 2 && input->dim(1) % 4 == 0, - "LSTM step should be a multiple of 4"); + MACE_CHECK(pre_output->dim_size() == 2 && pre_output->dim(1) % 4 == 0, + "LSTM hidden units should be a multiple of 4"); const index_t height = input->dim(0); const index_t width = input->dim(1); - const index_t width_blocks = width / 4; + const index_t hidden_units = pre_output->dim(1); + const index_t w_blocks = hidden_units >> 2; auto runtime = context_->device()->opencl_runtime(); @@ -57,17 +58,18 @@ MaceStatus LSTMCellFunctor::operator()( static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); } - const uint32_t gws[2] = {static_cast(width_blocks), + const uint32_t gws[2] = {static_cast(w_blocks), static_cast(height)}; if (!IsVecEqual(input_shape_, input->shape())) { - std::vector output_shape_padded = {height, 1, 1, width}; + std::vector output_shape_padded = {height, 1, 1, hidden_units}; std::vector output_image_shape; CalImage2DShape(output_shape_padded, BufferType::IN_OUT_CHANNEL, &output_image_shape); - MACE_RETURN_IF_ERROR(output->ResizeImage(input->shape(), + MACE_RETURN_IF_ERROR(output->ResizeImage(pre_output->shape(), output_image_shape)); - MACE_RETURN_IF_ERROR(cell->ResizeImage(input->shape(), output_image_shape)); + MACE_RETURN_IF_ERROR(cell->ResizeImage(pre_cell->shape(), + output_image_shape)); uint32_t idx = 0; OUT_OF_RANGE_SET_ARG; @@ -79,6 +81,8 @@ MaceStatus LSTMCellFunctor::operator()( kernel_.setArg(idx++, *(pre_cell->opencl_image())); kernel_.setArg(idx++, static_cast(forget_bias_)); kernel_.setArg(idx++, static_cast(width)); + kernel_.setArg(idx++, static_cast(hidden_units)); + kernel_.setArg(idx++, static_cast(RoundUpDiv4(width))); kernel_.setArg(idx++, *(cell->opencl_image())); kernel_.setArg(idx++, *(output->opencl_image())); diff --git a/mace/ops/BUILD b/mace/ops/BUILD index 312bdc90babe7d04574a6823455893085771212e..342dbc0d4c6b8ce6728b3d276e61bdec00f6a134 100644 --- a/mace/ops/BUILD +++ b/mace/ops/BUILD @@ -20,9 +20,9 @@ load( cc_library( name = "test", testonly = 1, - hdrs = [ - "ops_test_util.h", - ], + hdrs = glob([ + "*_test_util.h", + ]), srcs = [ "ops_test_util.cc", ], @@ -67,7 +67,7 @@ cc_library( ), hdrs = glob( ["*.h"], - exclude = ["ops_test_util.h"], + exclude = glob(["*_test_util.h"]), ), copts = [ "-Werror", diff --git a/mace/ops/lstmcell_benchmark.cc b/mace/ops/lstmcell_benchmark.cc index a465485e8d4f09a1e9895e7e401adea35c20c9cf..6ab6baa1b1b0aafdbe8540d435f305f01d433971 100644 --- a/mace/ops/lstmcell_benchmark.cc +++ b/mace/ops/lstmcell_benchmark.cc @@ -15,6 +15,7 @@ #include "mace/core/operator.h" #include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/testing/test_benchmark.h" +#include "mace/ops/lstmcell_test_util.h" #include "mace/ops/ops_test_util.h" namespace mace { @@ -23,23 +24,31 @@ namespace test { namespace { template -void LSTMCell(int iters, int batch, int lstm_step) { +void LSTMCell(int iters, int batch, int input_size, int hidden_units) { mace::testing::StopTiming(); OpsTestNet net; // Add input data - if (D == DeviceType::GPU) { - net.AddRandomInput("Input", {batch, lstm_step}); - net.AddRandomInput("PreOutput", {batch, lstm_step}); - net.AddRandomInput("Weight", {2 * lstm_step, 4 * lstm_step}); - net.AddRandomInput("Bias", {4 * lstm_step}); - net.AddRandomInput("PreCell", {batch, lstm_step}); - } else { - MACE_NOT_IMPLEMENTED; - } + net.AddRandomInput("Input", {batch, input_size}); + net.AddRandomInput("PreOutput", {batch, hidden_units}); + net.AddRandomInput("Weight", {input_size + hidden_units, + 4 * hidden_units}); + net.AddRandomInput("Bias", {4 * hidden_units}); + net.AddRandomInput("PreCell", {batch, hidden_units}); - if (D == DeviceType::GPU) { + const float &forget_add = 0.0f; + + if (D == DeviceType::CPU) { + net.CopyData("Input", "InputCPU"); + net.CopyData("PreOutput", "PreOutputCPU"); + net.CopyData("Weight", "WeightCPU"); + net.CopyData("Bias", "BiasCPU"); + net.CopyData("PreCell", "PreCellCPU"); + + LSTMCellCPU(&net, "InputCPU", "PreOutputCPU", "WeightCPU", "BiasCPU", + "PreCellCPU", forget_add, "CellCPU", "OutputCPU"); + } else if (D == DeviceType::GPU) { BufferToImage(&net, "Input", "InputImage", kernels::BufferType::IN_OUT_CHANNEL); BufferToImage(&net, "PreOutput", "PreOutputImage", @@ -49,7 +58,7 @@ void LSTMCell(int iters, int batch, int lstm_step) { BufferToImage(&net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT); BufferToImage(&net, "PreCell", "PreCellImage", - kernels::BufferType::IN_OUT_CHANNEL); + kernels::BufferType::IN_OUT_CHANNEL); OpDefBuilder("LSTMCell", "LSTMCellTest") .Input("InputImage") @@ -57,7 +66,7 @@ void LSTMCell(int iters, int batch, int lstm_step) { .Input("WeightImage") .Input("BiasImage") .Input("PreCellImage") - .AddFloatArg("forget_add", 0.0f) + .AddFloatArg("scalar_input", forget_add) .Output("CellImage") .Output("OutputImage") .Finalize(net.NewOperatorDef()); @@ -79,27 +88,30 @@ void LSTMCell(int iters, int batch, int lstm_step) { } } // namespace -#define MACE_BM_LSTMCELL_MACRO(N, LSTM_STEP, TYPE, DEVICE) \ - static void MACE_BM_LSTMCELL_##N##_##LSTM_STEP##_##TYPE##_##DEVICE( \ - int iters) { \ - const int64_t macc = \ - static_cast(iters) * N * 2 * LSTM_STEP * 4 * LSTM_STEP; \ - const int64_t tot = static_cast(iters) * N * LSTM_STEP; \ - mace::testing::MaccProcessed(macc); \ - mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ - LSTMCell(iters, N, LSTM_STEP); \ - } \ - MACE_BENCHMARK(MACE_BM_LSTMCELL_##N##_##LSTM_STEP##_##TYPE##_##DEVICE) +#define MACE_BM_LSTMCELL_MACRO(N, INPUT_SIZE, HIDDEN_UNITS, TYPE, DEVICE) \ + static void \ + MACE_BM_LSTMCELL_##N##_##INPUT_SIZE##_##HIDDEN_UNITS##_##TYPE##_##DEVICE(\ + int iters) { \ + const int64_t macc = \ + static_cast( \ + iters) * N * (INPUT_SIZE + HIDDEN_UNITS) * 4 * HIDDEN_UNITS; \ + const int64_t tot = static_cast(iters) * N * INPUT_SIZE; \ + mace::testing::MaccProcessed(macc); \ + mace::testing::BytesProcessed(tot * (sizeof(TYPE))); \ + LSTMCell(iters, N, INPUT_SIZE, HIDDEN_UNITS); \ + } \ + MACE_BENCHMARK( \ + MACE_BM_LSTMCELL_##N##_##INPUT_SIZE##_##HIDDEN_UNITS##_##TYPE##_##DEVICE) -#define MACE_BM_LSTMCELL(N, LSTM_STEP) \ - MACE_BM_LSTMCELL_MACRO(N, LSTM_STEP, float, GPU); \ - MACE_BM_LSTMCELL_MACRO(N, LSTM_STEP, half, GPU); +#define MACE_BM_LSTMCELL(N, INPUT_SIZE, HIDDEN_UNITS) \ + MACE_BM_LSTMCELL_MACRO(N, INPUT_SIZE, HIDDEN_UNITS, float, CPU); \ + MACE_BM_LSTMCELL_MACRO(N, INPUT_SIZE, HIDDEN_UNITS, float, GPU); \ + MACE_BM_LSTMCELL_MACRO(N, INPUT_SIZE, HIDDEN_UNITS, half, GPU); -MACE_BM_LSTMCELL(1, 200); -MACE_BM_LSTMCELL(20, 200); -MACE_BM_LSTMCELL(20, 320); -MACE_BM_LSTMCELL(32, 400); -MACE_BM_LSTMCELL(32, 640); +MACE_BM_LSTMCELL(1, 64, 256); +MACE_BM_LSTMCELL(30, 64, 256); +MACE_BM_LSTMCELL(50, 64, 256); +MACE_BM_LSTMCELL(80, 64, 256); } // namespace test } // namespace ops } // namespace mace diff --git a/mace/ops/lstmcell_test.cc b/mace/ops/lstmcell_test.cc index 109096c5ffec45f6d022108017c55b5c4f609799..1cfaad0179f24ce62ea99f1ea0d3a067711e0f38 100644 --- a/mace/ops/lstmcell_test.cc +++ b/mace/ops/lstmcell_test.cc @@ -14,6 +14,7 @@ #include "mace/core/operator.h" #include "mace/kernels/eltwise.h" +#include "mace/ops/lstmcell_test_util.h" #include "mace/ops/ops_test_util.h" namespace mace { @@ -23,128 +24,20 @@ namespace test { class LSTMCellTest : public OpsTestBase {}; namespace { - -template -void LSTMCellCPU(OpsTestNet *net, - const std::string &input_name, - const std::string &pre_output_name, - const std::string &weight_name, - const std::string &bias_name, - const std::string &pre_cell_name, - const float &forget_add_name, - const std::string &cell_name, - const std::string &output_name) { - OpDefBuilder("Concat", "Concat") - .Input(input_name) - .Input(pre_output_name) - .AddIntArg("axis", 1) - .Output("ConcatOutput") - .Finalize(net->AddNewOperatorDef()); - - OpDefBuilder("MatMul", "MatMul") - .Input("ConcatOutput") - .Input(weight_name) - .Output("MatMulOutput") - .Finalize(net->AddNewOperatorDef()); - - OpDefBuilder("BiasAdd", "BiasAdd") - .Input("MatMulOutput") - .Input(bias_name) - .Output("BiasOutput") - .Finalize(net->AddNewOperatorDef()); - - OpDefBuilder("Split", "FCSplit") - .Input("BiasOutput") - .AddIntArg("axis", 1) - .Output("SplitOutput0") - .Output("SplitOutput1") - .Output("SplitOutput2") - .Output("SplitOutput3") - .Finalize(net->AddNewOperatorDef()); - - OpDefBuilder("Activation", "InputSigmoid") - .Input("SplitOutput0") - .AddStringArg("activation", "SIGMOID") - .Output("InputSigmoid") - .Finalize(net->AddNewOperatorDef()); - - OpDefBuilder("Activation", "NewInputTanh") - .Input("SplitOutput1") - .AddStringArg("activation", "TANH") - .Output("NewInputTanh") - .Finalize(net->AddNewOperatorDef()); - - OpDefBuilder("Eltwise", "RememberMul") - .Input("InputSigmoid") - .Input("NewInputTanh") - .AddIntArg("T", DataTypeToEnum::v()) - .AddIntArg("type", static_cast(kernels::EltwiseType::PROD)) - .Output("RememberMul") - .Finalize(net->AddNewOperatorDef()); - - OpDefBuilder("Eltwise", "ForgetAdd") - .Input("SplitOutput2") - .AddFloatArg("scalar_input", forget_add_name) - .AddIntArg("T", DataTypeToEnum::v()) - .AddIntArg("type", static_cast(kernels::EltwiseType::SUM)) - .Output("ForgetAdd") - .Finalize(net->AddNewOperatorDef()); - - OpDefBuilder("Activation", "ForgetSigmoid") - .Input("ForgetAdd") - .AddStringArg("activation", "SIGMOID") - .Output("ForgetSigmoid") - .Finalize(net->AddNewOperatorDef()); - - OpDefBuilder("Eltwise", "ForgetMul") - .Input("ForgetSigmoid") - .Input(pre_cell_name) - .AddIntArg("T", DataTypeToEnum::v()) - .AddIntArg("type", static_cast(kernels::EltwiseType::PROD)) - .Output("ForgetMulPreCell") - .Finalize(net->AddNewOperatorDef()); - - OpDefBuilder("Eltwise", "Cell") - .Input("RememberMul") - .Input("ForgetMulPreCell") - .AddIntArg("T", DataTypeToEnum::v()) - .AddIntArg("type", static_cast(kernels::EltwiseType::SUM)) - .Output(cell_name) - .Finalize(net->AddNewOperatorDef()); - - OpDefBuilder("Activation", "CellTanh") - .Input(cell_name) - .AddStringArg("activation", "TANH") - .Output("CellTanh") - .Finalize(net->AddNewOperatorDef()); - - OpDefBuilder("Activation", "OutputSigmoid") - .Input("SplitOutput3") - .AddStringArg("activation", "SIGMOID") - .Output("OutputSigmoid") - .Finalize(net->AddNewOperatorDef()); - - OpDefBuilder("Eltwise", "FinalMul") - .Input("OutputSigmoid") - .Input("CellTanh") - .AddIntArg("T", DataTypeToEnum::v()) - .AddIntArg("type", static_cast(kernels::EltwiseType::PROD)) - .Output(output_name) - .Finalize(net->AddNewOperatorDef()); -} - template void TestLSTMCell(const uint32_t &batch, - const uint32_t &lstm_step, + const uint32_t &input_size, + const uint32_t &hidden_units, const float &forget_add) { // Construct graph OpsTestNet net; - net.AddRandomInput("Input", {batch, lstm_step}); - net.AddRandomInput("PreOutput", {batch, lstm_step}); - net.AddRandomInput("Weight", {2 * lstm_step, 4 * lstm_step}); - net.AddRandomInput("Bias", {4 * lstm_step}); - net.AddRandomInput("PreCell", {batch, lstm_step}); + net.AddRandomInput("Input", {batch, input_size}); + net.AddRandomInput("PreOutput", {batch, hidden_units}); + net.AddRandomInput("Weight", {input_size + hidden_units, + 4 * hidden_units}); + net.AddRandomInput("Bias", {4 * hidden_units}); + net.AddRandomInput("PreCell", {batch, hidden_units}); net.CopyData("Input", "InputCPU"); net.CopyData("PreOutput", "PreOutputCPU"); @@ -205,17 +98,17 @@ void TestLSTMCell(const uint32_t &batch, } // namespace TEST_F(LSTMCellTest, OPENCLRandomHalf) { - TestLSTMCell(1, 4, 0.0f); - TestLSTMCell(2, 16, 0.0f); - TestLSTMCell(2, 200, 0.5f); - TestLSTMCell(20, 320, 0.5f); + TestLSTMCell(1, 3, 8, 0.0f); + TestLSTMCell(2, 16, 24, 0.0f); + TestLSTMCell(2, 200, 280, 0.5f); + TestLSTMCell(20, 320, 512, 0.5f); } TEST_F(LSTMCellTest, OPENCLRandomFloat) { - TestLSTMCell(1, 4, 0.0f); - TestLSTMCell(2, 16, 0.0f); - TestLSTMCell(2, 200, 0.5f); - TestLSTMCell(20, 320, 0.5f); + TestLSTMCell(1, 3, 8, 0.0f); + TestLSTMCell(2, 16, 24, 0.0f); + TestLSTMCell(2, 200, 280, 0.5f); + TestLSTMCell(20, 320, 512, 0.5f); } } // namespace test diff --git a/mace/ops/lstmcell_test_util.h b/mace/ops/lstmcell_test_util.h new file mode 100644 index 0000000000000000000000000000000000000000..06d711516a903ff0119eb89a6b1c92ad6a03d030 --- /dev/null +++ b/mace/ops/lstmcell_test_util.h @@ -0,0 +1,141 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_LSTMCELL_TEST_UTIL_H_ +#define MACE_OPS_LSTMCELL_TEST_UTIL_H_ + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/eltwise.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +template +void LSTMCellCPU(OpsTestNet *net, + const std::string &input_name, + const std::string &pre_output_name, + const std::string &weight_name, + const std::string &bias_name, + const std::string &pre_cell_name, + const float &forget_add_name, + const std::string &cell_name, + const std::string &output_name) { + OpDefBuilder("Concat", "Concat") + .Input(input_name) + .Input(pre_output_name) + .AddIntArg("axis", 1) + .Output("ConcatOutput") + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("MatMul", "MatMul") + .Input("ConcatOutput") + .Input(weight_name) + .Output("MatMulOutput") + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("BiasAdd", "BiasAdd") + .Input("MatMulOutput") + .Input(bias_name) + .Output("BiasOutput") + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("Split", "FCSplit") + .Input("BiasOutput") + .AddIntArg("axis", 1) + .Output("SplitOutput0") + .Output("SplitOutput1") + .Output("SplitOutput2") + .Output("SplitOutput3") + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("Activation", "InputSigmoid") + .Input("SplitOutput0") + .AddStringArg("activation", "SIGMOID") + .Output("InputSigmoid") + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("Activation", "NewInputTanh") + .Input("SplitOutput1") + .AddStringArg("activation", "TANH") + .Output("NewInputTanh") + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("Eltwise", "RememberMul") + .Input("InputSigmoid") + .Input("NewInputTanh") + .AddIntArg("T", DataTypeToEnum::v()) + .AddIntArg("type", static_cast(kernels::EltwiseType::PROD)) + .Output("RememberMul") + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("Eltwise", "ForgetAdd") + .Input("SplitOutput2") + .AddFloatArg("scalar_input", forget_add_name) + .AddIntArg("T", DataTypeToEnum::v()) + .AddIntArg("type", static_cast(kernels::EltwiseType::SUM)) + .Output("ForgetAdd") + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("Activation", "ForgetSigmoid") + .Input("ForgetAdd") + .AddStringArg("activation", "SIGMOID") + .Output("ForgetSigmoid") + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("Eltwise", "ForgetMul") + .Input("ForgetSigmoid") + .Input(pre_cell_name) + .AddIntArg("T", DataTypeToEnum::v()) + .AddIntArg("type", static_cast(kernels::EltwiseType::PROD)) + .Output("ForgetMulPreCell") + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("Eltwise", "Cell") + .Input("RememberMul") + .Input("ForgetMulPreCell") + .AddIntArg("T", DataTypeToEnum::v()) + .AddIntArg("type", static_cast(kernels::EltwiseType::SUM)) + .Output(cell_name) + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("Activation", "CellTanh") + .Input(cell_name) + .AddStringArg("activation", "TANH") + .Output("CellTanh") + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("Activation", "OutputSigmoid") + .Input("SplitOutput3") + .AddStringArg("activation", "SIGMOID") + .Output("OutputSigmoid") + .Finalize(net->AddNewOperatorDef()); + + OpDefBuilder("Eltwise", "FinalMul") + .Input("OutputSigmoid") + .Input("CellTanh") + .AddIntArg("T", DataTypeToEnum::v()) + .AddIntArg("type", static_cast(kernels::EltwiseType::PROD)) + .Output(output_name) + .Finalize(net->AddNewOperatorDef()); +} + +} // namespace test +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_LSTMCELL_TEST_UTIL_H_