提交 af264bfb 编写于 作者: 李寅

Merge branch 'support_diff_in_size_and_hidden_units' into 'master'

support differ input_size and hidden_units, add lstmcell cpu benchmark

See merge request !788
......@@ -9,6 +9,8 @@ __kernel void lstmcell(KERNEL_ERROR_PARAMS
__read_only image2d_t pre_cell,
__private const float forget_bias,
__private const int width,
__private const int hidden_units,
__private const int in_w_blk,
__write_only image2d_t cell,
__write_only image2d_t output) {
const int w_blk_idx = get_global_id(0);
......@@ -25,114 +27,97 @@ __kernel void lstmcell(KERNEL_ERROR_PARAMS
DATA_TYPE4 fc_res0 = 0.0, fc_res1 = 0.0, fc_res2 = 0.0, fc_res3 = 0.0;
DATA_TYPE4 in, pre_h;
DATA_TYPE4 w0, w1, w2, w3;
int k_offset;
// concat matmul
for (short i = 0; i < global_size_dim0; ++i) {
for (short i = 0; i < in_w_blk; ++i) {
in = READ_IMAGET(input, SAMPLER, (int2)(i, h_idx));
short k = 4 * i;
int k = i << 2;
w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k));
w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k));
w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k));
w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k));
fc_res0 += in.x * w0;
fc_res1 += in.x * w1;
fc_res2 += in.x * w2;
fc_res3 += in.x * w3;
k = 4 * i + 1;
if (k < width) {
w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k));
w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k));
w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k));
w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k));
fc_res0 += in.y * w0;
fc_res1 += in.y * w1;
fc_res2 += in.y * w2;
fc_res3 += in.y * w3;
}
k = 4 * i + 2;
if (k < width) {
w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k));
w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k));
w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k));
w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k));
fc_res0 += in.z * w0;
fc_res1 += in.z * w1;
fc_res2 += in.z * w2;
fc_res3 += in.z * w3;
}
k = 4 * i + 3;
if (k < width) {
w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k));
w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k));
w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k));
w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k));
fc_res0 += in.w * w0;
fc_res1 += in.w * w1;
fc_res2 += in.w * w2;
fc_res3 += in.w * w3;
}
k += 1;
k_offset = select(-1, k, k < width);
w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k_offset));
w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k_offset));
w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k_offset));
w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k_offset));
fc_res0 += in.y * w0;
fc_res1 += in.y * w1;
fc_res2 += in.y * w2;
fc_res3 += in.y * w3;
k += 1;
k_offset = select(-1, k, k < width);
w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k_offset));
w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k_offset));
w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k_offset));
w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k_offset));
fc_res0 += in.z * w0;
fc_res1 += in.z * w1;
fc_res2 += in.z * w2;
fc_res3 += in.z * w3;
k += 1;
k_offset = select(-1, k, k < width);
w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k_offset));
w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k_offset));
w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k_offset));
w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k_offset));
fc_res0 += in.w * w0;
fc_res1 += in.w * w1;
fc_res2 += in.w * w2;
fc_res3 += in.w * w3;
}
for (short i = 0; i < global_size_dim0; ++i) {
pre_h = READ_IMAGET(pre_output, SAMPLER, (int2)(i, h_idx));
short k = 4 * (i + global_size_dim0);
short k_limit = 4 * global_size_dim0 + width;
int k = (i << 2) + width;
w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k));
w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k));
w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k));
w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k));
fc_res0 += pre_h.x * w0;
fc_res1 += pre_h.x * w1;
fc_res2 += pre_h.x * w2;
fc_res3 += pre_h.x * w3;
k = 4 * (i + global_size_dim0) + 1;
if (k < k_limit) {
w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k));
w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k));
w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k));
w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k));
fc_res0 += pre_h.y * w0;
fc_res1 += pre_h.y * w1;
fc_res2 += pre_h.y * w2;
fc_res3 += pre_h.y * w3;
}
k = 4 * (i + global_size_dim0) + 2;
if (k < k_limit) {
w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k));
w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k));
w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k));
w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k));
fc_res0 += pre_h.z * w0;
fc_res1 += pre_h.z * w1;
fc_res2 += pre_h.z * w2;
fc_res3 += pre_h.z * w3;
}
k += 1;
w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k));
w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k));
w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k));
w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k));
fc_res0 += pre_h.y * w0;
fc_res1 += pre_h.y * w1;
fc_res2 += pre_h.y * w2;
fc_res3 += pre_h.y * w3;
k = 4 * (i + global_size_dim0) + 3;
if (k < k_limit) {
w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k));
w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k));
w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k));
w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k));
k += 1;
w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k));
w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k));
w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k));
w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k));
fc_res0 += pre_h.z * w0;
fc_res1 += pre_h.z * w1;
fc_res2 += pre_h.z * w2;
fc_res3 += pre_h.z * w3;
fc_res0 += pre_h.w * w0;
fc_res1 += pre_h.w * w1;
fc_res2 += pre_h.w * w2;
fc_res3 += pre_h.w * w3;
}
k += 1;
w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k));
w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k));
w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k));
w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k));
fc_res0 += pre_h.w * w0;
fc_res1 += pre_h.w * w1;
fc_res2 += pre_h.w * w2;
fc_res3 += pre_h.w * w3;
}
// bias
......
......@@ -31,12 +31,13 @@ MaceStatus LSTMCellFunctor<DeviceType::GPU, T>::operator()(
Tensor *cell,
Tensor *output,
StatsFuture *future) {
MACE_CHECK(input->dim_size() == 2 && input->dim(1) % 4 == 0,
"LSTM step should be a multiple of 4");
MACE_CHECK(pre_output->dim_size() == 2 && pre_output->dim(1) % 4 == 0,
"LSTM hidden units should be a multiple of 4");
const index_t height = input->dim(0);
const index_t width = input->dim(1);
const index_t width_blocks = width / 4;
const index_t hidden_units = pre_output->dim(1);
const index_t w_blocks = hidden_units >> 2;
auto runtime = context_->device()->opencl_runtime();
......@@ -57,17 +58,18 @@ MaceStatus LSTMCellFunctor<DeviceType::GPU, T>::operator()(
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_));
}
const uint32_t gws[2] = {static_cast<uint32_t>(width_blocks),
const uint32_t gws[2] = {static_cast<uint32_t>(w_blocks),
static_cast<uint32_t>(height)};
if (!IsVecEqual(input_shape_, input->shape())) {
std::vector<index_t> output_shape_padded = {height, 1, 1, width};
std::vector<index_t> output_shape_padded = {height, 1, 1, hidden_units};
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape_padded, BufferType::IN_OUT_CHANNEL,
&output_image_shape);
MACE_RETURN_IF_ERROR(output->ResizeImage(input->shape(),
MACE_RETURN_IF_ERROR(output->ResizeImage(pre_output->shape(),
output_image_shape));
MACE_RETURN_IF_ERROR(cell->ResizeImage(input->shape(), output_image_shape));
MACE_RETURN_IF_ERROR(cell->ResizeImage(pre_cell->shape(),
output_image_shape));
uint32_t idx = 0;
OUT_OF_RANGE_SET_ARG;
......@@ -79,6 +81,8 @@ MaceStatus LSTMCellFunctor<DeviceType::GPU, T>::operator()(
kernel_.setArg(idx++, *(pre_cell->opencl_image()));
kernel_.setArg(idx++, static_cast<float>(forget_bias_));
kernel_.setArg(idx++, static_cast<int32_t>(width));
kernel_.setArg(idx++, static_cast<int32_t>(hidden_units));
kernel_.setArg(idx++, static_cast<int32_t>(RoundUpDiv4(width)));
kernel_.setArg(idx++, *(cell->opencl_image()));
kernel_.setArg(idx++, *(output->opencl_image()));
......
......@@ -20,9 +20,9 @@ load(
cc_library(
name = "test",
testonly = 1,
hdrs = [
"ops_test_util.h",
],
hdrs = glob([
"*_test_util.h",
]),
srcs = [
"ops_test_util.cc",
],
......@@ -67,7 +67,7 @@ cc_library(
),
hdrs = glob(
["*.h"],
exclude = ["ops_test_util.h"],
exclude = glob(["*_test_util.h"]),
),
copts = [
"-Werror",
......
......@@ -15,6 +15,7 @@
#include "mace/core/operator.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/lstmcell_test_util.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
......@@ -23,23 +24,31 @@ namespace test {
namespace {
template <DeviceType D, typename T>
void LSTMCell(int iters, int batch, int lstm_step) {
void LSTMCell(int iters, int batch, int input_size, int hidden_units) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
if (D == DeviceType::GPU) {
net.AddRandomInput<D, T>("Input", {batch, lstm_step});
net.AddRandomInput<D, T>("PreOutput", {batch, lstm_step});
net.AddRandomInput<D, T>("Weight", {2 * lstm_step, 4 * lstm_step});
net.AddRandomInput<D, T>("Bias", {4 * lstm_step});
net.AddRandomInput<D, T>("PreCell", {batch, lstm_step});
} else {
MACE_NOT_IMPLEMENTED;
}
net.AddRandomInput<D, float>("Input", {batch, input_size});
net.AddRandomInput<D, float>("PreOutput", {batch, hidden_units});
net.AddRandomInput<D, float>("Weight", {input_size + hidden_units,
4 * hidden_units});
net.AddRandomInput<D, float>("Bias", {4 * hidden_units});
net.AddRandomInput<D, float>("PreCell", {batch, hidden_units});
if (D == DeviceType::GPU) {
const float &forget_add = 0.0f;
if (D == DeviceType::CPU) {
net.CopyData<DeviceType::CPU, float>("Input", "InputCPU");
net.CopyData<DeviceType::CPU, float>("PreOutput", "PreOutputCPU");
net.CopyData<DeviceType::CPU, float>("Weight", "WeightCPU");
net.CopyData<DeviceType::CPU, float>("Bias", "BiasCPU");
net.CopyData<DeviceType::CPU, float>("PreCell", "PreCellCPU");
LSTMCellCPU<float>(&net, "InputCPU", "PreOutputCPU", "WeightCPU", "BiasCPU",
"PreCellCPU", forget_add, "CellCPU", "OutputCPU");
} else if (D == DeviceType::GPU) {
BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(&net, "PreOutput", "PreOutputImage",
......@@ -49,7 +58,7 @@ void LSTMCell(int iters, int batch, int lstm_step) {
BufferToImage<D, T>(&net, "Bias", "BiasImage",
kernels::BufferType::ARGUMENT);
BufferToImage<D, T>(&net, "PreCell", "PreCellImage",
kernels::BufferType::IN_OUT_CHANNEL);
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("LSTMCell", "LSTMCellTest")
.Input("InputImage")
......@@ -57,7 +66,7 @@ void LSTMCell(int iters, int batch, int lstm_step) {
.Input("WeightImage")
.Input("BiasImage")
.Input("PreCellImage")
.AddFloatArg("forget_add", 0.0f)
.AddFloatArg("scalar_input", forget_add)
.Output("CellImage")
.Output("OutputImage")
.Finalize(net.NewOperatorDef());
......@@ -79,27 +88,30 @@ void LSTMCell(int iters, int batch, int lstm_step) {
}
} // namespace
#define MACE_BM_LSTMCELL_MACRO(N, LSTM_STEP, TYPE, DEVICE) \
static void MACE_BM_LSTMCELL_##N##_##LSTM_STEP##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t macc = \
static_cast<int64_t>(iters) * N * 2 * LSTM_STEP * 4 * LSTM_STEP; \
const int64_t tot = static_cast<int64_t>(iters) * N * LSTM_STEP; \
mace::testing::MaccProcessed(macc); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
LSTMCell<DEVICE, TYPE>(iters, N, LSTM_STEP); \
} \
MACE_BENCHMARK(MACE_BM_LSTMCELL_##N##_##LSTM_STEP##_##TYPE##_##DEVICE)
#define MACE_BM_LSTMCELL_MACRO(N, INPUT_SIZE, HIDDEN_UNITS, TYPE, DEVICE) \
static void \
MACE_BM_LSTMCELL_##N##_##INPUT_SIZE##_##HIDDEN_UNITS##_##TYPE##_##DEVICE(\
int iters) { \
const int64_t macc = \
static_cast<int64_t>( \
iters) * N * (INPUT_SIZE + HIDDEN_UNITS) * 4 * HIDDEN_UNITS; \
const int64_t tot = static_cast<int64_t>(iters) * N * INPUT_SIZE; \
mace::testing::MaccProcessed(macc); \
mace::testing::BytesProcessed(tot * (sizeof(TYPE))); \
LSTMCell<DEVICE, TYPE>(iters, N, INPUT_SIZE, HIDDEN_UNITS); \
} \
MACE_BENCHMARK( \
MACE_BM_LSTMCELL_##N##_##INPUT_SIZE##_##HIDDEN_UNITS##_##TYPE##_##DEVICE)
#define MACE_BM_LSTMCELL(N, LSTM_STEP) \
MACE_BM_LSTMCELL_MACRO(N, LSTM_STEP, float, GPU); \
MACE_BM_LSTMCELL_MACRO(N, LSTM_STEP, half, GPU);
#define MACE_BM_LSTMCELL(N, INPUT_SIZE, HIDDEN_UNITS) \
MACE_BM_LSTMCELL_MACRO(N, INPUT_SIZE, HIDDEN_UNITS, float, CPU); \
MACE_BM_LSTMCELL_MACRO(N, INPUT_SIZE, HIDDEN_UNITS, float, GPU); \
MACE_BM_LSTMCELL_MACRO(N, INPUT_SIZE, HIDDEN_UNITS, half, GPU);
MACE_BM_LSTMCELL(1, 200);
MACE_BM_LSTMCELL(20, 200);
MACE_BM_LSTMCELL(20, 320);
MACE_BM_LSTMCELL(32, 400);
MACE_BM_LSTMCELL(32, 640);
MACE_BM_LSTMCELL(1, 64, 256);
MACE_BM_LSTMCELL(30, 64, 256);
MACE_BM_LSTMCELL(50, 64, 256);
MACE_BM_LSTMCELL(80, 64, 256);
} // namespace test
} // namespace ops
} // namespace mace
......@@ -14,6 +14,7 @@
#include "mace/core/operator.h"
#include "mace/kernels/eltwise.h"
#include "mace/ops/lstmcell_test_util.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
......@@ -23,128 +24,20 @@ namespace test {
class LSTMCellTest : public OpsTestBase {};
namespace {
template <typename T>
void LSTMCellCPU(OpsTestNet *net,
const std::string &input_name,
const std::string &pre_output_name,
const std::string &weight_name,
const std::string &bias_name,
const std::string &pre_cell_name,
const float &forget_add_name,
const std::string &cell_name,
const std::string &output_name) {
OpDefBuilder("Concat", "Concat")
.Input(input_name)
.Input(pre_output_name)
.AddIntArg("axis", 1)
.Output("ConcatOutput")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("MatMul", "MatMul")
.Input("ConcatOutput")
.Input(weight_name)
.Output("MatMulOutput")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("BiasAdd", "BiasAdd")
.Input("MatMulOutput")
.Input(bias_name)
.Output("BiasOutput")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Split", "FCSplit")
.Input("BiasOutput")
.AddIntArg("axis", 1)
.Output("SplitOutput0")
.Output("SplitOutput1")
.Output("SplitOutput2")
.Output("SplitOutput3")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Activation", "InputSigmoid")
.Input("SplitOutput0")
.AddStringArg("activation", "SIGMOID")
.Output("InputSigmoid")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Activation", "NewInputTanh")
.Input("SplitOutput1")
.AddStringArg("activation", "TANH")
.Output("NewInputTanh")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Eltwise", "RememberMul")
.Input("InputSigmoid")
.Input("NewInputTanh")
.AddIntArg("T", DataTypeToEnum<T>::v())
.AddIntArg("type", static_cast<int>(kernels::EltwiseType::PROD))
.Output("RememberMul")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Eltwise", "ForgetAdd")
.Input("SplitOutput2")
.AddFloatArg("scalar_input", forget_add_name)
.AddIntArg("T", DataTypeToEnum<T>::v())
.AddIntArg("type", static_cast<int>(kernels::EltwiseType::SUM))
.Output("ForgetAdd")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Activation", "ForgetSigmoid")
.Input("ForgetAdd")
.AddStringArg("activation", "SIGMOID")
.Output("ForgetSigmoid")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Eltwise", "ForgetMul")
.Input("ForgetSigmoid")
.Input(pre_cell_name)
.AddIntArg("T", DataTypeToEnum<T>::v())
.AddIntArg("type", static_cast<int>(kernels::EltwiseType::PROD))
.Output("ForgetMulPreCell")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Eltwise", "Cell")
.Input("RememberMul")
.Input("ForgetMulPreCell")
.AddIntArg("T", DataTypeToEnum<T>::v())
.AddIntArg("type", static_cast<int>(kernels::EltwiseType::SUM))
.Output(cell_name)
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Activation", "CellTanh")
.Input(cell_name)
.AddStringArg("activation", "TANH")
.Output("CellTanh")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Activation", "OutputSigmoid")
.Input("SplitOutput3")
.AddStringArg("activation", "SIGMOID")
.Output("OutputSigmoid")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Eltwise", "FinalMul")
.Input("OutputSigmoid")
.Input("CellTanh")
.AddIntArg("T", DataTypeToEnum<T>::v())
.AddIntArg("type", static_cast<int>(kernels::EltwiseType::PROD))
.Output(output_name)
.Finalize(net->AddNewOperatorDef());
}
template <DeviceType D, typename T>
void TestLSTMCell(const uint32_t &batch,
const uint32_t &lstm_step,
const uint32_t &input_size,
const uint32_t &hidden_units,
const float &forget_add) {
// Construct graph
OpsTestNet net;
net.AddRandomInput<D, float>("Input", {batch, lstm_step});
net.AddRandomInput<D, float>("PreOutput", {batch, lstm_step});
net.AddRandomInput<D, float>("Weight", {2 * lstm_step, 4 * lstm_step});
net.AddRandomInput<D, float>("Bias", {4 * lstm_step});
net.AddRandomInput<D, float>("PreCell", {batch, lstm_step});
net.AddRandomInput<D, float>("Input", {batch, input_size});
net.AddRandomInput<D, float>("PreOutput", {batch, hidden_units});
net.AddRandomInput<D, float>("Weight", {input_size + hidden_units,
4 * hidden_units});
net.AddRandomInput<D, float>("Bias", {4 * hidden_units});
net.AddRandomInput<D, float>("PreCell", {batch, hidden_units});
net.CopyData<DeviceType::CPU, float>("Input", "InputCPU");
net.CopyData<DeviceType::CPU, float>("PreOutput", "PreOutputCPU");
......@@ -205,17 +98,17 @@ void TestLSTMCell(const uint32_t &batch,
} // namespace
TEST_F(LSTMCellTest, OPENCLRandomHalf) {
TestLSTMCell<GPU, half>(1, 4, 0.0f);
TestLSTMCell<GPU, half>(2, 16, 0.0f);
TestLSTMCell<GPU, half>(2, 200, 0.5f);
TestLSTMCell<GPU, half>(20, 320, 0.5f);
TestLSTMCell<GPU, half>(1, 3, 8, 0.0f);
TestLSTMCell<GPU, half>(2, 16, 24, 0.0f);
TestLSTMCell<GPU, half>(2, 200, 280, 0.5f);
TestLSTMCell<GPU, half>(20, 320, 512, 0.5f);
}
TEST_F(LSTMCellTest, OPENCLRandomFloat) {
TestLSTMCell<GPU, float>(1, 4, 0.0f);
TestLSTMCell<GPU, float>(2, 16, 0.0f);
TestLSTMCell<GPU, float>(2, 200, 0.5f);
TestLSTMCell<GPU, float>(20, 320, 0.5f);
TestLSTMCell<GPU, float>(1, 3, 8, 0.0f);
TestLSTMCell<GPU, float>(2, 16, 24, 0.0f);
TestLSTMCell<GPU, float>(2, 200, 280, 0.5f);
TestLSTMCell<GPU, float>(20, 320, 512, 0.5f);
}
} // namespace test
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_LSTMCELL_TEST_UTIL_H_
#define MACE_OPS_LSTMCELL_TEST_UTIL_H_
#include <string>
#include "mace/core/operator.h"
#include "mace/kernels/eltwise.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
template <typename T>
void LSTMCellCPU(OpsTestNet *net,
const std::string &input_name,
const std::string &pre_output_name,
const std::string &weight_name,
const std::string &bias_name,
const std::string &pre_cell_name,
const float &forget_add_name,
const std::string &cell_name,
const std::string &output_name) {
OpDefBuilder("Concat", "Concat")
.Input(input_name)
.Input(pre_output_name)
.AddIntArg("axis", 1)
.Output("ConcatOutput")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("MatMul", "MatMul")
.Input("ConcatOutput")
.Input(weight_name)
.Output("MatMulOutput")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("BiasAdd", "BiasAdd")
.Input("MatMulOutput")
.Input(bias_name)
.Output("BiasOutput")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Split", "FCSplit")
.Input("BiasOutput")
.AddIntArg("axis", 1)
.Output("SplitOutput0")
.Output("SplitOutput1")
.Output("SplitOutput2")
.Output("SplitOutput3")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Activation", "InputSigmoid")
.Input("SplitOutput0")
.AddStringArg("activation", "SIGMOID")
.Output("InputSigmoid")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Activation", "NewInputTanh")
.Input("SplitOutput1")
.AddStringArg("activation", "TANH")
.Output("NewInputTanh")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Eltwise", "RememberMul")
.Input("InputSigmoid")
.Input("NewInputTanh")
.AddIntArg("T", DataTypeToEnum<T>::v())
.AddIntArg("type", static_cast<int>(kernels::EltwiseType::PROD))
.Output("RememberMul")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Eltwise", "ForgetAdd")
.Input("SplitOutput2")
.AddFloatArg("scalar_input", forget_add_name)
.AddIntArg("T", DataTypeToEnum<T>::v())
.AddIntArg("type", static_cast<int>(kernels::EltwiseType::SUM))
.Output("ForgetAdd")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Activation", "ForgetSigmoid")
.Input("ForgetAdd")
.AddStringArg("activation", "SIGMOID")
.Output("ForgetSigmoid")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Eltwise", "ForgetMul")
.Input("ForgetSigmoid")
.Input(pre_cell_name)
.AddIntArg("T", DataTypeToEnum<T>::v())
.AddIntArg("type", static_cast<int>(kernels::EltwiseType::PROD))
.Output("ForgetMulPreCell")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Eltwise", "Cell")
.Input("RememberMul")
.Input("ForgetMulPreCell")
.AddIntArg("T", DataTypeToEnum<T>::v())
.AddIntArg("type", static_cast<int>(kernels::EltwiseType::SUM))
.Output(cell_name)
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Activation", "CellTanh")
.Input(cell_name)
.AddStringArg("activation", "TANH")
.Output("CellTanh")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Activation", "OutputSigmoid")
.Input("SplitOutput3")
.AddStringArg("activation", "SIGMOID")
.Output("OutputSigmoid")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Eltwise", "FinalMul")
.Input("OutputSigmoid")
.Input("CellTanh")
.AddIntArg("T", DataTypeToEnum<T>::v())
.AddIntArg("type", static_cast<int>(kernels::EltwiseType::PROD))
.Output(output_name)
.Finalize(net->AddNewOperatorDef());
}
} // namespace test
} // namespace ops
} // namespace mace
#endif // MACE_OPS_LSTMCELL_TEST_UTIL_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册