提交 622173e7 编写于 作者: Y yejianwu

add lstm gpu, unstack cpu

上级 f5829926
...@@ -45,11 +45,13 @@ cc_library( ...@@ -45,11 +45,13 @@ cc_library(
exclude = [ exclude = [
"buffer_to_image.h", "buffer_to_image.h",
"image_to_buffer.h", "image_to_buffer.h",
"lstmcell.h",
], ],
) + if_opencl_enabled(glob([ ) + if_opencl_enabled(glob([
"opencl/*.h", "opencl/*.h",
"buffer_to_image.h", "buffer_to_image.h",
"image_to_buffer.h", "image_to_buffer.h",
"lstmcell.h",
])), ])),
copts = [ copts = [
"-Werror", "-Werror",
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_LSTMCELL_H_
#define MACE_KERNELS_LSTMCELL_H_
#include <algorithm>
#include <limits>
#include <memory>
#include <vector>
#include "mace/core/future.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/tensor.h"
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct LSTMCellFunctor;
template <typename T>
struct LSTMCellFunctor<DeviceType::GPU, T> {
explicit LSTMCellFunctor(T forget_bias) :
forget_bias_(static_cast<T>(forget_bias)) {}
MaceStatus operator()(const Tensor *input,
const Tensor *pre_output,
const Tensor *weight,
const Tensor *bias,
const Tensor *pre_cell,
Tensor *cell,
Tensor *output,
StatsFuture *future);
T forget_bias_;
cl::Kernel kernel_;
uint32_t kwg_size_;
std::unique_ptr<BufferBase> kernel_error_;
std::vector<index_t> input_shape_;
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_LSTMCELL_H_
...@@ -61,6 +61,11 @@ ...@@ -61,6 +61,11 @@
__constant sampler_t SAMPLER = __constant sampler_t SAMPLER =
CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
inline float4 do_sigmoid(float4 in) {
// native_func not support half
return native_recip(1.0 + native_exp(-in));
}
inline DATA_TYPE4 do_activation(DATA_TYPE4 in, inline DATA_TYPE4 do_activation(DATA_TYPE4 in,
#ifdef USE_PRELU #ifdef USE_PRELU
DATA_TYPE4 prelu_alpha, DATA_TYPE4 prelu_alpha,
...@@ -80,7 +85,7 @@ inline DATA_TYPE4 do_activation(DATA_TYPE4 in, ...@@ -80,7 +85,7 @@ inline DATA_TYPE4 do_activation(DATA_TYPE4 in,
out = tanh(in); out = tanh(in);
#endif #endif
#ifdef USE_SIGMOID #ifdef USE_SIGMOID
out = native_recip((DATA_TYPE)1 + native_exp(-in)); out = do_sigmoid(in);
#endif #endif
return out; return out;
} }
......
#include <common.h>
__kernel void lstmcell(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM2
__read_only image2d_t input,
__read_only image2d_t pre_output,
__read_only image2d_t weight,
__read_only image2d_t bias,
__read_only image2d_t pre_cell,
__private const float forget_bias,
__private const int width,
__write_only image2d_t cell,
__write_only image2d_t output) {
const int w_blk_idx = get_global_id(0);
const int h_idx = get_global_id(1);
#ifndef NON_UNIFORM_WORK_GROUP
if (w_blk_idx >= global_size_dim0 || h_idx >= global_size_dim1) return;
#endif
// fc_res0 -> i
// fc_res1 -> j
// fc_res2 -> f
// fc_res3 -> o
DATA_TYPE4 fc_res0 = 0.0, fc_res1 = 0.0, fc_res2 = 0.0, fc_res3 = 0.0;
DATA_TYPE4 in, pre_h;
DATA_TYPE4 w0, w1, w2, w3;
// concat matmul
for (short i = 0; i < global_size_dim0; ++i) {
in = READ_IMAGET(input, SAMPLER, (int2)(i, h_idx));
short k = 4 * i;
w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k));
w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k));
w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k));
w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k));
fc_res0 += in.x * w0;
fc_res1 += in.x * w1;
fc_res2 += in.x * w2;
fc_res3 += in.x * w3;
k = 4 * i + 1;
if (k < width) {
w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k));
w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k));
w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k));
w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k));
fc_res0 += in.y * w0;
fc_res1 += in.y * w1;
fc_res2 += in.y * w2;
fc_res3 += in.y * w3;
}
k = 4 * i + 2;
if (k < width) {
w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k));
w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k));
w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k));
w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k));
fc_res0 += in.z * w0;
fc_res1 += in.z * w1;
fc_res2 += in.z * w2;
fc_res3 += in.z * w3;
}
k = 4 * i + 3;
if (k < width) {
w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k));
w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k));
w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k));
w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k));
fc_res0 += in.w * w0;
fc_res1 += in.w * w1;
fc_res2 += in.w * w2;
fc_res3 += in.w * w3;
}
}
for (short i = 0; i < global_size_dim0; ++i) {
pre_h = READ_IMAGET(pre_output, SAMPLER, (int2)(i, h_idx));
short k = 4 * (i + global_size_dim0);
short k_limit = 4 * global_size_dim0 + width;
w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k));
w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k));
w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k));
w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k));
fc_res0 += pre_h.x * w0;
fc_res1 += pre_h.x * w1;
fc_res2 += pre_h.x * w2;
fc_res3 += pre_h.x * w3;
k = 4 * (i + global_size_dim0) + 1;
if (k < k_limit) {
w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k));
w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k));
w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k));
w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k));
fc_res0 += pre_h.y * w0;
fc_res1 += pre_h.y * w1;
fc_res2 += pre_h.y * w2;
fc_res3 += pre_h.y * w3;
}
k = 4 * (i + global_size_dim0) + 2;
if (k < k_limit) {
w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k));
w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k));
w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k));
w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k));
fc_res0 += pre_h.z * w0;
fc_res1 += pre_h.z * w1;
fc_res2 += pre_h.z * w2;
fc_res3 += pre_h.z * w3;
}
k = 4 * (i + global_size_dim0) + 3;
if (k < k_limit) {
w0 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx, k));
w1 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0, k));
w2 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, k));
w3 = READ_IMAGET(weight, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, k));
fc_res0 += pre_h.w * w0;
fc_res1 += pre_h.w * w1;
fc_res2 += pre_h.w * w2;
fc_res3 += pre_h.w * w3;
}
}
// bias
DATA_TYPE4 b0, b1, b2, b3;
b0 = READ_IMAGET(bias, SAMPLER, (int2)(w_blk_idx, 0));
b1 = READ_IMAGET(bias, SAMPLER, (int2)(w_blk_idx + global_size_dim0, 0));
b2 = READ_IMAGET(bias, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 2, 0));
b3 = READ_IMAGET(bias, SAMPLER, (int2)(w_blk_idx + global_size_dim0 * 3, 0));
fc_res0 += b0;
fc_res1 += b1;
fc_res2 += b2;
fc_res3 += b3;
// gate
DATA_TYPE4 pre_c, c, h;
pre_c = READ_IMAGET(pre_cell, SAMPLER, (int2)(w_blk_idx, h_idx));
c = do_sigmoid(fc_res0) * tanh(fc_res1) + do_sigmoid((fc_res2 + (float4)forget_bias)) * pre_c;
h = do_sigmoid(fc_res3) * tanh(c);
WRITE_IMAGET(cell, (int2)(w_blk_idx, h_idx), c);
WRITE_IMAGET(output, (int2)(w_blk_idx, h_idx), h);
}
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/kernels/lstmcell.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/tuner.h"
#include "mace/utils/utils.h"
namespace mace {
namespace kernels {
template <typename T>
MaceStatus LSTMCellFunctor<DeviceType::GPU, T>::operator()(
const Tensor *input,
const Tensor *pre_output,
const Tensor *weight,
const Tensor *bias,
const Tensor *pre_cell,
Tensor *cell,
Tensor *output,
StatsFuture *future) {
const index_t height = input->dim(0);
const index_t width = input->dim(1);
auto runtime = OpenCLRuntime::Global();
if (kernel_.get() == nullptr) {
std::set<std::string> built_options;
OUT_OF_RANGE_CONFIG(kernel_error_);
NON_UNIFORM_WG_CONFIG;
auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("lstmcell");
built_options.emplace("-Dlstmcell=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpCompatibleCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpCompatibleCLCMDDt(dt));
MACE_RETURN_IF_ERROR(runtime->BuildKernel("lstmcell", kernel_name,
built_options, &kernel_));
kwg_size_ =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_));
}
const index_t width_blocks = RoundUpDiv4(width);
const uint32_t gws[2] = {static_cast<uint32_t>(width_blocks),
static_cast<uint32_t>(height)};
if (!IsVecEqual(input_shape_, input->shape())) {
std::vector<index_t> output_shape_paded = {height, 1, 1, width};
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape_paded, BufferType::IN_OUT_CHANNEL,
&output_image_shape);
MACE_RETURN_IF_ERROR(output->ResizeImage(input->shape(),
output_image_shape));
MACE_RETURN_IF_ERROR(cell->ResizeImage(input->shape(), output_image_shape));
uint32_t idx = 0;
OUT_OF_RANGE_SET_ARG;
SET_2D_GWS_ARGS(kernel_);
kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, *(pre_output->opencl_image()));
kernel_.setArg(idx++, *(weight->opencl_image()));
kernel_.setArg(idx++, *(bias->opencl_image()));
kernel_.setArg(idx++, *(pre_cell->opencl_image()));
kernel_.setArg(idx++, static_cast<float>(forget_bias_));
kernel_.setArg(idx++, static_cast<int32_t>(width));
kernel_.setArg(idx++, *(cell->opencl_image()));
kernel_.setArg(idx++, *(output->opencl_image()));
input_shape_ = input->shape();
}
const std::vector<uint32_t> lws = {kwg_size_ / 16, 16, 0};
std::string tuning_key =
Concat("lstmcell_opencl_kernel", output->dim(0), output->dim(1));
MACE_RETURN_IF_ERROR(TuningOrRun2DKernel(kernel_, tuning_key,
gws, lws, future));
OUT_OF_RANGE_VALIDATION(kernel_error_);
return MACE_SUCCESS;
}
template struct LSTMCellFunctor<DeviceType::GPU, float>;
template struct LSTMCellFunctor<DeviceType::GPU, half>;
} // namespace kernels
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_UNSTACK_H_
#define MACE_KERNELS_UNSTACK_H_
#include <algorithm>
#include <functional>
#include <memory>
#include <vector>
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/public/mace.h"
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct UnstackFunctor {
explicit UnstackFunctor(int axis) : axis_(axis) {}
MaceStatus operator()(const Tensor *input,
const std::vector<Tensor *> &outputs,
StatsFuture *future) {
std::vector<index_t> input_shape = input->shape();
MACE_CHECK(axis_ >= -(input->dim_size()) && axis_ < input->dim_size(),
"axis out of bound.");
if (axis_ < 0) {
axis_ += input->dim_size();
}
MACE_CHECK(outputs.size() == input_shape[axis_],
"output size not equal input_shape[axis]");
std::vector<index_t> output_shape = input_shape;
output_shape.erase(output_shape.begin() + axis_);
std::vector<T *> output_data(outputs.size(), nullptr);
for (size_t i = 0; i < input_shape[axis_]; ++i) {
MACE_RETURN_IF_ERROR(outputs[i]->Resize(output_shape));
output_data[i] = outputs[i]->mutable_data<T>();
}
const T *input_data = input->data<T>();
index_t high_dim_elem_size =
std::accumulate(input_shape.begin(), input_shape.begin() + axis_, 1,
std::multiplies<index_t>());
index_t low_dim_elem_size =
std::accumulate(input_shape.begin() + axis_ + 1, input_shape.end(), 1,
std::multiplies<index_t>());
for (index_t h = 0; h < high_dim_elem_size; ++h) {
int input_idx = h * input_shape[axis_] * low_dim_elem_size;
int output_idx = h * low_dim_elem_size;
for (size_t i = 0; i < input_shape[axis_]; ++i) {
memcpy(output_data[i] + output_idx, input_data + input_idx,
sizeof(T) * low_dim_elem_size);
input_idx += low_dim_elem_size;
}
}
SetFutureDefaultWaitFn(future);
return MACE_SUCCESS;
}
int axis_;
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_UNSTACK_H_
...@@ -38,11 +38,13 @@ cc_library( ...@@ -38,11 +38,13 @@ cc_library(
"*_benchmark.cc", "*_benchmark.cc",
"buffer_to_image.cc", "buffer_to_image.cc",
"image_to_buffer.cc", "image_to_buffer.cc",
"lstmcell.cc",
], ],
) + if_opencl_enabled( ) + if_opencl_enabled(
[ [
"buffer_to_image.cc", "buffer_to_image.cc",
"image_to_buffer.cc", "image_to_buffer.cc",
"lstmcell.cc",
], ],
), ),
hdrs = glob( hdrs = glob(
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/lstmcell.h"
namespace mace {
namespace ops {
void Register_LSTMCell(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("LSTMCell")
.Device(DeviceType::GPU)
.TypeConstraint<float>("T")
.Build(),
LSTMCellOp<DeviceType::GPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("LSTMCell")
.Device(DeviceType::GPU)
.TypeConstraint<half>("T")
.Build(),
LSTMCellOp<DeviceType::GPU, half>);
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_LSTMCELL_H_
#define MACE_OPS_LSTMCELL_H_
#include <vector>
#include "mace/core/operator.h"
#include "mace/kernels/lstmcell.h"
namespace mace {
namespace ops {
template <DeviceType D, class T>
class LSTMCellOp : public Operator<D, T> {
public:
LSTMCellOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
functor_(static_cast<T>(
OperatorBase::GetOptionalArg<float>("value", 0.0))) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *pre_output = this->Input(PRE_OUTPUT);
const Tensor *weight = this->Input(WEIGHT);
const Tensor *bias = this->Input(BIAS);
const Tensor *pre_cell = this->Input(PRE_CELL);
Tensor *cell = this->Output(CELL);
Tensor *output = this->Output(OUTPUT);
MACE_CHECK(input->dim_size() == 2 && input->dim(1) % 4 == 0,
"LSTM step should be a multiple of 4");
return functor_(
input, pre_output, weight, bias, pre_cell, cell, output, future);
};
protected:
kernels::LSTMCellFunctor<D, T> functor_;
MACE_OP_INPUT_TAGS(INPUT, PRE_OUTPUT, WEIGHT, BIAS, PRE_CELL);
MACE_OP_OUTPUT_TAGS(CELL, OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_LSTMCELL_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template <DeviceType D, typename T>
void LSTMCell(int iters, int batch, int lstm_step) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
if (D == DeviceType::GPU) {
net.AddRandomInput<D, T>("Input", {batch, lstm_step});
net.AddRandomInput<D, T>("PreOutput", {batch, lstm_step});
net.AddRandomInput<D, T>("Weight", {2 * lstm_step, 4 * lstm_step});
net.AddRandomInput<D, T>("Bias", {4 * lstm_step});
net.AddRandomInput<D, T>("PreCell", {batch, lstm_step});
} else {
MACE_NOT_IMPLEMENTED;
}
if (D == DeviceType::GPU) {
BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(&net, "PreOutput", "PreOutputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(&net, "Weight", "WeightImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(&net, "Bias", "BiasImage",
kernels::BufferType::ARGUMENT);
BufferToImage<D, T>(&net, "PreCell", "PreCellImage",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("LSTMCell", "LSTMCellTest")
.Input("InputImage")
.Input("PreOutputImage")
.Input("WeightImage")
.Input("BiasImage")
.Input("PreCellImage")
.AddFloatArg("forget_add", 0.0f)
.Output("CellImage")
.Output("OutputImage")
.Finalize(net.NewOperatorDef());
} else {
MACE_NOT_IMPLEMENTED;
}
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
} // namespace
#define MACE_BM_LSTMCELL_MACRO(N, LSTM_STEP, TYPE, DEVICE) \
static void MACE_BM_LSTMCELL_##N##_##LSTM_STEP##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t macc = \
static_cast<int64_t>(iters) * N * 2 * LSTM_STEP * 4 * LSTM_STEP; \
const int64_t tot = static_cast<int64_t>(iters) * N * LSTM_STEP; \
mace::testing::MaccProcessed(macc); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
LSTMCell<DEVICE, TYPE>(iters, N, LSTM_STEP); \
} \
MACE_BENCHMARK(MACE_BM_LSTMCELL_##N##_##LSTM_STEP##_##TYPE##_##DEVICE)
#define MACE_BM_LSTMCELL(N, LSTM_STEP) \
MACE_BM_LSTMCELL_MACRO(N, LSTM_STEP, float, GPU); \
MACE_BM_LSTMCELL_MACRO(N, LSTM_STEP, half, GPU);
MACE_BM_LSTMCELL(1, 200);
MACE_BM_LSTMCELL(20, 200);
MACE_BM_LSTMCELL(20, 320);
MACE_BM_LSTMCELL(32, 400);
MACE_BM_LSTMCELL(32, 640);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/kernels/eltwise.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class LSTMCellTest : public OpsTestBase {};
namespace {
template <typename T>
void LSTMCellCPU(OpsTestNet *net,
const std::string &input_name,
const std::string &pre_output_name,
const std::string &weight_name,
const std::string &bias_name,
const std::string &pre_cell_name,
const float &forget_add_name,
const std::string &cell_name,
const std::string &output_name) {
OpDefBuilder("Concat", "Concat")
.Input(input_name)
.Input(pre_output_name)
.AddIntArg("axis", 1)
.Output("ConcatOutput")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("MatMul", "MatMul")
.Input("ConcatOutput")
.Input(weight_name)
.Output("MatMulOutput")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("BiasAdd", "BiasAdd")
.Input("MatMulOutput")
.Input(bias_name)
.Output("BiasOutput")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Split", "FCSplit")
.Input("BiasOutput")
.AddIntArg("axis", 1)
.Output("SplitOutput0")
.Output("SplitOutput1")
.Output("SplitOutput2")
.Output("SplitOutput3")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Activation", "InputSigmoid")
.Input("SplitOutput0")
.AddStringArg("activation", "SIGMOID")
.Output("InputSigmoid")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Activation", "NewInputTanh")
.Input("SplitOutput1")
.AddStringArg("activation", "TANH")
.Output("NewInputTanh")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Eltwise", "RememberMul")
.Input("InputSigmoid")
.Input("NewInputTanh")
.AddIntArg("T", DataTypeToEnum<T>::v())
.AddIntArg("type", static_cast<int>(kernels::EltwiseType::PROD))
.Output("RememberMul")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Eltwise", "ForgetAdd")
.Input("SplitOutput2")
.AddFloatArg("value", forget_add_name)
.AddIntArg("T", DataTypeToEnum<T>::v())
.AddIntArg("type", static_cast<int>(kernels::EltwiseType::SUM))
.Output("ForgetAdd")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Activation", "ForgetSigmoid")
.Input("ForgetAdd")
.AddStringArg("activation", "SIGMOID")
.Output("ForgetSigmoid")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Eltwise", "ForgetMul")
.Input("ForgetSigmoid")
.Input(pre_cell_name)
.AddIntArg("T", DataTypeToEnum<T>::v())
.AddIntArg("type", static_cast<int>(kernels::EltwiseType::PROD))
.Output("ForgetMulPreCell")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Eltwise", "Cell")
.Input("RememberMul")
.Input("ForgetMulPreCell")
.AddIntArg("T", DataTypeToEnum<T>::v())
.AddIntArg("type", static_cast<int>(kernels::EltwiseType::SUM))
.Output(cell_name)
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Activation", "CellTanh")
.Input(cell_name)
.AddStringArg("activation", "TANH")
.Output("CellTanh")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Activation", "OutputSigmoid")
.Input("SplitOutput3")
.AddStringArg("activation", "SIGMOID")
.Output("OutputSigmoid")
.Finalize(net->AddNewOperatorDef());
OpDefBuilder("Eltwise", "FinalMul")
.Input("OutputSigmoid")
.Input("CellTanh")
.AddIntArg("T", DataTypeToEnum<T>::v())
.AddIntArg("type", static_cast<int>(kernels::EltwiseType::PROD))
.Output(output_name)
.Finalize(net->AddNewOperatorDef());
}
template <DeviceType D, typename T>
void TestLSTMCell(const uint32_t &batch,
const uint32_t &lstm_step,
const float &forget_add) {
// Construct graph
OpsTestNet net;
net.AddRandomInput<D, float>("Input", {batch, lstm_step});
net.AddRandomInput<D, float>("PreOutput", {batch, lstm_step});
net.AddRandomInput<D, float>("Weight", {2 * lstm_step, 4 * lstm_step});
net.AddRandomInput<D, float>("Bias", {4 * lstm_step});
net.AddRandomInput<D, float>("PreCell", {batch, lstm_step});
net.CopyData<DeviceType::CPU, float>("Input", "InputCPU");
net.CopyData<DeviceType::CPU, float>("PreOutput", "PreOutputCPU");
net.CopyData<DeviceType::CPU, float>("Weight", "WeightCPU");
net.CopyData<DeviceType::CPU, float>("Bias", "BiasCPU");
net.CopyData<DeviceType::CPU, float>("PreCell", "PreCellCPU");
// Run on CPU
LSTMCellCPU<float>(&net, "InputCPU", "PreOutputCPU", "WeightCPU", "BiasCPU",
"PreCellCPU", forget_add, "CellCPU", "OutputCPU");
// Run
net.RunOp(DeviceType::CPU);
// Run on GPU
BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(&net, "PreOutput", "PreOutputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(&net, "Weight", "WeightImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(&net, "Bias", "BiasImage",
kernels::BufferType::ARGUMENT);
BufferToImage<D, T>(&net, "PreCell", "PreCellImage",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("LSTMCell", "LSTMCellTest")
.Input("InputImage")
.Input("PreOutputImage")
.Input("WeightImage")
.Input("BiasImage")
.Input("PreCellImage")
.AddFloatArg("forget_add", forget_add)
.Output("CellImage")
.Output("OutputImage")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
ImageToBuffer<D, float>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL);
ImageToBuffer<D, float>(&net, "CellImage", "Cell",
kernels::BufferType::IN_OUT_CHANNEL);
Tensor expected_cell, expected_output;
expected_cell.Copy(*net.GetOutput("CellCPU"));
expected_output.Copy(*net.GetOutput("OutputCPU"));
if (DataTypeToEnum<T>::value == DT_HALF) {
ExpectTensorNear<float>(expected_cell, *net.GetOutput("Cell"), 1e-3);
ExpectTensorNear<float>(expected_output, *net.GetOutput("Output"), 1e-3);
} else {
ExpectTensorNear<float>(expected_cell, *net.GetOutput("Cell"), 1e-5);
ExpectTensorNear<float>(expected_output, *net.GetOutput("Output"), 1e-5);
}
}
} // namespace
TEST_F(LSTMCellTest, OPENCLRandomHalf) {
TestLSTMCell<GPU, half>(1, 4, 0.0f);
TestLSTMCell<GPU, half>(2, 16, 0.0f);
TestLSTMCell<GPU, half>(2, 200, 0.5f);
TestLSTMCell<GPU, half>(20, 320, 0.5f);
}
TEST_F(LSTMCellTest, OPENCLRandomFloat) {
TestLSTMCell<GPU, float>(1, 4, 0.0f);
TestLSTMCell<GPU, float>(2, 16, 0.0f);
TestLSTMCell<GPU, float>(2, 200, 0.5f);
TestLSTMCell<GPU, float>(20, 320, 0.5f);
}
} // namespace test
} // namespace ops
} // namespace mace
...@@ -53,6 +53,7 @@ extern void Register_Shape(OperatorRegistryBase *op_registry); ...@@ -53,6 +53,7 @@ extern void Register_Shape(OperatorRegistryBase *op_registry);
extern void Register_Split(OperatorRegistryBase *op_registry); extern void Register_Split(OperatorRegistryBase *op_registry);
extern void Register_Softmax(OperatorRegistryBase *op_registry); extern void Register_Softmax(OperatorRegistryBase *op_registry);
extern void Register_Stack(OperatorRegistryBase *op_registry); extern void Register_Stack(OperatorRegistryBase *op_registry);
extern void Register_Unstack(OperatorRegistryBase *op_registry);
extern void Register_StridedSlice(OperatorRegistryBase *op_registry); extern void Register_StridedSlice(OperatorRegistryBase *op_registry);
extern void Register_SpaceToBatchND(OperatorRegistryBase *op_registry); extern void Register_SpaceToBatchND(OperatorRegistryBase *op_registry);
extern void Register_SpaceToDepth(OperatorRegistryBase *op_registry); extern void Register_SpaceToDepth(OperatorRegistryBase *op_registry);
...@@ -64,6 +65,7 @@ extern void Register_WinogradTransform(OperatorRegistryBase *op_registry); ...@@ -64,6 +65,7 @@ extern void Register_WinogradTransform(OperatorRegistryBase *op_registry);
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
extern void Register_BufferToImage(OperatorRegistryBase *op_registry); extern void Register_BufferToImage(OperatorRegistryBase *op_registry);
extern void Register_ImageToBuffer(OperatorRegistryBase *op_registry); extern void Register_ImageToBuffer(OperatorRegistryBase *op_registry);
extern void Register_LSTMCell(OperatorRegistryBase *op_registry);
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
} // namespace ops } // namespace ops
...@@ -105,6 +107,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() { ...@@ -105,6 +107,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() {
ops::Register_Split(this); ops::Register_Split(this);
ops::Register_Softmax(this); ops::Register_Softmax(this);
ops::Register_Stack(this); ops::Register_Stack(this);
ops::Register_Unstack(this);
ops::Register_StridedSlice(this); ops::Register_StridedSlice(this);
ops::Register_SpaceToBatchND(this); ops::Register_SpaceToBatchND(this);
ops::Register_SpaceToDepth(this); ops::Register_SpaceToDepth(this);
...@@ -116,6 +119,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() { ...@@ -116,6 +119,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() {
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
ops::Register_BufferToImage(this); ops::Register_BufferToImage(this);
ops::Register_ImageToBuffer(this); ops::Register_ImageToBuffer(this);
ops::Register_LSTMCell(this);
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
} }
......
...@@ -201,6 +201,21 @@ class OpsTestNet { ...@@ -201,6 +201,21 @@ class OpsTestNet {
} }
} }
template <DeviceType D, typename T>
void CopyData(const std::string &src_name,
const std::string &dst_name) {
Tensor *input = ws_.GetTensor(src_name);
Tensor *output = ws_.CreateTensor(dst_name, GetDeviceAllocator(D),
DataTypeToEnum<T>::v());
const std::vector<index_t> input_shape = input->shape();
output->Resize(input_shape);
Tensor::MappingGuard input_guard(input);
const T *input_data = input->data<T>();
output->CopyBytes(input->raw_data(), input->size() * input->SizeOfType());
}
template <DeviceType D, typename T> template <DeviceType D, typename T>
void TransformDataFormat(const std::string &src_name, void TransformDataFormat(const std::string &src_name,
const DataFormat src_format, const DataFormat src_format,
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/unstack.h"
namespace mace {
namespace ops {
void Register_Unstack(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Unstack")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
UnstackOp<DeviceType::CPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Unstack")
.Device(DeviceType::CPU)
.TypeConstraint<int32_t>("T")
.Build(),
UnstackOp<DeviceType::CPU, int32_t>);
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_UNSTACK_H_
#define MACE_OPS_UNSTACK_H_
#include <vector>
#include "mace/core/operator.h"
#include "mace/kernels/unstack.h"
namespace mace {
namespace ops {
template <DeviceType D, class T>
class UnstackOp : public Operator<D, T> {
public:
UnstackOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(OperatorBase::GetOptionalArg<int>("axis", 0)) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const std::vector<Tensor *> outputs = this->Outputs();
return functor_(input, outputs, future);
}
private:
kernels::UnstackFunctor<D, T> functor_;
protected:
MACE_OP_OUTPUT_TAGS(INPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_UNSTACK_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class UnstackOpTest : public OpsTestBase {};
namespace {
void TestUnstack(const std::vector<index_t> &input_shape,
const std::vector<float> &input,
int axis,
const std::vector<index_t> &output_shape,
const std::vector<std::vector<float>> &outputs) {
OpsTestNet net;
net.AddInputFromArray<CPU, float>("Input", input_shape, input);
auto op_builder = OpDefBuilder("Unstack", "UnstackOpTest")
.Input("Input")
.AddIntArg("axis", axis);
for (size_t i = 0; i < outputs.size(); ++i) {
op_builder.Output(MakeString("Output", i));
}
op_builder.Finalize(net.NewOperatorDef());
net.RunOp();
for (size_t i = 0; i < outputs.size(); ++i) {
LOG(INFO) << MakeString("Output", i);
net.AddInputFromArray<CPU, float>("ExpectedOutput", output_shape,
outputs[i]);
ExpectTensorNear<float>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput(MakeString("Output", i).c_str()));
}
}
} // namespace
TEST_F(UnstackOpTest, TestUnstackScalar) {
TestUnstack({3}, {1, 2, 3}, 0, {}, {{1}, {2}, {3}});
}
TEST_F(UnstackOpTest, TestUnstackVector) {
TestUnstack({3, 2}, {1, 4, 2, 5, 3, 6}, 0, {2}, {{1, 4}, {2, 5}, {3, 6}});
TestUnstack({3, 2}, {1, 4, 2, 5, 3, 6}, -2, {2}, {{1, 4}, {2, 5}, {3, 6}});
TestUnstack({2, 3}, {1, 2, 3, 4, 5, 6}, 1, {2}, {{1, 4}, {2, 5}, {3, 6}});
}
TEST_F(UnstackOpTest, TestUnstackHighRank) {
TestUnstack({2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, -3, {2, 3},
{{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11, 12}});
TestUnstack({2, 2, 3}, {1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}, 1, {2, 3},
{{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11, 12}});
TestUnstack({2, 3, 2}, {1, 7, 2, 8, 3, 9, 4, 10, 5, 11, 6, 12}, 2, {2, 3},
{{1, 2, 3, 4, 5, 6}, {7, 8, 9, 10, 11, 12}});
}
} // namespace test
} // namespace ops
} // namespace mace
...@@ -44,18 +44,26 @@ def calculate_image_shape(buffer_type, shape, winograd_blk_size=0): ...@@ -44,18 +44,26 @@ def calculate_image_shape(buffer_type, shape, winograd_blk_size=0):
image_shape[0] = shape[1] image_shape[0] = shape[1]
image_shape[1] = shape[2] * shape[3] * roundup_div4(shape[0]) image_shape[1] = shape[2] * shape[3] * roundup_div4(shape[0])
elif buffer_type == OpenCLBufferType.IN_OUT_CHANNEL: elif buffer_type == OpenCLBufferType.IN_OUT_CHANNEL:
mace_check(len(shape) == 4, "Conv2D input/output buffer should be 4D") mace_check(len(shape) == 2 or len(shape) == 4,
image_shape[0] = roundup_div4(shape[3]) * shape[2] "input/output buffer should be 2D|4D")
image_shape[1] = shape[0] * shape[1] if len(shape) == 4:
image_shape[0] = roundup_div4(shape[3]) * shape[2]
image_shape[1] = shape[0] * shape[1]
elif len(shape) == 2:
image_shape[0] = roundup_div4(shape[1])
image_shape[1] = shape[0]
elif buffer_type == OpenCLBufferType.ARGUMENT: elif buffer_type == OpenCLBufferType.ARGUMENT:
mace_check(len(shape) == 1, mace_check(len(shape) == 1,
"Argument buffer should be 1D not " + str(shape)) "Argument buffer should be 1D not " + str(shape))
image_shape[0] = roundup_div4(shape[0]) image_shape[0] = roundup_div4(shape[0])
image_shape[1] = 1 image_shape[1] = 1
elif buffer_type == OpenCLBufferType.IN_OUT_HEIGHT: elif buffer_type == OpenCLBufferType.IN_OUT_HEIGHT:
mace_check(len(shape) == 4, "Input/output buffer should be 4D") if len(shape) == 4:
image_shape[0] = shape[2] * shape[3] image_shape[0] = shape[2] * shape[3]
image_shape[1] = shape[0] * roundup_div4(shape[1]) image_shape[1] = shape[0] * roundup_div4(shape[1])
elif len(shape) == 2:
image_shape[0] = shape[0]
image_shape[1] = roundup_div4(shape[1])
elif buffer_type == OpenCLBufferType.IN_OUT_WIDTH: elif buffer_type == OpenCLBufferType.IN_OUT_WIDTH:
mace_check(len(shape) == 4, "Input/output buffer should be 4D") mace_check(len(shape) == 4, "Input/output buffer should be 4D")
image_shape[0] = roundup_div4(shape[2]) * shape[3] image_shape[0] = roundup_div4(shape[2]) * shape[3]
......
...@@ -93,6 +93,7 @@ MaceSupportedOps = [ ...@@ -93,6 +93,7 @@ MaceSupportedOps = [
'Gather', 'Gather',
'Identity', 'Identity',
'LocalResponseNorm', 'LocalResponseNorm',
'LSTMCell',
'MatMul', 'MatMul',
'Pad', 'Pad',
'Pooling', 'Pooling',
...@@ -107,6 +108,7 @@ MaceSupportedOps = [ ...@@ -107,6 +108,7 @@ MaceSupportedOps = [
'Shape', 'Shape',
'Squeeze', 'Squeeze',
'Stack', 'Stack',
'Unstack',
'StridedSlice', 'StridedSlice',
'Softmax', 'Softmax',
'SpaceToBatchND', 'SpaceToBatchND',
...@@ -198,6 +200,9 @@ class TransformerRule(Enum): ...@@ -198,6 +200,9 @@ class TransformerRule(Enum):
QUANTIZE_NODES = 23 QUANTIZE_NODES = 23
ADD_QUANTIZE_TENSOR_RANGE = 24 ADD_QUANTIZE_TENSOR_RANGE = 24
QUANTIZE_WEIGHTS = 25 QUANTIZE_WEIGHTS = 25
TRANSPOSE_MATMUL_WEIGHT = 26
TRANSFORM_LSTMCELL_ZEROSTATE = 27
TRANSFORM_BASIC_LSTMCELL = 28
class ConverterInterface(object): class ConverterInterface(object):
...@@ -336,6 +341,8 @@ class ConverterOption(object): ...@@ -336,6 +341,8 @@ class ConverterOption(object):
# Model structure related transformation # Model structure related transformation
TransformerRule.REMOVE_IDENTITY_OP, TransformerRule.REMOVE_IDENTITY_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING, TransformerRule.TRANSFORM_GLOBAL_POOLING,
TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE,
TransformerRule.TRANSFORM_BASIC_LSTMCELL,
TransformerRule.FOLD_RESHAPE, TransformerRule.FOLD_RESHAPE,
TransformerRule.TRANSFORM_MATMUL_TO_FC, TransformerRule.TRANSFORM_MATMUL_TO_FC,
TransformerRule.FOLD_BATCHNORM, TransformerRule.FOLD_BATCHNORM,
......
...@@ -96,6 +96,8 @@ TFSupportedOps = [ ...@@ -96,6 +96,8 @@ TFSupportedOps = [
'Slice', 'Slice',
'Stack', 'Stack',
'Pack', 'Pack',
'Unstack',
'Unpack',
'Cast', 'Cast',
'ArgMax', 'ArgMax',
'Split', 'Split',
...@@ -196,6 +198,8 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -196,6 +198,8 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.Slice.name: self.convert_slice, TFOpType.Slice.name: self.convert_slice,
TFOpType.Pack.name: self.convert_stack, TFOpType.Pack.name: self.convert_stack,
TFOpType.Stack.name: self.convert_stack, TFOpType.Stack.name: self.convert_stack,
TFOpType.Unpack.name: self.convert_unstack,
TFOpType.Unstack.name: self.convert_unstack,
TFOpType.Cast.name: self.convert_cast, TFOpType.Cast.name: self.convert_cast,
TFOpType.ArgMax.name: self.convert_argmax, TFOpType.ArgMax.name: self.convert_argmax,
TFOpType.Split.name: self.convert_split, TFOpType.Split.name: self.convert_split,
...@@ -774,6 +778,17 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -774,6 +778,17 @@ class TensorflowConverter(base_converter.ConverterInterface):
except ValueError: except ValueError:
axis_arg.i = 0 axis_arg.i = 0
def convert_unstack(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Unstack.name
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
try:
axis_arg.i = tf_op.get_attr(MaceKeyword.mace_axis_str)
except ValueError:
axis_arg.i = 0
def convert_cast(self, tf_op): def convert_cast(self, tf_op):
op = self.convert_general_op(tf_op) op = self.convert_general_op(tf_op)
op.type = MaceOp.Cast.name op.type = MaceOp.Cast.name
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import enum import enum
import numpy as np import numpy as np
import re
from mace.proto import mace_pb2 from mace.proto import mace_pb2
from mace.python.tools.converter_tool import base_converter from mace.python.tools.converter_tool import base_converter
...@@ -49,6 +50,10 @@ class Transformer(base_converter.ConverterInterface): ...@@ -49,6 +50,10 @@ class Transformer(base_converter.ConverterInterface):
TransformerRule.REMOVE_IDENTITY_OP: self.remove_identity_op, TransformerRule.REMOVE_IDENTITY_OP: self.remove_identity_op,
TransformerRule.TRANSFORM_GLOBAL_POOLING: TransformerRule.TRANSFORM_GLOBAL_POOLING:
self.transform_global_pooling, self.transform_global_pooling,
TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE:
self.transform_lstmcell_zerostate,
TransformerRule.TRANSFORM_BASIC_LSTMCELL:
self.transform_basic_lstmcell,
TransformerRule.FOLD_RESHAPE: self.fold_reshape, TransformerRule.FOLD_RESHAPE: self.fold_reshape,
TransformerRule.TRANSFORM_MATMUL_TO_FC: TransformerRule.TRANSFORM_MATMUL_TO_FC:
self.transform_matmul_to_fc, self.transform_matmul_to_fc,
...@@ -332,6 +337,154 @@ class Transformer(base_converter.ConverterInterface): ...@@ -332,6 +337,154 @@ class Transformer(base_converter.ConverterInterface):
return False return False
def transform_lstmcell_zerostate(self):
net = self._model
zero_state_pattern = \
re.compile(r'^.*BasicLSTMCellZeroState_?[0-9]*/[a-zA-Z]+_?[0-9]*') # noqa
for op in net.op:
if op.type == MaceOp.Fill.name and \
zero_state_pattern.match(op.name):
print("Transform lstm zerostate")
concat_op = self._producer[op.input[0]]
consumer_op = self._consumers[op.output[0]][0]
dims = [self._consts[concat_op.input[0]].int32_data[0],
self._consts[concat_op.input[1]].int32_data[0]]
tensor_def = net.tensors.add()
tensor_def.name = op.output[0].replace('/zeros', '/init_const')
tensor_def.dims.extend(dims)
tensor_def.data_type = self._consts[op.input[1]].data_type
tensor_def.float_data.extend(
[self._consts[op.input[1]].float_data[0]] *
(dims[0] * dims[1]))
for i in range(len(consumer_op.input)):
if zero_state_pattern.match(consumer_op.input[i][:-2]):
consumer_op.input[i] = tensor_def.name
net.tensors.remove(self._consts[op.input[1]])
net.tensors.remove(self._consts[concat_op.input[0]])
net.tensors.remove(self._consts[concat_op.input[1]])
net.op.remove(concat_op)
net.op.remove(op)
return True
def transform_basic_lstmcell(self):
if self._option.device != DeviceType.GPU.value:
return False
net = self._model
basic_lstm_concat_pattern = \
re.compile(r'^.*basic_lstm_cell_?[0-9]*/concat_?[0-9]*')
for op in net.op:
if op.type == MaceOp.Concat.name and \
basic_lstm_concat_pattern.match(op.name):
print("Transform basic lstmcell")
ops_to_delete = []
ops_to_delete.extend([op])
op_def = net.op.add()
op_def.name = op.name.replace('/concat', '/folded_lstmcell')
op_def.type = MaceOp.LSTMCell.name
op_def.arg.extend(op.arg[:-1])
# Concat pre output and cur input
# extend concat inputs
op_def.input.extend([op_input for op_input in op.input])
# lstm MatMul in FC of [pre_output, cur_input]
matmul_op = self._consumers[op.output[0]][0]
ops_to_delete.extend([matmul_op])
# extend MatMul weight input
op_def.input.extend([matmul_op.input[1]])
# lstm BiasAdd in FC of [pre_output, cur_input]
biasadd_op = self._consumers[matmul_op.output[0]][0]
ops_to_delete.extend([biasadd_op])
# extend BiasAdd bias input
op_def.input.extend([biasadd_op.input[1]])
# Split FC output into i, j, f, o
# i = input_gate, j = new_input, f = forget_gate, o = output_gate # noqa
split_op = self._consumers[biasadd_op.output[0]][0]
ops_to_delete.extend([split_op])
# input gate activation
input_gate_op = self._consumers[split_op.output[0]][0]
ops_to_delete.extend([input_gate_op])
# new input gate
new_input_tanh_op = self._consumers[split_op.output[1]][0]
ops_to_delete.extend([new_input_tanh_op])
# forget gate add
forget_add_op = self._consumers[split_op.output[2]][0]
ops_to_delete.extend([forget_add_op])
# output gate activation
output_gate_op = self._consumers[split_op.output[3]][0]
ops_to_delete.extend([output_gate_op])
# extend forget add
mace_check(len(forget_add_op.input) == 1,
'Wrong LSTM format in forget gate inputs')
for arg in forget_add_op.arg:
if arg.name == MaceKeyword.mace_scalar_input_str:
op_def.arg.extend([arg])
# state remember
remember_mul_op = self._consumers[input_gate_op.output[0]][0]
ops_to_delete.extend([remember_mul_op])
mace_check(remember_mul_op.name == self._consumers[
new_input_tanh_op.output[0]][0].name,
'Wrong LSTM format in input sig & input tanh mul')
# forget gate activation
forget_gate_op = self._consumers[forget_add_op.output[0]][0]
ops_to_delete.extend([forget_gate_op])
# Mul `forget` & `pre cell state`
forget_mul_op = self._consumers[forget_gate_op.output[0]][0]
ops_to_delete.extend([forget_mul_op])
# extend pre cell state input
op_def.input.extend([forget_mul_op.input[0]])
# get cur cell state
# Add `forget gate output` & `remember mul output`
remember_forget_add_op = \
self._consumers[remember_mul_op.output[0]][0]
ops_to_delete.extend([remember_forget_add_op])
mace_check(remember_forget_add_op.name ==
self._consumers[forget_mul_op.output[0]][0].name,
'Wrong LSTM format in add forget gate & remember mul') # noqa
op_def.output.extend([remember_forget_add_op.output[0]])
op_def.output_shape.extend(remember_forget_add_op.output_shape)
# cell state output tanh
for consumer in \
self._consumers[remember_forget_add_op.output[0]]:
if consumer.type == MaceOp.Activation.name and \
consumer.name.find('basic_lstm_cell') > 0:
cell_tanh_op = consumer
ops_to_delete.extend([cell_tanh_op])
# final mul, get output
final_mul_op = self._consumers[cell_tanh_op.output[0]][0]
ops_to_delete.extend([final_mul_op])
mace_check(final_mul_op.name ==
self._consumers[output_gate_op.output[0]][0].name,
'Wrong LSTM format in final mul')
op_def.output.extend([final_mul_op.output[0]])
op_def.output_shape.extend(final_mul_op.output_shape)
for op_to_del in ops_to_delete:
net.op.remove(op_to_del)
return True
return False
def fold_conv_and_bn(self): def fold_conv_and_bn(self):
net = self._model net = self._model
for op in net.op: for op in net.op:
...@@ -1156,6 +1309,15 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1156,6 +1309,15 @@ class Transformer(base_converter.ConverterInterface):
if ConverterUtil.get_arg(op, if ConverterUtil.get_arg(op,
MaceKeyword.mace_activation_type_str).s == ActivationType.PRELU.name: # noqa MaceKeyword.mace_activation_type_str).s == ActivationType.PRELU.name: # noqa
self.buffer_to_image(op, 1, OpenCLBufferType.ARGUMENT) self.buffer_to_image(op, 1, OpenCLBufferType.ARGUMENT)
elif op.type == MaceOp.LSTMCell.name:
if op.input[1] in self._consts:
self.buffer_to_image(op, 1,
OpenCLBufferType.IN_OUT_CHANNEL)
self.buffer_to_image(op, 2, OpenCLBufferType.IN_OUT_CHANNEL)
self.buffer_to_image(op, 3, OpenCLBufferType.ARGUMENT)
if op.input[4] in self._consts:
self.buffer_to_image(op, 4,
OpenCLBufferType.IN_OUT_CHANNEL)
# Add OpenCL max image size # Add OpenCL max image size
arg = net.arg.add() arg = net.arg.add()
......
...@@ -240,7 +240,7 @@ class GPUMemoryOptimizer(MemoryOptimizer): ...@@ -240,7 +240,7 @@ class GPUMemoryOptimizer(MemoryOptimizer):
op_type) op_type)
else: else:
if len(output_shape) == 2: # only support fc/softmax if len(output_shape) == 2: # only support fc/softmax
buffer_shape = [output_shape[0], 1, 1, output_shape[1]] buffer_shape = [output_shape[0], output_shape[1]]
elif len(output_shape) == 4: elif len(output_shape) == 4:
buffer_shape = output_shape buffer_shape = output_shape
else: else:
......
...@@ -37,6 +37,7 @@ def _opencl_encrypt_kernel_impl(repository_ctx): ...@@ -37,6 +37,7 @@ def _opencl_encrypt_kernel_impl(repository_ctx):
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/depthwise_conv2d.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/depthwise_conv2d.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/eltwise.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/eltwise.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/fully_connected.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/fully_connected.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/lstmcell.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/matmul.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/matmul.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/pad.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/pad.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/pooling.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/pooling.cl"))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册