提交 8b15ac82 编写于 作者: L Liu Yiqun

Move the definition of hl_cpu_gru_forward and hl_cpu_gru_backward to function/GruFunctor.h.

上级 c54c7d91
......@@ -18,14 +18,6 @@ limitations under the License. */
#ifndef __NVCC__
#include "paddle/math/MathFunctions.h"
// #ifndef PADDLE_TYPE_DOUBLE
// #define CBLAS_GEMM paddle::gemm<float>
// #else
// #define CBLAS_GEMM paddle::gemm<double>
// #endif
template<class OpResetOutput>
void hl_naive_gru_forward_reset_output(OpResetOutput opResetOutput,
real *gateValue,
......@@ -210,51 +202,6 @@ inline void forward_final_output(OpFinalOutput opFinalOutput,
}
}
template<class OpResetOutput, class OpFinalOutput>
void hl_cpu_gru_forward(OpResetOutput opResetOutput,
OpFinalOutput opFinalOutput,
hl_gru_value value,
int frameSize,
int batchSize,
hl_activation_mode_t active_node,
hl_activation_mode_t active_gate) {
if (value.prevOutValue) {
// CBLAS_GEMM(CblasNoTrans,
// CblasNoTrans,
// batchSize,
// 2 * frameSize,
// frameSize,
// 1,
// value.prevOutValue,
// frameSize,
// value.gateWeight,
// frameSize * 2,
// 1,
// value.gateValue,
// frameSize * 3);
}
forward_reset_output(opResetOutput, value, frameSize, batchSize, active_gate);
if (value.prevOutValue) {
// CBLAS_GEMM(CblasNoTrans,
// CblasNoTrans,
// batchSize,
// frameSize,
// frameSize,
// 1,
// value.resetOutputValue,
// frameSize,
// value.stateWeight,
// frameSize,
// 1,
// value.gateValue + frameSize * 2,
// frameSize * 3);
}
forward_final_output(opFinalOutput, value, frameSize, batchSize, active_node);
}
template<class OpStateGrad>
void hl_naive_gru_backward_state_grad(OpStateGrad opStateGrad,
real *gateValue,
......@@ -524,87 +471,6 @@ inline void backward_reset_grad(OpResetGrad opResetGrad,
}
}
}
template<class OpStateGrad, class OpResetGrad>
void hl_cpu_gru_backward(OpStateGrad opStateGrad,
OpResetGrad opResetGrad,
hl_gru_value value,
hl_gru_grad grad,
int frameSize,
int batchSize,
hl_activation_mode_t active_node,
hl_activation_mode_t active_gate) {
backward_state_grad(opStateGrad, value, grad,
frameSize, batchSize, active_node);
if (value.prevOutValue && grad.prevOutGrad) {
// CBLAS_GEMM(CblasNoTrans,
// CblasTrans,
// batchSize,
// frameSize,
// frameSize,
// 1,
// grad.gateGrad + frameSize * 2,
// frameSize * 3,
// value.stateWeight,
// frameSize,
// 0,
// grad.resetOutputGrad,
// frameSize);
if (grad.stateWeightGrad) {
// CBLAS_GEMM(CblasTrans,
// CblasNoTrans,
// frameSize,
// frameSize,
// batchSize,
// 1,
// value.resetOutputValue,
// frameSize,
// grad.gateGrad + frameSize * 2,
// frameSize * 3,
// 1,
// grad.stateWeightGrad,
// frameSize);
}
}
backward_reset_grad(opResetGrad, value, grad,
frameSize, batchSize, active_gate);
if (grad.prevOutGrad && value.prevOutValue) {
// CBLAS_GEMM(CblasNoTrans,
// CblasTrans,
// batchSize,
// frameSize,
// frameSize * 2,
// 1,
// grad.gateGrad,
// frameSize * 3,
// value.gateWeight,
// frameSize * 2,
// 1,
// grad.prevOutGrad,
// frameSize);
if (grad.gateWeightGrad) {
// CBLAS_GEMM(CblasTrans,
// CblasNoTrans,
// frameSize,
// frameSize * 2,
// batchSize,
// 1,
// value.prevOutValue,
// frameSize,
// grad.gateGrad,
// frameSize * 3,
// 1,
// grad.gateWeightGrad,
// frameSize * 2);
}
}
}
#endif
#endif // HL_CPU_GRU_CUH_
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "GemmFunctor.h"
#include "GruFunctor.h"
#include "hl_cpu_gru.cuh"
namespace paddle {
template <DeviceType Device, class T>
struct GruFunctor {
template <class OpResetOutput, class OpFinalOutput>
static void compute(OpResetOutput opResetOutput,
OpFinalOutput opFinalOutput,
hl_gru_value value,
int frameSize,
int batchSize,
hl_activation_mode_t active_node,
hl_activation_mode_t active_gate) {
#ifndef __NVCC__
if (value.prevOutValue) {
BlasGemm<Device, T>::compute(false,
false,
batchSize,
2 * frameSize,
frameSize,
1,
value.prevOutValue,
frameSize,
value.gateWeight,
frameSize * 2,
1,
value.gateValue,
frameSize * 3);
}
forward_reset_output(
opResetOutput, value, frameSize, batchSize, active_gate);
if (value.prevOutValue) {
BlasGemm<Device, T>::compute(false,
false,
batchSize,
frameSize,
frameSize,
1,
value.resetOutputValue,
frameSize,
value.stateWeight,
frameSize,
1,
value.gateValue + frameSize * 2,
frameSize * 3);
}
forward_final_output(
opFinalOutput, value, frameSize, batchSize, active_node);
#endif
}
};
template <DeviceType Device, class T>
struct GruGradFunctor {
template <class OpStateGrad, class OpResetGrad>
static void compute(OpStateGrad opStateGrad,
OpResetGrad opResetGrad,
hl_gru_value value,
hl_gru_grad grad,
int frameSize,
int batchSize,
hl_activation_mode_t active_node,
hl_activation_mode_t active_gate) {
#ifndef __NVCC__
backward_state_grad(
opStateGrad, value, grad, frameSize, batchSize, active_node);
if (value.prevOutValue && grad.prevOutGrad) {
BlasGemm<Device, T>::compute(false,
true,
batchSize,
frameSize,
frameSize,
1,
grad.gateGrad + frameSize * 2,
frameSize * 3,
value.stateWeight,
frameSize,
0,
grad.resetOutputGrad,
frameSize);
if (grad.stateWeightGrad) {
BlasGemm<Device, T>::compute(true,
false,
frameSize,
frameSize,
batchSize,
1,
value.resetOutputValue,
frameSize,
grad.gateGrad + frameSize * 2,
frameSize * 3,
1,
grad.stateWeightGrad,
frameSize);
}
}
backward_reset_grad(
opResetGrad, value, grad, frameSize, batchSize, active_gate);
if (grad.prevOutGrad && value.prevOutValue) {
BlasGemm<Device, T>::compute(false,
true,
batchSize,
frameSize,
frameSize * 2,
1,
grad.gateGrad,
frameSize * 3,
value.gateWeight,
frameSize * 2,
1,
grad.prevOutGrad,
frameSize);
if (grad.gateWeightGrad) {
BlasGemm<Device, T>::compute(true,
false,
frameSize,
frameSize * 2,
batchSize,
1,
value.prevOutValue,
frameSize,
grad.gateGrad,
frameSize * 3,
1,
grad.gateWeightGrad,
frameSize * 2);
}
}
#endif
}
};
} // namespace paddle
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "GruCompute.h"
#include "hl_recurrent_apply.cuh"
#include "paddle/function/GruFunctor.h"
#include "paddle/utils/Util.h"
namespace paddle {
......@@ -25,13 +26,13 @@ void GruCompute::init(LayerConfig &config) {
template <>
void GruCompute::forward<0>(hl_gru_value value, int frameSize, int batchSize) {
hl_cpu_gru_forward(hppl::forward::gru_resetOutput(),
hppl::forward::gru_finalOutput(),
value,
frameSize,
batchSize,
activeNode_,
activeGate_);
GruFunctor<DEVICE_TYPE_CPU, real>::compute(hppl::forward::gru_resetOutput(),
hppl::forward::gru_finalOutput(),
value,
frameSize,
batchSize,
activeNode_,
activeGate_);
}
template <>
......@@ -39,14 +40,15 @@ void GruCompute::backward<0>(hl_gru_value value,
hl_gru_grad grad,
int frameSize,
int batchSize) {
hl_cpu_gru_backward(hppl::backward::gru_stateGrad(),
hppl::backward::gru_resetGrad(),
value,
grad,
frameSize,
batchSize,
activeNode_,
activeGate_);
GruGradFunctor<DEVICE_TYPE_CPU, real>::compute(
hppl::backward::gru_stateGrad(),
hppl::backward::gru_resetGrad(),
value,
grad,
frameSize,
batchSize,
activeNode_,
activeGate_);
}
} // namespace paddle
......@@ -2,25 +2,8 @@
set -xe
COMPILER=gcc
USE_EIGEN=ON
if [ $COMPILER == clang ]; then
SUFFIX=_clang
C_COMPILER=clang
CXX_COMPILER=clang++
else
SUFFIX=_gcc
C_COMPILER=gcc
CXX_COMPILER=g++
fi
if [ $USE_EIGEN == ON ]; then
SUFFIX=${SUFFIX}_eigen
else
SUFFIX=${SUFFIX}_openblas
fi
BUILD_ROOT=/paddle/build_android$SUFFIX
DEST_ROOT=/paddle/install$SUFFIX
BUILD_ROOT=/paddle/build_android
DEST_ROOT=/paddle/install
rm -rf $BUILD_ROOT 2>/dev/null || true
mkdir -p $BUILD_ROOT
......@@ -41,7 +24,7 @@ if [ $ANDROID_ABI == "armeabi-v7a" ]; then
-DCMAKE_INSTALL_PREFIX=$DEST_ROOT \
-DTHIRD_PARTY_PATH=$THIRD_PARTY_PATH \
-DCMAKE_BUILD_TYPE=Release \
-DUSE_EIGEN_FOR_BLAS=${USE_EIGEN} \
-DUSE_EIGEN_FOR_BLAS=ON \
-DWITH_C_API=ON \
-DWITH_SWIG_PY=OFF \
-DWITH_STYLE_CHECK=OFF \
......@@ -58,7 +41,7 @@ elif [ $ANDROID_ABI == "arm64-v8a" ]; then
-DCMAKE_INSTALL_PREFIX=$DEST_ROOT \
-DTHIRD_PARTY_PATH=$THIRD_PARTY_PATH \
-DCMAKE_BUILD_TYPE=Release \
-DUSE_EIGEN_FOR_BLAS=${USE_EIGEN} \
-DUSE_EIGEN_FOR_BLAS=OFF \
-DWITH_C_API=ON \
-DWITH_SWIG_PY=OFF \
..
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册