提交 3cace737 编写于 作者: D dangqingqing

Add lstm implementation.

上级 9106a4bb
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/lstm_unit_op.h"
#include "paddle/operators/lstm_op.h"
namespace paddle {
namespace operators {
......@@ -44,8 +44,36 @@ class LSTMOp : public framework::OperatorWithKernel {
"should be the same.");
}
int frame_size = x_dims[1];
auto w_dims = ctx->GetInputDim("Weight");
PADDLE_ENFORCE_EQ(w_dims.size(), 2,
"The rank of Input(Weight) should be 2.");
PADDLE_ENFORCE_EQ(w_dims[0], frame_size,
"The first dimension of Input(Weight) "
"should be %d.",
frame_size);
PADDLE_ENFORCE_EQ(w_dims[1], 4 * frame_size,
"The second dimension of Input(Weight) "
"should be 4 * %d.",
frame_size);
auto b_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1,
"The first dimension of Input(Bias) should be 1.");
if (ctx->Attrs().Get<bool>("use_peepholes")) {
PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size,
"The second dimension of Input(Bias) should be "
"7 * %d if enable peepholes connection",
frame_size);
} else {
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
"The second dimension of Input(Bias) should be "
"4 * %d if diable peepholes connection",
frame_size);
}
ctx->SetOutputDim("Hidden", x_dims);
ctx->SetOutputDim("Cell", x_dims);
ctx->SetOutputDim("Hidden", x_dims);
ctx->ShareLoD("Input", "Hidden");
ctx->ShareLoD("Input", "Cell");
}
......@@ -82,6 +110,8 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
"2. `use_peepholes = True` "
" - The shape is (1 x 7*D). "
" - Bias = {b_i, b_f, b_c, b_o, W_ic, W_fc, W_oc}.");
AddOutput("Batch", "(LoDTensor) save the reorganized input as batch info. ")
.AsIntermediate();
AddOutput("Hidden",
"(LoDTensor) the hidden state lod tensor of LSTM operator. "
"The shape and lod is the same with the `Input`.");
......@@ -92,6 +122,10 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, defalut: True) "
"whether to enable diagonal/peephole connections.")
.SetDefault(true);
AddAttr<bool>("is_reverse",
"(bool, defalut: False) "
"whether to compute reversed LSTM.")
.SetDefault(true);
AddAttr<std::string>(
"gate_activation",
"(string, defalut: sigmoid)"
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "glog/logging.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
......@@ -25,7 +24,21 @@ using framework::Tensor;
template <typename Place, typename T>
class LSTMKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {}
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input_t = ctx.Input<framework::LoDTensor>("Input");
auto* batch_t = ctx.Input<framework::LoDTensor>("Batch");
auto* bias_t = ctx.Input<framework::LoDTensor>("Bias");
bool is_reverse = ctx.Attr<bool>("is_reverse");
LoDTensor2BatchFunctor<Place, T> to_batch(ctx.device_context(), input_t,
batch_t, is_reverse);
auto in_dims = input_t->dims();
int frame_size = in_dims[1];
if (bias_t) {
auto b = EigenMatrix<T>::From(*bias);
}
}
};
template <typename Place, typename T>
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_ACTIVATION_FUNCTIONS_H_
#define HL_ACTIVATION_FUNCTIONS_H_
#include "hl_functions.h"
/**
* Active functions: sigmoid, relu, tanh and linear.
*/
#define HPPL_ACTIVE_FUNCTION \
{ hppl::sigmoid, hppl::relu, hppl::tanh, hppl::linear }
namespace hppl {
/**
* Hppl supports sigmoid, relu, tanh, linear active functions
* for neural networks' forward and backward activation.
*/
template <class T>
class Active {
public:
typedef T (*forward)(T);
typedef T (*backward)(T, T);
};
#ifdef __NVCC__
namespace gpu {
static __device__ Active<float>::forward forward[] = HPPL_ACTIVE_FUNCTION;
static __device__ Active<float>::backward backward[] = HPPL_ACTIVE_FUNCTION;
static __device__ Active<double>::forward forward[] = HPPL_ACTIVE_FUNCTION;
static __device__ Active<double>::backward backward[] = HPPL_ACTIVE_FUNCTION;
} // namespace gpu
#else
namespace cpu {
static Active<float>::forward forward[] = HPPL_ACTIVE_FUNCTION;
static Active<float>::backward backward[] = HPPL_ACTIVE_FUNCTION;
static Active<double>::forward forward[] = HPPL_ACTIVE_FUNCTION;
static Active<double>::backward backward[] = HPPL_ACTIVE_FUNCTION;
} // namespace cpu
#ifdef __AVX__
namespace avx {
static Active<__m256>::forward forward[] = HPPL_ACTIVE_FUNCTION;
static Active<__m256>::backward backward[] = HPPL_ACTIVE_FUNCTION;
} // namespace avx
#endif
#endif
} // namespace hppl
#endif // HL_ACTIVATION_FUNCTIONS_H_
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <immintrin.h>
#include "hl_functions.h"
namespace hppl {
extern __m256 exp(__m256 a);
__m256 relu(const __m256 a) {
__m256 tmp = _mm256_set1_ps(0.0f);
return _mm256_max_ps(a, tmp);
}
__m256 sigmoid(const __m256 a) {
__m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX);
__m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN);
__m256 tmp = _mm256_max_ps(a, min);
tmp = _mm256_min_ps(tmp, max);
tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp);
tmp = exp(tmp);
tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp);
tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp);
return tmp;
}
__m256 tanh(const __m256 a) {
__m256 max = _mm256_set1_ps(EXP_MAX_INPUT);
__m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a);
tmp = _mm256_min_ps(tmp, max);
tmp = exp(tmp);
return _mm256_sub_ps(_mm256_div_ps(_mm256_set1_ps(2.0f),
_mm256_add_ps(_mm256_set1_ps(1.0f), tmp)),
_mm256_set1_ps(1.0f));
}
__m256 linear(const __m256 a) { return a; }
__m256 relu(const __m256 a, const __m256 b) {
return _mm256_mul_ps(
a, _mm256_and_ps(_mm256_cmp_ps(b, _mm256_set1_ps(0.0f), _CMP_GT_OS),
_mm256_set1_ps(1.0f)));
}
__m256 sigmoid(const __m256 a, const __m256 b) {
return _mm256_mul_ps(_mm256_mul_ps(a, b),
_mm256_sub_ps(_mm256_set1_ps(1.0f), b));
}
__m256 tanh(const __m256 a, const __m256 b) {
return _mm256_mul_ps(
a, _mm256_sub_ps(_mm256_set1_ps(1.0f), _mm256_mul_ps(b, b)));
}
__m256 linear(const __m256 a, const __m256 b) { return a; }
} // namespace hppl
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_AVX_FUNCTIONS_H_
#define HL_AVX_FUNCTIONS_H_
#include <immintrin.h>
namespace hppl {
__m256 relu(const __m256 a);
__m256 sigmoid(const __m256 a);
__m256 tanh(const __m256 a);
__m256 linear(const __m256 a);
__m256 relu(const __m256 a, const __m256 b);
__m256 sigmoid(const __m256 a, const __m256 b);
__m256 tanh(const __m256 a, const __m256 b);
__m256 linear(const __m256 a, const __m256 b);
} // namespace hppl
#endif // HL_AVX_FUNCTIONS_H_
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <math.h>
#include "/paddle/operators/math/detail/hl_functions.h"
namespace hppl {
real relu(const real a) { return a > 0.0f ? a : 0.0f; }
real sigmoid(const real a) {
const real min = SIGMOID_THRESHOLD_MIN;
const real max = SIGMOID_THRESHOLD_MAX;
real tmp = (a < min) ? min : ((a > max) ? max : a);
return 1.0 / (1.0 + exp(-tmp));
}
real tanh(const real a) {
real tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return (2.0 / (1.0 + exp(tmp))) - 1.0;
}
real linear(const real a) { return a; }
real relu(const real a, const real b) { return a * (b > 0.0f ? 1.0f : 0.0f); }
real sigmoid(const real a, const real b) { return a * b * (1 - b); }
real tanh(const real a, const real b) { return a * (1.0f - b * b); }
real linear(const real a, const real b) { return a; }
} // namespace hppl
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_FUNCTIONS_H_
#define HL_FUNCTIONS_H_
/**
* sigmoid threshold maximum
*/
#define SIGMOID_THRESHOLD_MIN -40.0
/**
* sigmoid threshold minimum
*/
#define SIGMOID_THRESHOLD_MAX 13.0
#ifndef __NVCC__
namespace hppl {
/*
* forward activation
*/
template <typename T>
T relu(const T a);
template <typename T>
T sigmoid(const T a);
template <typename T>
T tanh(const T a);
template <typename T>
T linear(const T a);
/*
* backward activation
*/
template <typename T>
T relu(const T a, const T b);
template <typename T>
T sigmoid(const T a, const T b);
template <typename T>
T tanh(const T a, const T b);
template <typename T>
T linear(const T a, const T b);
} // namespace hppl
#ifdef __AVX__
#include "hl_avx_functions.h"
#endif
#else
#include "hl_gpu_functions.h"
#endif
#endif // HL_FUNCTIONS_H_
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef HL_GPU_FUNCTIONS_CUH_
#define HL_GPU_FUNCTIONS_CUH_
#include "hl_base.h"
namespace hppl {
template <typename T>
__device__ static T relu(const T a) {
return a > 0.0f ? a : 0.0f;
}
template <>
__device__ static float sigmoid(const float a) {
const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX;
float tmp = (a < min) ? min : ((a > max) ? max : a);
return __fdividef(1.0f, 1.0f + __expf(-tmp));
}
template <>
__device__ static double sigmoid(const double a) {
const double min = SIGMOID_THRESHOLD_MIN;
const double max = SIGMOID_THRESHOLD_MAX;
double tmp = (a < min) ? min : ((a > max) ? max : a);
return 1.0 / (1.0 + exp(-tmp));
}
template <>
__device__ static float tanh(const float a) {
return __fdividef(2.0f, (1.0f + __expf(-2.0f * a))) - 1.0f;
}
template <>
__device__ static double tanh(const double a) {
return (2.0 / (1.0 + exp(-2.0 * a))) - 1.0;
}
template <typename T>
__device__ static T linear(const T a) {
return a;
}
template <typename T>
__device__ static T relu(const T a, const T b) {
return a * (b > 0.0f ? 1.0f : 0.0f);
}
template <typename T>
__device__ static T sigmoid(const T a, const T b) {
return a * b * (1 - b);
}
template <typename T>
__device__ static T tanh(const T a, const T b) {
return a * (1.0f - b * b);
}
template <typename T>
__device__ static T linear(const T a, const T b) {
return a;
}
} // namespace hppl
#endif // HL_GPU_FUNCTIONS_CUH_
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/math/lstm_compute.h"
namespace paddle {
namespace operators {
namespace math {
namespace detail {
#ifndef __NVCC__
template <class T, class Op>
void naive_lstm_forward_one_sequence(Op op, lstm_value value, int frameSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
T rValueIn;
T rValueIg;
T rValueFg;
T rValueOg;
T rCheckI;
T rCheckF;
T rCheckO;
T rState;
T rPrevState = 0;
T rStateAtv;
T rOut;
T *valueIn = value.gateValue;
T *valueIg = value.gateValue + frameSize;
T *valueFg = value.gateValue + frameSize * 2;
T *valueOg = value.gateValue + frameSize * 3;
for (int i = 0; i < frameSize; i++) {
rValueIn = valueIn[i];
rValueIg = valueIg[i];
rValueFg = valueFg[i];
rValueOg = valueOg[i];
rCheckI = value.checkIg[i];
rCheckF = value.checkFg[i];
rCheckO = value.checkOg[i];
if (value.prevStateValue) {
rPrevState = value.prevStateValue[i];
}
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
rOut, rCheckI, rCheckF, rCheckO, hppl::cpu::forward[active_node],
hppl::cpu::forward[active_gate], hppl::cpu::forward[active_state]);
valueIn[i] = rValueIn;
valueIg[i] = rValueIg;
valueFg[i] = rValueFg;
valueOg[i] = rValueOg;
value.stateValue[i] = rState;
value.stateActiveValue[i] = rStateAtv;
value.outputValue[i] = rOut;
}
}
template <class T, class Op>
void naive_lstm_backward_one_sequence(Op op, lstm_value value, lstm_grad grad,
int frameSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
T rValueIn;
T rValueIg;
T rValueFg;
T rValueOg;
T rGradIn;
T rGradIg;
T rGradFg;
T rGradOg;
T rPrevState = 0;
T rPrevStateGrad;
T rState;
T rStateGrad;
T rStateAtv;
T rOutputGrad;
T rCheckI;
T rCheckF;
T rCheckO;
T rCheckIGrad;
T rCheckFGrad;
T rCheckOGrad;
T *valueIn = value.gateValue;
T *valueIg = value.gateValue + frameSize;
T *valueFg = value.gateValue + frameSize * 2;
T *valueOg = value.gateValue + frameSize * 3;
T *gradIn = grad.gateGrad;
T *gradIg = grad.gateGrad + frameSize;
T *gradFg = grad.gateGrad + frameSize * 2;
T *gradOg = grad.gateGrad + frameSize * 3;
for (int i = 0; i < frameSize; i++) {
rValueIn = valueIn[i];
rValueIg = valueIg[i];
rValueFg = valueFg[i];
rValueOg = valueOg[i];
rCheckI = value.checkIg[i];
rCheckF = value.checkFg[i];
rCheckO = value.checkOg[i];
rState = value.stateValue[i];
rStateAtv = value.stateActiveValue[i];
rOutputGrad = grad.outputGrad[i];
rStateGrad = grad.stateGrad[i];
if (value.prevStateValue) {
rPrevState = value.prevStateValue[i];
}
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg,
rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv,
rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad,
rCheckOGrad, hppl::cpu::backward[active_node],
hppl::cpu::backward[active_gate], hppl::cpu::backward[active_state]);
gradIn[i] = rGradIn;
gradIg[i] = rGradIg;
gradFg[i] = rGradFg;
gradOg[i] = rGradOg;
grad.stateGrad[i] = rStateGrad;
if (grad.prevStateGrad) grad.prevStateGrad[i] = rPrevStateGrad;
if (value.prevStateValue) {
if (grad.checkIgGrad) grad.checkIgGrad[i] += rCheckIGrad;
if (grad.checkFgGrad) grad.checkFgGrad[i] += rCheckFGrad;
}
if (grad.checkOgGrad) grad.checkOgGrad[i] += rCheckOGrad;
}
}
template <class Op>
void avx_lstm_forward_one_sequence(Op op, lstm_value value, int frameSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
#ifdef __AVX__
__m256 rValueIn;
__m256 rValueIg;
__m256 rValueFg;
__m256 rValueOg;
__m256 rCheckI;
__m256 rCheckF;
__m256 rCheckO;
__m256 rState;
__m256 rPrevState = _mm256_set1_ps(0.0f);
__m256 rStateAtv;
__m256 rOut;
__m256 *valueIn = (__m256 *)value.gateValue;
__m256 *valueIg = (__m256 *)(value.gateValue + frameSize);
__m256 *valueFg = (__m256 *)(value.gateValue + frameSize * 2);
__m256 *valueOg = (__m256 *)(value.gateValue + frameSize * 3);
for (int i = 0; i < frameSize / 8; i++) {
rValueIn = valueIn[i];
rValueIg = valueIg[i];
rValueFg = valueFg[i];
rValueOg = valueOg[i];
rCheckI = ((__m256 *)value.checkIg)[i];
rCheckF = ((__m256 *)value.checkFg)[i];
rCheckO = ((__m256 *)value.checkOg)[i];
if (value.prevStateValue) {
rPrevState = ((__m256 *)value.prevStateValue)[i];
}
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
rOut, rCheckI, rCheckF, rCheckO, hppl::avx::forward[active_node],
hppl::avx::forward[active_gate], hppl::avx::forward[active_state]);
valueIn[i] = rValueIn;
valueIg[i] = rValueIg;
valueFg[i] = rValueFg;
valueOg[i] = rValueOg;
((__m256 *)value.stateValue)[i] = rState;
((__m256 *)value.stateActiveValue)[i] = rStateAtv;
((__m256 *)value.outputValue)[i] = rOut;
}
#endif
}
template <class Op>
void avx_lstm_backward_one_sequence(Op op, lstm_value value, lstm_grad grad,
int frameSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
#ifdef __AVX__
__m256 rValueIn;
__m256 rValueIg;
__m256 rValueFg;
__m256 rValueOg;
__m256 rGradIn;
__m256 rGradIg;
__m256 rGradFg;
__m256 rGradOg;
__m256 rPrevState = _mm256_set1_ps(0.0f);
__m256 rPrevStateGrad;
__m256 rStateGrad;
__m256 rState;
__m256 rStateAtv;
__m256 rOutputGrad;
__m256 rCheckI;
__m256 rCheckF;
__m256 rCheckO;
__m256 rCheckIGrad;
__m256 rCheckFGrad;
__m256 rCheckOGrad;
__m256 *valueIn = (__m256 *)value.gateValue;
__m256 *valueIg = (__m256 *)(value.gateValue + frameSize);
__m256 *valueFg = (__m256 *)(value.gateValue + frameSize * 2);
__m256 *valueOg = (__m256 *)(value.gateValue + frameSize * 3);
__m256 *gradIn = (__m256 *)grad.gateGrad;
__m256 *gradIg = (__m256 *)(grad.gateGrad + frameSize);
__m256 *gradFg = (__m256 *)(grad.gateGrad + frameSize * 2);
__m256 *gradOg = (__m256 *)(grad.gateGrad + frameSize * 3);
for (int i = 0; i < frameSize / 8; i++) {
rValueIn = valueIn[i];
rValueIg = valueIg[i];
rValueFg = valueFg[i];
rValueOg = valueOg[i];
rCheckI = ((__m256 *)value.checkIg)[i];
rCheckF = ((__m256 *)value.checkFg)[i];
rCheckO = ((__m256 *)value.checkOg)[i];
rState = ((__m256 *)value.stateValue)[i];
rStateAtv = ((__m256 *)value.stateActiveValue)[i];
rOutputGrad = ((__m256 *)grad.outputGrad)[i];
rStateGrad = ((__m256 *)grad.stateGrad)[i];
if (value.prevStateValue) {
rPrevState = ((__m256 *)value.prevStateValue)[i];
}
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg,
rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv,
rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad,
rCheckOGrad, hppl::avx::backward[active_node],
hppl::avx::backward[active_gate], hppl::avx::backward[active_state]);
gradIn[i] = rGradIn;
gradIg[i] = rGradIg;
gradFg[i] = rGradFg;
gradOg[i] = rGradOg;
((__m256 *)grad.stateGrad)[i] = rStateGrad;
if (grad.prevStateGrad) ((__m256 *)grad.prevStateGrad)[i] = rPrevStateGrad;
if (value.prevStateValue) {
if (grad.checkIgGrad) ((__m256 *)grad.checkIgGrad)[i] += rCheckIGrad;
if (grad.checkFgGrad) ((__m256 *)grad.checkFgGrad)[i] += rCheckFGrad;
}
if (grad.checkOgGrad) ((__m256 *)grad.checkOgGrad)[i] += rCheckOGrad;
}
#endif
}
template <class T, class Op>
void cpu_lstm_forward(Op op, lstm_value value, int frameSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
if (Op::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) {
avx_lstm_forward_one_sequence(op, value, frameSize, active_node,
active_gate, active_state);
} else {
naive_lstm_forward_one_sequence<T>(op, value, frameSize, active_node,
active_gate, active_state);
}
}
template <class T, class Op>
void cpu_lstm_backward(Op op, lstm_value value, lstm_grad grad, int frameSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
if (Op::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) {
avx_lstm_backward_one_sequence(op, value, grad, frameSize, active_node,
active_gate, active_state);
} else {
naive_lstm_backward_one_sequence<T>(op, value, grad, frameSize, active_node,
active_gate, active_state);
}
}
#endif
} // namespace detail
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/math/detail/lstm_kernel.h"
#include "paddle/operators/math/lstm_compute.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
namespace math {
namespace detail {
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
*/
template <class T, class Op, bool isBatch>
__global__ void KeLstmForward(Op op, lstm_value value, int frameSize,
int batchSize, activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return;
int batchIdx = 0;
if (isBatch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
if (batchIdx >= batchSize) return;
value.gateValue += batchIdx * frameSize * 4;
value.outputValue += batchIdx * frameSize;
value.stateValue += batchIdx * frameSize;
value.stateActiveValue += batchIdx * frameSize;
}
T rState;
T rPrevState = 0;
T rStateAtv;
T rOut;
T rValueIn;
T rValueIg;
T rValueFg;
T rValueOg;
T rCheckI = value.checkIg[frameIdx];
T rCheckF = value.checkFg[frameIdx];
T rCheckO = value.checkOg[frameIdx];
rValueIn = value.gateValue[frameIdx];
rValueIg = value.gateValue[frameIdx + frameSize];
rValueFg = value.gateValue[frameIdx + frameSize * 2];
rValueOg = value.gateValue[frameIdx + frameSize * 3];
if (value.prevStateValue) {
if (isBatch) value.prevStateValue += batchIdx * frameSize;
rPrevState = value.prevStateValue[frameIdx];
}
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
rOut, rCheckI, rCheckF, rCheckO, hppl::gpu::forward[active_node],
hppl::gpu::forward[active_gate], hppl::gpu::forward[active_state]);
value.gateValue[frameIdx] = rValueIn;
value.gateValue[frameIdx + frameSize] = rValueIg;
value.gateValue[frameIdx + frameSize * 2] = rValueFg;
value.gateValue[frameIdx + frameSize * 3] = rValueOg;
value.stateValue[frameIdx] = rState;
value.stateActiveValue[frameIdx] = rStateAtv;
value.outputValue[frameIdx] = rOut;
}
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
*/
template <class T, class Op, bool isBatch>
__global__ void KeLstmBackward(Op op, lstm_value value, lstm_grad grad,
int frameSize, int batchSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return;
int batchIdx = 0;
if (isBatch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
if (batchIdx >= batchSize) return;
value.gateValue += batchIdx * frameSize * 4;
value.stateValue += batchIdx * frameSize;
value.stateActiveValue += batchIdx * frameSize;
grad.gateGrad += batchIdx * frameSize * 4;
grad.stateGrad += batchIdx * frameSize;
grad.outputGrad += batchIdx * frameSize;
}
T rValueIn;
T rValueIg;
T rValueFg;
T rValueOg;
T rGradIn;
T rGradIg;
T rGradFg;
T rGradOg;
T rPrevState = 0;
T rPrevStateGrad;
T rState;
T rStateGrad;
T rStateAtv;
T rOutputGrad;
T rCheckI = value.checkIg[frameIdx];
T rCheckF = value.checkFg[frameIdx];
T rCheckO = value.checkOg[frameIdx];
T rCheckIGrad;
T rCheckFGrad;
T rCheckOGrad;
rValueIn = value.gateValue[frameIdx];
rValueIg = value.gateValue[frameIdx + frameSize];
rValueFg = value.gateValue[frameIdx + frameSize * 2];
rValueOg = value.gateValue[frameIdx + frameSize * 3];
rState = value.stateValue[frameIdx];
rStateAtv = value.stateActiveValue[frameIdx];
rOutputGrad = grad.outputGrad[frameIdx];
rStateGrad = grad.stateGrad[frameIdx];
if (value.prevStateValue) {
if (isBatch) value.prevStateValue += batchIdx * frameSize;
rPrevState = value.prevStateValue[frameIdx];
}
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg,
rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad,
rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad,
hppl::gpu::backward[active_node], hppl::gpu::backward[active_gate],
hppl::gpu::backward[active_state]);
grad.gateGrad[frameIdx] = rGradIn;
grad.gateGrad[frameIdx + frameSize] = rGradIg;
grad.gateGrad[frameIdx + frameSize * 2] = rGradFg;
grad.gateGrad[frameIdx + frameSize * 3] = rGradOg;
grad.stateGrad[frameIdx] = rStateGrad;
if (grad.prevStateGrad) {
if (isBatch) grad.prevStateGrad += batchIdx * frameSize;
grad.prevStateGrad[frameIdx] = rPrevStateGrad;
}
if (isBatch) {
if (value.prevStateValue) {
if (grad.checkIgGrad)
paddle::platform::CudaAtomicAdd(grad.checkIgGrad + frameIdx,
rCheckIGrad);
if (grad.checkFgGrad)
paddle::platform::CudaAtomicAdd(grad.checkFgGrad + frameIdx,
rCheckFGrad);
}
if (grad.checkOgGrad)
paddle::platform::CudaAtomicAdd(grad.checkOgGrad + frameIdx, rCheckOGrad);
} else {
if (value.prevStateValue) {
if (grad.checkIgGrad) grad.checkIgGrad[frameIdx] += rCheckIGrad;
if (grad.checkFgGrad) grad.checkFgGrad[frameIdx] += rCheckFGrad;
}
if (grad.checkOgGrad) grad.checkOgGrad[frameIdx] += rCheckOGrad;
}
}
template <class T, class Op>
void gpu_lstm_forward(Op op, lstm_value value, int frameSize, int batchSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
dim3 threads;
dim3 grid;
if (batchSize == 1) {
int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
int frameBlocks = (frameSize + 1024 - 1) / 1024;
threads = dim3(framePerBlock, 1);
grid = dim3(frameBlocks, 1);
} else {
/* framePerBlock = 32 batchPerBlock = 32 */
threads = dim3(32, 32);
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
}
if (batchSize == 1) {
KeLstmForward<T, Op,
/* isBatch= */ false><<<grid, threads, 0, STREAM_DEFAULT>>>(
op, value, frameSize, batchSize, active_node, active_gate,
active_state);
} else {
KeLstmForward<T, Op,
/* isBatch= */ true><<<grid, threads, 0, STREAM_DEFAULT>>>(
op, value, frameSize, batchSize, active_node, active_gate,
active_state);
}
}
template <class T, class Op>
void gpu_lstm_backward(Op op, lstm_value value, lstm_grad grad, int frameSize,
int batchSize, activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
dim3 threads;
dim3 grid;
if (batchSize == 1) {
int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
int frameBlocks = (frameSize + 1024 - 1) / 1024;
threads = dim3(framePerBlock, 1);
grid = dim3(frameBlocks, 1);
} else {
/* framePerBlock = 32 batchPerBlock = 32 */
threads = dim3(32, 32);
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
}
if (batchSize == 1) {
KeLstmBackward<T, Op,
/* isBatch= */ false><<<grid, threads, 0, STREAM_DEFAULT>>>(
op, value, grad, frameSize, batchSize, active_node, active_gate,
active_state);
} else {
KeLstmBackward<T, Op,
/* isBatch= */ true><<<grid, threads, 0, STREAM_DEFAULT>>>(
op, value, grad, frameSize, batchSize, active_node, active_gate,
active_state);
}
}
} // namespace detail
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hl_activation_functions.h"
#ifdef __CUDA_ARCH__
#define INLINE __device__ inline
#else
#define INLINE inline
#endif
namespace paddle {
namespace operators {
namespace math {
namespace detail {
namespace forward {
template <class T>
class lstm {
public:
INLINE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
T &prevState, T &state, T &stateAtv, T &output,
T &checkI, T &checkF, T &checkO,
Active<T>::forward actInput,
Active<T>::forward actGate,
Active<T>::forward actState) {
valueIn = actInput(valueIn);
valueIg = actGate(valueIg + prevState * checkI);
valueFg = actGate(valueFg + prevState * checkF);
state = valueIn * valueIg + prevState * valueFg;
valueOg = actGate(valueOg + state * checkO);
stateAtv = actState(state);
output = valueOg * stateAtv;
}
#ifndef __NVCC__
#ifndef __AVX__
static const bool avx = false;
#else
static const bool avx = true;
INLINE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
__m256 &valueOg, __m256 &prevState, __m256 &state,
__m256 &stateAtv, __m256 &output, __m256 &checkI,
__m256 &checkF, __m256 &checkO,
Active<__m256>::forward actInput,
Active<__m256>::forward actGate,
Active<__m256>::forward actState) {
valueIn = actInput(valueIn);
valueIg = actGate(_mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI)));
valueFg = actGate(_mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF)));
state = _mm256_add_ps(_mm256_mul_ps(valueIn, valueIg),
_mm256_mul_ps(prevState, valueFg));
valueOg = actGate(_mm256_add_ps(valueOg, _mm256_mul_ps(state, checkO)));
stateAtv = actState(state);
output = _mm256_mul_ps(valueOg, stateAtv);
}
#endif
#endif
};
} // namespace forward
namespace backward {
template <class T>
class lstm {
public:
INLINE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
T &gradIn, T &gradIg, T &gradFg, T &gradOg,
T &prevState, T &prevStateGrad, T &state, T &stateGrad,
T &stateAtv, T &outputGrad, T &checkI, T &checkF,
T &checkO, T &checkIGrad, T &checkFGrad, T &checkOGrad,
Active<T>::backward actInput,
Active<T>::backward actGate,
Active<T>::backward actState) {
gradOg = actGate(outputGrad * stateAtv, valueOg);
stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO;
gradIn = actInput(stateGrad * valueIg, valueIn);
gradIg = actGate(stateGrad * valueIn, valueIg);
gradFg = actGate(stateGrad * prevState, valueFg);
prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg;
checkIGrad = gradIg * prevState;
checkFGrad = gradFg * prevState;
checkOGrad = gradOg * state;
}
#ifndef __NVCC__
#ifndef __AVX__
static const bool avx = false;
#else
static const bool avx = true;
INLINE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg,
__m256 &valueOg, __m256 &gradIn, __m256 &gradIg,
__m256 &gradFg, __m256 &gradOg, __m256 &prevState,
__m256 &prevStateGrad, __m256 &state,
__m256 &stateGrad, __m256 &stateAtv,
__m256 &outputGrad, __m256 &checkI, __m256 &checkF,
__m256 &checkO, __m256 &checkIGrad, __m256 &checkFGrad,
__m256 &checkOGrad, Active<__m256>::backward actInput,
Active<__m256>::backward actGate,
Active<__m256>::backward actState) {
gradOg = actGate(_mm256_mul_ps(outputGrad, stateAtv), valueOg);
stateGrad = _mm256_add_ps(
actState(_mm256_mul_ps(outputGrad, valueOg), stateAtv), stateGrad);
stateGrad = _mm256_add_ps(_mm256_mul_ps(gradOg, checkO), stateGrad);
gradIn = actInput(_mm256_mul_ps(stateGrad, valueIg), valueIn);
gradIg = actGate(_mm256_mul_ps(stateGrad, valueIn), valueIg);
gradFg = actGate(_mm256_mul_ps(stateGrad, prevState), valueFg);
prevStateGrad = _mm256_add_ps(_mm256_mul_ps(gradIg, checkI),
_mm256_mul_ps(gradFg, checkF));
prevStateGrad =
_mm256_add_ps(_mm256_mul_ps(stateGrad, valueFg), prevStateGrad);
checkIGrad = _mm256_mul_ps(gradIg, prevState);
checkFGrad = _mm256_mul_ps(gradFg, prevState);
checkOGrad = _mm256_mul_ps(gradOg, state);
}
#endif
#endif
};
} // namespace backward
} // namespace detail
} // namespace math
} // namespace operators
} // namespace paddle
#endif /* HL_LSTM_OPS_CUH_ */
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "LstmCompute.h"
#include "paddle/operators/math/detail/lstm_cpu_kernel.h"
#include "paddle/operators/math/detail/lstm_kernel.h"
namespace paddle {
namespace operators {
namespace math {
template <class T>
struct LstmUnitFunctor<platform::CPUPlace, T> {
static void compute(lstm_value value, int frame_size, int batch_size,
std::string gate_act, std::string cell_act,
std::string cand_act) {
for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_forward(detail::forward::lstm<T>(), value, frameSize,
ActiveType(cand_act), ActiveType(gate_act),
ActiveType(cell_act));
value.gateValue += frameSize * 4;
value.stateValue += frameSize;
value.stateActiveValue += frameSize;
value.outputValue += frameSize;
if (value.prevStateValue) {
value.prevStateValue += frameSize;
}
}
}
};
template <class T>
struct LstmUnitGradFunctor<platform::CPUPlace, T> {
static void compute(lstm_value value, lstm_grad grad, int frame_size,
int batch_size, std::string gate_act,
std::string cell_act, std::string cand_act) {
for (int b = 0; b < batchSize; b++) {
detail::cpu_lstm_backward(detail::backward::lstm<T>(), value, grad,
frameSize, ActiveType(cand_act),
ActiveType(gate_act), ActiveType(cell_act));
value.gateValue += frameSize * 4;
value.stateValue += frameSize;
value.stateActiveValue += frameSize;
value.outputValue += frameSize;
if (value.prevStateValue) {
value.prevStateValue += frameSize;
}
grad.gateGrad += frameSize * 4;
grad.stateGrad += frameSize;
grad.stateActiveGrad += frameSize;
grad.outputGrad += frameSize;
if (grad.prevStateGrad) {
grad.prevStateGrad += frameSize;
}
}
};
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "LstmCompute.h"
#include "paddle/operators/math/detail/lstm_cpu_kernel.h"
#include "paddle/operators/math/detail/lstm_kernel.h"
namespace paddle {
namespace operators {
namespace math {
template <class T>
struct LstmUnitFunctor<platform::GPUPlace, T> {
static void compute(lstm_value value, int frame_size, int batch_size,
std::string gate_act, std::string cell_act,
std::string cand_act) {
for (int b = 0; b < batch_size; b++) {
detail::gpu_lstm_forward(detail::forward::lstm<T>(), value, frameSize,
ActiveType(cand_act), ActiveType(gate_act),
ActiveType(cell_act));
value.gateValue += frameSize * 4;
value.stateValue += frameSize;
value.stateActiveValue += frameSize;
value.outputValue += frameSize;
if (value.prevStateValue) {
value.prevStateValue += frameSize;
}
}
}
};
template <class T>
struct LstmUnitGradFunctor<platform::GPUPlace, T> {
static void compute(lstm_value value, lstm_grad grad, int frame_size,
int batch_size, std::string gate_act,
std::string cell_act, std::string cand_act) {
for (int b = 0; b < batchSize; b++) {
detail::gpu_lstm_backward(detail::backward::lstm<T>(), value, grad,
frameSize, ActiveType(cand_act),
ActiveType(gate_act), ActiveType(cell_act));
value.gateValue += frameSize * 4;
value.stateValue += frameSize;
value.stateActiveValue += frameSize;
value.outputValue += frameSize;
if (value.prevStateValue) {
value.prevStateValue += frameSize;
}
grad.gateGrad += frameSize * 4;
grad.stateGrad += frameSize;
grad.stateActiveGrad += frameSize;
grad.outputGrad += frameSize;
if (grad.prevStateGrad) {
grad.prevStateGrad += frameSize;
}
}
};
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/platform/macros.h"
namespace paddle {
namespace operators {
namespace math {
typedef enum {
HL_ACTIVATION_SIGMOID = 0,
HL_ACTIVATION_RELU = 1,
HL_ACTIVATION_TANH = 2,
HL_ACTIVATION_LINEAR = 3,
HL_ACTIVATION_END
} activation_mode_t;
template <T>
struct lstm_value {
real *gateValue;
real *prevStateValue;
real *stateValue;
real *stateActiveValue;
real *outputValue;
real *checkIg;
real *checkFg;
real *checkOg;
};
template <T>
struct lstm_grad {
real *gateGrad;
real *prevStateGrad;
real *stateGrad;
real *stateActiveGrad;
real *outputGrad;
real *checkIgGrad;
real *checkFgGrad;
real *checkOgGrad;
};
activation_mode_t ActiveType(const std::string &type) {
if (type == "sigmoid") {
return HL_ACTIVATION_SIGMOID;
} else if (type == "relu") {
return HL_ACTIVATION_RELU;
} else if (type == "tanh") {
return HL_ACTIVATION_TANH;
} else if (type == "linear" || type == "") {
return HL_ACTIVATION_LINEAR;
} else {
PADDLE_THROW("Do not support activation type.");
}
}
template <typename Place, typename T>
class LstmUnitFunctor {
public:
static void compute(lstm_value value, int frame_size, int batch_size,
std::string gate_act, std::string cell_act,
std::string cand_act);
};
template <typename Place, typename T>
class LstmUnitGradFunctor {
public:
static void compute(lstm_value value, lstm_grad grad, int frame_size,
int batch_size, std::string gate_act,
std::string cell_act, std::string cand_act);
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -18,6 +18,37 @@ namespace paddle {
namespace operators {
namespace math {
template <typename T>
class CopyMatrixRowsFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& src, const size_t* index,
framework::Tensor& dst, bool is_src_index) {
auto src_dims = src.dims();
auto dst_dims = dst.dims();
PADDLE_ENFORCE(src_dims.size(), 2, "The src must be matrix with rank 2.");
PADDLE_ENFORCE(dst_dims.size(), 2, "The dst must be matrix with rank 2.");
PADDLE_ENFORCE_EQ(src_dims[1], dst_dims[1],
"The width of src and dst must be same.");
auto height = dst_dims[0];
auto width = dst_dims[1];
auto* src_data = src.data<T>();
auto* dst_data = dst.data<T>();
for (int i = 0; i < height; ++i) {
if (is_src_index) {
memcpy(dst_data + i * width, src_data + index[i] * width,
width * sizeof(T));
} else {
memcpy(dst_data + index[i] * width, src_data + i * width,
width * sizeof(T));
}
}
}
};
template class CopyMatrixRowsFunctor<platform::CPUPlace, float>;
template class CopyMatrixRowsFunctor<platform::CPUPlace, double>;
template class LoDTensor2BatchFunctor<platform::CPUPlace, float>;
template class Batch2LoDTensor2Functor<platform::CPUPlace, float>;
......
......@@ -18,6 +18,53 @@ namespace paddle {
namespace operators {
namespace math {
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void CopyMatrixRowsKernel(const T* src, T* dst, const int* index,
int height, int width,
const bool is_src_index) {
int idx = threadIdx.x;
int idy = threadIdx.y;
int id = blockIdx.x + idy * GridDimX;
while (id < height) {
int src_idx = is_src_index ? index[id] : id;
int dst_idx = is_src_index ? id : index[id];
T* src_data = src + src_idx * width;
T* dst_data = dst + dst_idx * width;
for (int i = idx; i < width; i += BlockDimX) {
dst_data[i] = src_data[i];
}
id += BlockDimY * GridDimX;
}
}
template <typename T>
class CopyMatrixRowsFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& src, const size_t* index,
framework::Tensor& dst, bool is_src_index) {
auto src_dims = src.dims();
auto dst_dims = dst.dims();
PADDLE_ENFORCE(src_dims.size(), 2, "The src must be matrix with rank 2.");
PADDLE_ENFORCE(dst_dims.size(), 2, "The dst must be matrix with rank 2.");
PADDLE_ENFORCE_EQ(src_dims[1], dst_dims[1],
"The width of src and dst must be same.");
auto height = dst_dims[0];
auto width = dst_dims[1];
auto* src_data = src.data<T>();
auto* dst_data = dst.data<T>();
dim3 threads(128, 8);
dim3 grid(8, 1);
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(context);
CopyMatrixRowsKernel<T, 128, 8, 8><<<grid, threads, 0, stream>>>(
src_data, dst_data, index, height, width);
}
};
template class CopyMatrixRowsFunctor<platform::GPUPlace, float>;
template class CopyMatrixRowsFunctor<platform::GPUPlace, double>;
template class LoDTensor2BatchFunctor<platform::GPUPlace, float>;
template class Batch2LoDTensor2Functor<platform::GPUPlace, float>;
......
......@@ -16,6 +16,19 @@ namespace paddle {
namespace operators {
namespace math {
template <typename Place, typename T>
class CopyMatrixRowsFunctor {
public:
// If is_src_index is true,
// copy the indexed rows of input src to the output dst.
// If is_src_index is false,
// copy the input src to the indexed rows of output dst.
// The indexed rows are based on the input index.
void operator()(const platform::DeviceContext& context,
const framework::Tensor& src, const size_t* index,
framework::Tensor& dst, const bool is_src_index);
};
template <typename Place, typename T>
class LoDTensor2BatchFunctor {
public:
......@@ -97,8 +110,11 @@ class LoDTensor2BatchFunctor {
}
batch_starts[n + 1] = batch_id;
}
CopyMatrixRowsFunctor<Place, T> to_batch;
to_batch(context, lod_tensor, batch, true);
}
}
};
template <typename Place, typename T>
class Batch2LoDTensor2Functor {
......@@ -107,6 +123,7 @@ class Batch2LoDTensor2Functor {
const framework::LoDTensor& batch,
framework::LoDTensor& lod_tensor,
const bool is_reverse) const;
};
} // namespace math
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册