/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include "paddle/operators/math/detail/activation_functions.h" #include "paddle/operators/math/gru_compute.h" namespace paddle { namespace operators { namespace math { namespace detail { #ifndef __NVCC__ template void hl_naive_gru_forward_reset_output(OpResetOutput opResetOutput, T *gateValue, T *resetOutputValue, T *prevOutputValue, int frameSize, activation_mode_t active_gate) { T rValueUpdateGate; T rValueResetGate; T rValueResetOutput; T rPrevOut = 0; T *updateGate = gateValue; T *resetGate = gateValue + frameSize; for (int i = 0; i < frameSize; i++) { rValueUpdateGate = updateGate[i]; rValueResetGate = resetGate[i]; if (prevOutputValue) { rPrevOut = prevOutputValue[i]; } opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut, rValueResetOutput, active_gate); updateGate[i] = rValueUpdateGate; resetGate[i] = rValueResetGate; resetOutputValue[i] = rValueResetOutput; } } template void hl_naive_gru_forward_final_output(OpFinalOutput opFinalOutput, T *gateValue, T *prevOutputValue, T *outputValue, int frameSize, activation_mode_t active_node) { T rValueUpdateGate; T rValueFrameState; T rPrevOut = 0; T rOutput; T *updateGate = gateValue; T *frameState = gateValue + frameSize * 2; for (int i = 0; i < frameSize; i++) { rValueUpdateGate = updateGate[i]; rValueFrameState = frameState[i]; if (prevOutputValue) { rPrevOut = prevOutputValue[i]; } opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput, active_node); frameState[i] = rValueFrameState; outputValue[i] = rOutput; } } template void hl_avx_gru_forward_reset_output(OpResetOutput opResetOutput, T *gateValue, T *resetOutputValue, T *prevOutputValue, int frameSize, activation_mode_t active_gate) { #ifdef __AVX__ __m256 rValueUpdateGate; __m256 rValueResetGate; __m256 rValueResetOutput; __m256 rPrevOut = _mm256_set1_ps(0.0f); __m256 *updateGate = (__m256 *)gateValue; __m256 *resetGate = (__m256 *)(gateValue + frameSize); for (int i = 0; i < frameSize / 8; i++) { rValueUpdateGate = updateGate[i]; rValueResetGate = resetGate[i]; if (prevOutputValue) { rPrevOut = ((__m256 *)prevOutputValue)[i]; } opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut, rValueResetOutput, active_gate); updateGate[i] = rValueUpdateGate; resetGate[i] = rValueResetGate; ((__m256 *)resetOutputValue)[i] = rValueResetOutput; } #endif } template void hl_avx_gru_forward_final_output(OpFinalOutput opFinalOutput, T *gateValue, T *prevOutputValue, T *outputValue, int frameSize, activation_mode_t active_node) { #ifdef __AVX__ __m256 rValueUpdateGate; __m256 rValueFrameState; __m256 rPrevOut = _mm256_set1_ps(0.0f); __m256 rOutput; __m256 *updateGate = (__m256 *)gateValue; __m256 *frameState = (__m256 *)(gateValue + frameSize * 2); for (int i = 0; i < frameSize / 8; i++) { rValueUpdateGate = updateGate[i]; rValueFrameState = frameState[i]; if (prevOutputValue) { rPrevOut = ((__m256 *)prevOutputValue)[i]; } opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput, active_node); frameState[i] = rValueFrameState; ((__m256 *)outputValue)[i] = rOutput; } #endif } template inline void forward_reset_output(OpResetOutput opResetOutput, hl_gru_value value, int frameSize, int batchSize, activation_mode_t active_gate) { for (int b = 0; b < batchSize; b++) { if (OpResetOutput::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) { hl_avx_gru_forward_reset_output( opResetOutput, value.gateValue, value.resetOutputValue, value.prevOutValue, frameSize, active_gate); } else { hl_naive_gru_forward_reset_output( opResetOutput, value.gateValue, value.resetOutputValue, value.prevOutValue, frameSize, active_gate); } value.gateValue += frameSize * 3; value.resetOutputValue += frameSize; if (value.prevOutValue) { value.prevOutValue += frameSize; } } } template inline void forward_final_output(OpFinalOutput opFinalOutput, hl_gru_value value, int frameSize, int batchSize, activation_mode_t active_node) { for (int b = 0; b < batchSize; b++) { if (OpFinalOutput::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) { hl_avx_gru_forward_final_output(opFinalOutput, value.gateValue, value.prevOutValue, value.outputValue, frameSize, active_node); } else { hl_naive_gru_forward_final_output(opFinalOutput, value.gateValue, value.prevOutValue, value.outputValue, frameSize, active_node); } value.gateValue += frameSize * 3; value.outputValue += frameSize; if (value.prevOutValue) { value.prevOutValue += frameSize; } } } template void hl_naive_gru_backward_state_grad(OpStateGrad opStateGrad, T *gateValue, T *gateGrad, T *prevOutValue, T *prevOutGrad, T *outputGrad, int frameSize, activation_mode_t active_node) { T rUpdateGateValue; T rUpdateGateGrad; T rFrameStateValue; T rFrameStateGrad; T rOutGrad; T rPrevOutValue = 0; T rPrevOutGrad = 0; T *updateGateValue = gateValue; T *updateGateGrad = gateGrad; T *frameStateValue = gateValue + frameSize * 2; T *frameStateGrad = gateGrad + frameSize * 2; for (int i = 0; i < frameSize; i++) { rUpdateGateValue = updateGateValue[i]; rFrameStateValue = frameStateValue[i]; rOutGrad = outputGrad[i]; if (prevOutValue) { rPrevOutValue = prevOutValue[i]; } if (prevOutGrad) { rPrevOutGrad = prevOutGrad[i]; } opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue, rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad, active_node); updateGateGrad[i] = rUpdateGateGrad; frameStateGrad[i] = rFrameStateGrad; if (prevOutGrad) { prevOutGrad[i] = rPrevOutGrad; } } } template void hl_naive_gru_backward_reset_grad(OpResetGrad opResetGrad, T *gateValue, T *gateGrad, T *prevOutValue, T *prevOutGrad, T *resetOutputGrad, int frameSize, activation_mode_t active_gate) { T rUpdateGateValue; T rUpdateGateGrad; T rResetGateValue; T rResetGateGrad; T rResetOutputGrad = 0; T rPrevOutValue = 0; T rPrevOutGrad = 0; T *updateGateValue = gateValue; T *updateGateGrad = gateGrad; T *resetGateValue = gateValue + frameSize; T *resetGateGrad = gateGrad + frameSize; for (int i = 0; i < frameSize; i++) { rUpdateGateValue = updateGateValue[i]; rUpdateGateGrad = updateGateGrad[i]; rResetGateValue = resetGateValue[i]; if (prevOutValue && prevOutGrad) { rResetOutputGrad = resetOutputGrad[i]; } if (prevOutValue) { rPrevOutValue = prevOutValue[i]; } if (prevOutGrad) { rPrevOutGrad = prevOutGrad[i]; } opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue, rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad, active_gate); updateGateGrad[i] = rUpdateGateGrad; resetGateGrad[i] = rResetGateGrad; if (prevOutGrad) { prevOutGrad[i] = rPrevOutGrad; } } } template void hl_avx_gru_backward_state_grad(OpStateGrad opStateGrad, T *gateValue, T *gateGrad, T *prevOutValue, T *prevOutGrad, T *outputGrad, int frameSize, activation_mode_t active_node) { #ifdef __AVX__ __m256 rUpdateGateValue; __m256 rUpdateGateGrad; __m256 rFrameStateValue; __m256 rFrameStateGrad; __m256 rOutGrad; __m256 rPrevOutValue = _mm256_set1_ps(0.0f); __m256 rPrevOutGrad = _mm256_set1_ps(0.0f); __m256 *updateGateValue = (__m256 *)gateValue; __m256 *updateGateGrad = (__m256 *)gateGrad; __m256 *frameStateValue = (__m256 *)(gateValue + frameSize * 2); __m256 *frameStateGrad = (__m256 *)(gateGrad + frameSize * 2); for (int i = 0; i < frameSize / 8; i++) { rUpdateGateValue = updateGateValue[i]; rFrameStateValue = frameStateValue[i]; rOutGrad = ((__m256 *)outputGrad)[i]; if (prevOutValue) { rPrevOutValue = ((__m256 *)prevOutValue)[i]; } if (prevOutGrad) { rPrevOutGrad = ((__m256 *)prevOutGrad)[i]; } opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue, rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad, active_node); updateGateGrad[i] = rUpdateGateGrad; frameStateGrad[i] = rFrameStateGrad; if (prevOutGrad) { ((__m256 *)prevOutGrad)[i] = rPrevOutGrad; } } #endif } template void hl_avx_gru_backward_reset_grad(OpResetGrad opResetGrad, T *gateValue, T *gateGrad, T *prevOutValue, T *prevOutGrad, T *resetOutputGrad, int frameSize, activation_mode_t active_gate) { #ifdef __AVX__ __m256 rUpdateGateValue; __m256 rUpdateGateGrad; __m256 rResetGateValue; __m256 rResetGateGrad; __m256 rResetOutputGrad = _mm256_set1_ps(0.0f); __m256 rPrevOutValue = _mm256_set1_ps(0.0f); __m256 rPrevOutGrad = _mm256_set1_ps(0.0f); __m256 *updateGateValue = (__m256 *)gateValue; __m256 *updateGateGrad = (__m256 *)gateGrad; __m256 *resetGateValue = (__m256 *)(gateValue + frameSize); __m256 *resetGateGrad = (__m256 *)(gateGrad + frameSize); for (int i = 0; i < frameSize / 8; i++) { rUpdateGateValue = updateGateValue[i]; rUpdateGateGrad = updateGateGrad[i]; rResetGateValue = resetGateValue[i]; if (prevOutValue && prevOutGrad) { rResetOutputGrad = ((__m256 *)resetOutputGrad)[i]; } if (prevOutValue) { rPrevOutValue = ((__m256 *)prevOutValue)[i]; } if (prevOutGrad) { rPrevOutGrad = ((__m256 *)prevOutGrad)[i]; } opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue, rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad, active_gate); updateGateGrad[i] = rUpdateGateGrad; resetGateGrad[i] = rResetGateGrad; if (prevOutGrad) { ((__m256 *)prevOutGrad)[i] = rPrevOutGrad; } } #endif } template inline void backward_state_grad(OpStateGrad opStateGrad, hl_gru_value value, hl_gru_grad grad, int frameSize, int batchSize, activation_mode_t active_node) { for (int b = 0; b < batchSize; b++) { if (OpStateGrad::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) { hl_avx_gru_backward_state_grad( opStateGrad, value.gateValue, grad.gateGrad, value.prevOutValue, grad.prevOutGrad, grad.outputGrad, frameSize, active_node); } else { hl_naive_gru_backward_state_grad( opStateGrad, value.gateValue, grad.gateGrad, value.prevOutValue, grad.prevOutGrad, grad.outputGrad, frameSize, active_node); } value.gateValue += frameSize * 3; if (value.prevOutValue) { value.prevOutValue += frameSize; } grad.gateGrad += frameSize * 3; grad.outputGrad += frameSize; if (grad.prevOutGrad) { grad.prevOutGrad += frameSize; } } } template inline void backward_reset_grad(OpResetGrad opResetGrad, hl_gru_value value, hl_gru_grad grad, int frameSize, int batchSize, activation_mode_t active_gate) { for (int b = 0; b < batchSize; b++) { if (OpResetGrad::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) { hl_avx_gru_backward_reset_grad( opResetGrad, value.gateValue, grad.gateGrad, value.prevOutValue, grad.prevOutGrad, grad.resetOutputGrad, frameSize, active_gate); } else { hl_naive_gru_backward_reset_grad( opResetGrad, value.gateValue, grad.gateGrad, value.prevOutValue, grad.prevOutGrad, grad.resetOutputGrad, frameSize, active_gate); } value.gateValue += frameSize * 3; if (value.prevOutValue) { value.prevOutValue += frameSize; } grad.gateGrad += frameSize * 3; grad.resetOutputGrad += frameSize; if (grad.prevOutGrad) { grad.prevOutGrad += frameSize; } } } #endif } // namespace detail } // namespace math } // namespace operators } // namespace paddle