提交 0780433b 编写于 作者: R Robert David 提交者: TensorFlower Gardener

LSTM: Split projection calculations to separate functions.

PiperOrigin-RevId: 318157584
Change-Id: I9bbee585b18fc508e6afb166e394e8d4bf50940d
上级 690a77bf
......@@ -127,6 +127,210 @@ inline float GetTensorScale(const TfLiteTensor* tensor) {
return tensor == nullptr ? 1.0f : tensor->params.scale;
}
// Calculates the output state tensor of an LSTM step.
//
// Implements the following formula:
// output_no_projection = output_gate .* activate(cell_state)
// (elementwise vector product)
// If no projection is used:
// output = output_state = output_no_projection
// With projection:
// output = output_state = clip(W*output_no_projection + bias)
//
// Output might not have a different 'stride' than n_batch, so we need to copy.
//
// Parameters:
// - n_batch: batches: the number of distinct vectors in each array.
// - n_cell, n_output: sizes of vectors.
// - cell_state, output_gate: input vectors, size n_batch*n_cell.
// - projection_weights, projection_weights_scale, projection_bias:
// constant inputs, describing projection matrix and bias.
// - proj_clip: if > 0, clip the output of the projection.
// - output_state: output vector, size n_batch*n_output. Must be contigous.
// - scratch: scratch area, size n_batch*n_cell.
// LINT.IfChange
void CalculateLstmOutputFloat(int n_batch, int n_cell, int n_output,
const float* cell_state, const float* output_gate,
TfLiteFusedActivation activation,
const float* projection_weights,
const float* projection_bias,
const float proj_clip, float* output_state,
float* scratch) {
tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
activation, scratch);
tensor_utils::VectorVectorCwiseProduct(output_gate, scratch, n_batch * n_cell,
scratch);
const bool use_projection = (projection_weights != nullptr);
const bool use_projection_bias = (projection_bias != nullptr);
if (use_projection) {
if (use_projection_bias) {
tensor_utils::VectorBatchVectorAssign(projection_bias, n_output, n_batch,
output_state);
} else {
std::fill_n(output_state, n_batch * n_output, 0.0f);
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights, n_output, n_cell, scratch, n_batch, output_state);
if (proj_clip > 0.0f) {
tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip);
}
} else {
std::copy_n(scratch, n_batch * n_output, output_state);
}
}
// LINT.ThenChange(//tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc)
// Calculates the output state tensor of an LSTM step. See Float version too.
//
// Parameters:
// - n_batch: batches: the number of distinct vectors in each array.
// - n_cell, n_output: sizes of vectors.
// - cell_state, output_gate: input vectors, size n_batch*n_cell.
// - projection_weights, projection_weights_scale, projection_bias:
// constant inputs, describing projection matrix and bias.
// - proj_clip: if > 0, clip the output of the projection.
// - output_state: output vector, size n_batch*n_output. Must be contigous.
// - asymmetric_quantize_inputs: parameter to control quantization.
// - projection_weights_row_sums, compute_row_sums, context: Data for optimized
// MatrixBatchVectorMultiplyAccumulate.
// - scratch0: scratch area of size n_batch*n_cell
// - scratch1: scratch area of size n_batch*n_cell
// - scratch2: scratch area of size n_batch
// - scratch3: scratch area of size n_batch
// - scratch4: scratch area used by MatrixBatchVectorMultiplyAccumulate
void CalculateLstmOutputHybrid(
int n_batch, int n_cell, int n_output, const float* cell_state,
const float* output_gate, TfLiteFusedActivation activation,
const int8_t* projection_weights, float projection_weights_scale,
const float* projection_bias, const float proj_clip, float* output_state,
bool asymmetric_quantize_inputs, int32_t* projection_weights_row_sums,
bool* compute_row_sums, CpuBackendContext* context, float* scratch0,
int8_t* scratch1, float* scratch2, int32_t* scratch3, int32_t* scratch4) {
tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
activation, scratch0);
tensor_utils::VectorVectorCwiseProduct(output_gate, scratch0,
n_batch * n_cell, scratch0);
const bool use_projection = (projection_weights != nullptr);
const bool use_projection_bias = (projection_bias != nullptr);
if (use_projection) {
if (use_projection_bias) {
tensor_utils::VectorBatchVectorAssign(projection_bias, n_output, n_batch,
output_state);
} else {
std::fill_n(output_state, n_batch * n_output, 0.0f);
}
if (!tensor_utils::IsZeroVector(scratch0, n_batch * n_cell)) {
// Save quantization and matmul computation for all zero output.
tensor_utils::BatchQuantizeFloats(scratch0, n_batch, n_cell, scratch1,
scratch2, scratch3,
asymmetric_quantize_inputs);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights, n_output, n_cell, scratch1,
projection_weights_scale, scratch2, n_batch, output_state,
/*per_channel_scale=*/nullptr,
asymmetric_quantize_inputs ? scratch3 : nullptr, scratch4,
projection_weights_row_sums, compute_row_sums, scratch2, context);
}
if (proj_clip > 0.0f) {
tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip);
}
} else {
std::copy_n(scratch0, n_batch * n_output, output_state);
}
}
// Calculates the output state tensor of an LSTM step. See Float and hybrid
// versions as well.
//
// Parameters:
// - n_batch: batches: the number of distinct vectors in each array.
// - n_cell, n_output: sizes of vectors.
// - cell_state, output_gate: input vectors, size n_batch*n_cell.
// - cell_state_scale: scaling of cell_state.
// - effective_hidden_scale_[a|b]: effective scale of cell_state.*output_gate
// - hidden_zp: zero_point for cell_state.*output_gate
// - projection_weights, effective_proj_scale_[a|b], projection_effective_bias:
// constant inputs, describing projection matrix and bias.
// - output_state_zp: zero point of output_state. (Input, calibrated value.)
// - quantized_proj_clip: if > 0, clip the output of the projection.
// - output_state: output vector, size n_batch*n_output. Must be contigous.
// - context: data for optimized MatrixBatchVectorMultiplyAccumulate.
// - scratch0: scratch area of size n_batch*n_cell
// - scratch1: scratch area of size n_batch*n_cell
// - scratch2: scratch area used by MatrixBatchVectorMultiplyAccumulate
void CalculateLstmOutputInteger8x8_16(
int n_batch, int n_cell, int n_output, const int16_t* cell_state,
int32_t cell_state_scale, const int16_t* output_gate,
int32_t effective_hidden_scale_a, int32_t effective_hidden_scale_b,
int32_t hidden_zp, const int8_t* projection_weights,
int32_t effective_proj_scale_a, int32_t effective_proj_scale_b,
const int32_t* projection_effective_bias, int32_t output_state_zp,
int8_t quantized_proj_clip, int8_t* output_state,
CpuBackendContext* context, int16_t* scratch0, int8_t* scratch1,
int32_t* scratch2) {
// Note: unlike float/hybrid, the activation is always Tanh.
tensor_utils::ApplyTanh(15 + cell_state_scale, cell_state, n_batch, n_cell,
scratch0);
tensor_utils::CwiseMul(output_gate, scratch0, effective_hidden_scale_a,
effective_hidden_scale_b, n_batch, n_cell, hidden_zp,
scratch1);
const bool use_projection = (projection_weights != nullptr);
if (use_projection) {
// Note: no bias like in float/hybrid
std::fill_n(output_state, n_batch * n_output, 0);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
scratch1, projection_effective_bias, projection_weights,
effective_proj_scale_a, effective_proj_scale_b, n_batch, n_cell,
n_output, output_state_zp, scratch2, output_state, context);
if (quantized_proj_clip > 0) {
tensor_utils::CwiseClipping(output_state, n_batch * n_output,
quantized_proj_clip);
}
} else {
std::copy_n(scratch1, n_batch * n_output, output_state);
}
}
// Calculates the output state tensor of an LSTM step. See Float and hybrid
// versions as well.
//
// Parameters:
// - n_batch: batches: the number of distinct vectors in each array.
// - n_cell, n_output: sizes of vectors.
// - cell_state, output_gate: input vectors, size n_batch*n_cell.
// - projection_weights, effective_proj_scale_[a|b], projection_bias:
// constant inputs, describing projection matrix and bias.
// - output_state_zp: zero point of the output state.
// - quantized_proj_clip: if > 0, clip the output of the projection.
// - output_state: output vector, size n_batch*n_output. Must be contigous.
// - scratch: scratch area of size n_batch*n_cell
void CalculateLstmOutputInteger8x8_8(
int n_batch, int n_cell, int n_output, const int16_t* cell_state,
const int16_t* output_gate, const int8_t* projection_weights,
int32_t effective_proj_scale_a, int32_t effective_proj_scale_b,
const int32_t* projection_bias, int32_t output_state_zp,
int32_t quantized_proj_clip, int8_t* output_state, int16_t* scratch) {
// Note: unlike float/hybrid, the activation is always Tanh.
tensor_utils::ApplyTanhFloat(cell_state, n_batch, n_cell, -15, scratch);
tensor_utils::CwiseMul(output_gate, scratch, n_batch, n_cell, 15 + 15 - 15,
scratch);
// Note: no bias like in float/hybrid
tensor_utils::MatrixBatchVectorMultiply(
scratch, projection_weights, effective_proj_scale_a,
effective_proj_scale_b, projection_bias, n_batch, n_cell, n_output,
output_state_zp, output_state);
if (quantized_proj_clip > 0) {
tensor_utils::CwiseClipping(output_state, n_batch * n_output,
quantized_proj_clip);
}
}
// Performs an LSTM batch inference step for input specified by input_ptr.
// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
// biases (*_bias_ptr), and buffers (*_scratch), along with additional
......@@ -395,32 +599,12 @@ inline void LstmStepFloat(
}
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
output_gate_scratch);
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
params->activation, cell_gate_scratch);
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_gate_scratch,
n_batch * n_cell, output_gate_scratch);
const bool use_projection_weight = (projection_weights_ptr != nullptr);
const bool use_projection_bias = (projection_bias_ptr != nullptr);
CalculateLstmOutputFloat(n_batch, n_cell, n_output, cell_state_ptr,
output_gate_scratch, params->activation,
projection_weights_ptr, projection_bias_ptr,
params->proj_clip, output_state_ptr, scratch2);
// For each batch: update output_state.
if (use_projection_weight) {
if (use_projection_bias) {
tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
n_batch, output_state_ptr);
} else {
std::fill_n(output_state_ptr, n_batch * n_output, 0.0f);
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch,
output_state_ptr);
if (params->proj_clip > 0.0) {
tensor_utils::CwiseClipping(output_state_ptr, n_batch * n_output,
params->proj_clip);
}
} else {
std::copy_n(output_gate_scratch, n_batch * n_output, output_state_ptr);
}
// Copy output_state to the output. Note that the output batch rows may not be
// contiguous (output_batch_leading_dim != n_output).
for (int b = 0; b < n_batch; b++) {
......@@ -861,44 +1045,17 @@ inline void LstmStepHybrid(
}
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
output_gate_scratch);
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
params->activation, cell_gate_scratch);
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_gate_scratch,
n_batch * n_cell, output_gate_scratch);
const bool use_projection_weight = (projection_weights_ptr != nullptr);
const bool use_projection_bias = (projection_bias_ptr != nullptr);
CalculateLstmOutputHybrid(
n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
params->activation, projection_weights_ptr, projection_weights_scale,
projection_bias_ptr, params->proj_clip, output_state_ptr,
asymmetric_quantize_inputs, projection_weights_row_sums, compute_row_sums,
context, scratch2, quantized_output_scratch, scaling_factors, zero_points,
accum_scratch_ptr);
// For each batch: update the projection and output_state. Note that since
// the output batch rows may not be contiguous (output_batch_leading_dim !=
// n_output), we unroll the batched operations.
if (use_projection_weight) {
if (use_projection_bias) {
tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
n_batch, output_state_ptr);
} else {
std::fill_n(output_state_ptr, n_batch * n_output, 0.0f);
}
if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
// Save quantization and matmul computation for all zero input.
tensor_utils::BatchQuantizeFloats(
output_gate_scratch, n_batch, n_cell, quantized_output_scratch,
scaling_factors, zero_points, asymmetric_quantize_inputs);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights_ptr, n_output, n_cell, quantized_output_scratch,
projection_weights_scale, scaling_factors, n_batch, output_state_ptr,
/*per_channel_scale=*/nullptr,
asymmetric_quantize_inputs ? zero_points : nullptr, accum_scratch_ptr,
projection_weights_row_sums, compute_row_sums,
scaling_factors_scratch, context);
}
if (params->proj_clip > 0.0) {
tensor_utils::CwiseClipping(output_state_ptr, n_batch * n_output,
params->proj_clip);
}
} else {
std::copy_n(output_gate_scratch, n_batch * n_output, output_state_ptr);
}
// Copy output_state_ptr to the output. Note that the output batch rows may
// not be contiguous (output_batch_leading_dim != n_output).
for (int b = 0; b < n_batch; b++) {
std::copy_n(output_state_ptr + b * n_output, n_output,
output_ptr + b * output_batch_leading_dim);
......@@ -1071,7 +1228,6 @@ inline void LstmStepInteger8x8_16(
const bool use_cifg = (input_to_input_weight_ptr == nullptr);
const bool use_peephole = (cell_to_output_weight_ptr != nullptr);
const bool use_layer_norm = (layer_norm_forget_weight_ptr != nullptr);
const bool use_projection = (projection_weight_ptr != nullptr);
// Check for nullptrs.
TFLITE_DCHECK(input_to_forget_effective_bias);
......@@ -1219,28 +1375,17 @@ inline void LstmStepInteger8x8_16(
tensor_utils::ApplySigmoid(output_gate_scratch, n_batch, n_cell,
output_gate_scratch);
// Hidden.
tensor_utils::ApplyTanh(15 + cell_state_scale, cell_state_ptr, n_batch,
n_cell, input_gate_scratch);
tensor_utils::CwiseMul(output_gate_scratch, input_gate_scratch,
effective_hidden_scale_a, effective_hidden_scale_b,
n_batch, n_cell, hidden_zp, scratch4);
// Projection.
if (use_projection) {
std::fill_n(output_ptr, n_batch * n_output, 0);
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
scratch4, projection_effective_bias, projection_weight_ptr,
effective_proj_scale_a, effective_proj_scale_b, n_batch, n_cell,
n_output, output_state_zp, scratch5, output_ptr, context);
if (quantized_proj_clip > 0) {
tensor_utils::CwiseClipping(output_ptr, n_batch * n_output,
quantized_proj_clip);
}
} else {
std::copy_n(scratch4, n_batch * n_output, output_ptr);
}
std::copy_n(output_ptr, n_batch * n_output, output_state_ptr);
CalculateLstmOutputInteger8x8_16(
n_batch, n_cell, n_output, cell_state_ptr, cell_state_scale,
output_gate_scratch, effective_hidden_scale_a, effective_hidden_scale_b,
hidden_zp, projection_weight_ptr, effective_proj_scale_a,
effective_proj_scale_b, projection_effective_bias, output_state_zp,
quantized_proj_clip, output_state_ptr, context, scratch0, scratch4,
scratch5);
// Copy output state to the output. Note that unlike float or hybrid, output
// is always contigous.
std::copy_n(output_state_ptr, n_batch * n_output, output_ptr);
}
// Fully quantized lstm kernel for 8 bit gate matmul output.
......@@ -1502,27 +1647,15 @@ inline void LstmStepInteger8x8_8(
quantized_cell_clip);
}
// Cell to hidden.
tensor_utils::ApplyTanhFloat(cell_state_ptr, n_batch, n_cell, -15,
forget_gate_scratch);
tensor_utils::CwiseMul(output_gate_scratch, forget_gate_scratch, n_batch,
n_cell, 15 + 15 - 15, cell_gate_scratch);
// Projection.
tensor_utils::MatrixBatchVectorMultiply(
cell_gate_scratch, projection_weight_ptr, effective_proj_scale_a,
effective_proj_scale_b, projection_bias_ptr, n_batch, n_cell, n_output,
output_state_zp, output_ptr);
// Projection clipping.
if (quantized_proj_clip > 0) {
tensor_utils::CwiseClipping(output_ptr, n_batch * n_output,
quantized_proj_clip);
}
CalculateLstmOutputInteger8x8_8(
n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
projection_weight_ptr, effective_proj_scale_a, effective_proj_scale_b,
projection_bias_ptr, output_state_zp, quantized_proj_clip,
output_state_ptr, scratch2);
// Copy output to output state.
std::copy_n(output_ptr, n_batch * n_output, output_state_ptr);
// Copy output state to the output. Note that unlike float or hybrid, output
// is always contigous.
std::copy_n(output_state_ptr, n_batch * n_output, output_ptr);
}
} // namespace
......
......@@ -37,6 +37,41 @@ namespace builtin {
namespace {
void CalculateLstmOutputFloat(
int n_batch, int n_cell, int n_output, const float* cell_state,
const float* output_gate, TfLiteFusedActivation activation,
const float* projection_weights, const float* projection_bias,
const float proj_clip, float* output_state, float* scratch, Logger* logger,
const std::vector<int>& intermediate_tensor_indexes,
ErrorReporter* error_reporter) {
tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
activation, scratch);
tensor_utils::VectorVectorCwiseProduct(output_gate, scratch, n_batch * n_cell,
scratch);
logger->LogTensorValue(intermediate_tensor_indexes[4], scratch,
n_cell * n_batch, error_reporter);
const bool use_projection = (projection_weights != nullptr);
const bool use_projection_bias = (projection_bias != nullptr);
if (use_projection) {
if (use_projection_bias) {
tensor_utils::VectorBatchVectorAssign(projection_bias, n_output, n_batch,
output_state);
} else {
std::fill_n(output_state, n_batch * n_output, 0.0f);
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights, n_output, n_cell, scratch, n_batch, output_state);
if (proj_clip > 0.0f) {
tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip);
}
} else {
std::copy_n(scratch, n_batch * n_output, output_state);
}
}
inline void LstmStepWithAuxInput(
const float* input_ptr, const float* input_to_input_weights_ptr,
const float* input_to_forget_weights_ptr,
......@@ -245,35 +280,13 @@ inline void LstmStepWithAuxInput(
}
tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
output_gate_scratch);
tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
params->activation, cell_gate_scratch);
tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_gate_scratch,
n_batch * n_cell, output_gate_scratch);
logger->LogTensorValue(intermediate_tensor_indexes[4], output_gate_scratch,
n_cell * n_batch, error_reporter);
const bool use_projection_weight = (projection_weights_ptr != nullptr);
const bool use_projection_bias = (projection_bias_ptr != nullptr);
CalculateLstmOutputFloat(n_batch, n_cell, n_output, cell_state_ptr,
output_gate_scratch, params->activation,
projection_weights_ptr, projection_bias_ptr,
params->proj_clip, output_state_ptr, scratch2,
logger, intermediate_tensor_indexes, error_reporter);
// For each batch: update output_state.
if (use_projection_weight) {
if (use_projection_bias) {
tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
n_batch, output_state_ptr);
} else {
std::fill_n(output_state_ptr, n_batch * n_output, 0.0f);
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch,
output_state_ptr);
if (params->proj_clip > 0.0) {
tensor_utils::CwiseClipping(output_state_ptr, n_batch * n_output,
params->proj_clip);
}
} else {
std::copy_n(output_gate_scratch, n_batch * n_output, output_state_ptr);
}
// Copy output_state to the output. Note that the output batch rows may not be
// contiguous (output_batch_leading_dim != n_output).
for (int b = 0; b < n_batch; b++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册