未验证 提交 085260f3 编写于 作者: J Jack Zhou 提交者: GitHub

Add eigen gru and fix the dropout bug in the rnn

Add eigen gru and fix the dropout bug in the rnn 
上级 545df287
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <type_traits> #include <type_traits>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/gru_compute.h" #include "paddle/fluid/operators/math/gru_compute.h"
...@@ -21,6 +23,10 @@ namespace paddle { ...@@ -21,6 +23,10 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
namespace detail { namespace detail {
using Array1 = Eigen::DSizes<int64_t, 1>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
#ifndef __NVCC__ #ifndef __NVCC__
...@@ -242,23 +248,46 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, ...@@ -242,23 +248,46 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
#endif #endif
} }
template <typename T>
inline void forward_reset_outputV2(const platform::CPUDeviceContext &context,
GRUMetaValue<T> value, int frame_size) {
auto &place = *context.eigen_device();
auto value_reset_gate =
typename EigenVector<T>::Type(value.gate_value, Array1(frame_size));
auto value_update_gate = typename EigenVector<T>::Type(
value.gate_value + frame_size, Array1(frame_size));
auto value_reset_output = typename EigenVector<T>::Type(
value.reset_output_value, Array1(frame_size));
auto value_reset_bias =
typename EigenVector<T>::ConstType(value.reset_bias, Array1(frame_size));
SigmoidFunctor<T>()(place, value_reset_gate, value_reset_gate);
SigmoidFunctor<T>()(place, value_update_gate, value_update_gate);
value_reset_output.device(place) =
(value_reset_output + value_reset_bias) * value_reset_gate;
}
template <class OpResetOutput, typename T> template <class OpResetOutput, typename T>
inline void forward_reset_output(OpResetOutput op_reset_output, inline void forward_reset_output(
GRUMetaValue<T> value, int frame_size, OpResetOutput op_reset_output, GRUMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_gate, int batch_size, ActivationType active_gate, bool old_version = true,
bool old_version = true) { const platform::CPUDeviceContext *context = nullptr) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
if (OpResetOutput::avx && (frame_size > static_cast<int>(8 - 1)) && if (!old_version) {
(sizeof(T) == 4)) { // use eigen
hl_avx_gru_forward_reset_output( forward_reset_outputV2(*context, value, frame_size);
op_reset_output, value.gate_value, value.reset_output_value,
value.prev_out_value, frame_size, active_gate, old_version,
value.reset_bias);
} else { } else {
hl_naive_gru_forward_reset_output( if (OpResetOutput::avx && (frame_size & static_cast<int>(8 - 1)) &&
op_reset_output, value.gate_value, value.reset_output_value, (sizeof(T) == 4)) {
value.prev_out_value, frame_size, active_gate, old_version, hl_avx_gru_forward_reset_output(
value.reset_bias); op_reset_output, value.gate_value, value.reset_output_value,
value.prev_out_value, frame_size, active_gate, old_version,
value.reset_bias);
} else {
hl_naive_gru_forward_reset_output(
op_reset_output, value.gate_value, value.reset_output_value,
value.prev_out_value, frame_size, active_gate, old_version,
value.reset_bias);
}
} }
value.gate_value += frame_size * 3; value.gate_value += frame_size * 3;
value.reset_output_value += frame_size; value.reset_output_value += frame_size;
...@@ -268,25 +297,51 @@ inline void forward_reset_output(OpResetOutput op_reset_output, ...@@ -268,25 +297,51 @@ inline void forward_reset_output(OpResetOutput op_reset_output,
} }
} }
template <typename T>
inline void forward_final_outputV2(const platform::CPUDeviceContext &context,
GRUMetaValue<T> value, int frame_size) {
auto &place = *context.eigen_device();
auto value_update_gate = typename EigenVector<T>::Type(
value.gate_value + frame_size, Array1(frame_size));
auto value_frame_state = typename EigenVector<T>::Type(
value.gate_value + 2 * frame_size, Array1(frame_size));
auto value_output =
typename EigenVector<T>::Type(value.output_value, Array1(frame_size));
TanhFunctor<T>()(place, value_frame_state, value_frame_state);
value_output.device(place) =
(static_cast<T>(1.0) - value_update_gate) * value_frame_state;
if (value.prev_out_value) {
auto value_prev_out = typename EigenVector<T>::ConstType(
value.prev_out_value, Array1(frame_size));
value_output.device(place) =
value_output + value_update_gate * value_prev_out;
}
}
template <class OpFinalOutput, typename T> template <class OpFinalOutput, typename T>
inline void forward_final_output(OpFinalOutput op_final_output, inline void forward_final_output(
GRUMetaValue<T> value, int frame_size, OpFinalOutput op_final_output, GRUMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_node, int batch_size, ActivationType active_node, bool origin_mode,
bool origin_mode, bool old_version = true) { bool old_version = true,
const platform::CPUDeviceContext *context = nullptr) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
if (OpFinalOutput::avx && (frame_size > static_cast<int>(8 - 1)) && if (!old_version) {
(sizeof(T) == 4)) { // eigen
hl_avx_gru_forward_final_output(op_final_output, value.gate_value, forward_final_outputV2(*context, value, frame_size);
value.prev_out_value, value.output_value,
frame_size, active_node, origin_mode,
old_version);
} else { } else {
hl_naive_gru_forward_final_output(op_final_output, value.gate_value, if (OpFinalOutput::avx && (frame_size & static_cast<int>(8 - 1)) &&
(sizeof(T) == 4)) {
hl_avx_gru_forward_final_output(op_final_output, value.gate_value,
value.prev_out_value, value.prev_out_value,
value.output_value, frame_size, value.output_value, frame_size,
active_node, origin_mode, old_version); active_node, origin_mode, old_version);
} else {
hl_naive_gru_forward_final_output(
op_final_output, value.gate_value, value.prev_out_value,
value.output_value, frame_size, active_node, origin_mode,
old_version);
}
} }
value.gate_value += frame_size * 3; value.gate_value += frame_size * 3;
value.output_value += frame_size; value.output_value += frame_size;
if (value.prev_out_value) { if (value.prev_out_value) {
...@@ -664,23 +719,70 @@ inline void backward_reset_grad(OpResetGrad op_reset_grad, ...@@ -664,23 +719,70 @@ inline void backward_reset_grad(OpResetGrad op_reset_grad,
} }
} }
template <typename T>
inline void gru_backward(const platform::CPUDeviceContext &context,
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size) {
auto &place = *context.eigen_device();
auto value_reset_gate =
typename EigenVector<T>::Type(value.gate_value, Array1(frame_size));
auto grad_reset_gate =
typename EigenVector<T>::Type(grad.gate_grad, Array1(frame_size));
auto value_update_gate = typename EigenVector<T>::Type(
value.gate_value + frame_size, Array1(frame_size));
auto grad_update_gate = typename EigenVector<T>::Type(
grad.gate_grad + frame_size, Array1(frame_size));
auto value_frame_state = typename EigenVector<T>::Type(
value.gate_value + frame_size * 2, Array1(frame_size));
auto grad_frame_state = typename EigenVector<T>::Type(
grad.gate_grad + frame_size * 2, Array1(frame_size));
auto grad_output =
typename EigenVector<T>::Type(grad.output_grad, Array1(frame_size));
auto value_reset_output = typename EigenVector<T>::Type(
value.reset_output_value, Array1(frame_size));
auto grad_reset_output =
typename EigenVector<T>::Type(grad.reset_output_grad, Array1(frame_size));
if (value.prev_out_value) {
auto value_prev_out = typename EigenVector<T>::ConstType(
value.prev_out_value, Array1(frame_size));
SigmoidGradFunctor<T>()(place, 1 /*useless*/, value_update_gate,
(value_prev_out - value_frame_state) * grad_output,
grad_update_gate);
} else {
SigmoidGradFunctor<T>()(
place, 1 /*useless*/, value_update_gate,
static_cast<T>(-1) * value_frame_state * grad_output, grad_update_gate);
}
if (grad.prev_out_grad) {
auto grad_prev_out =
typename EigenVector<T>::Type(grad.prev_out_grad, Array1(frame_size));
grad_prev_out.device(place) =
grad_prev_out + grad_output * value_update_gate;
}
TanhGradFunctor<T>()(place, 1 /*useless*/, value_frame_state,
grad_output * (static_cast<T>(1.0) - value_update_gate),
grad_frame_state);
SigmoidGradFunctor<T>()(
place, 1 /*useless*/, value_reset_gate,
value_reset_output / value_reset_gate * grad_frame_state,
grad_reset_gate);
if (value.prev_out_value && grad.prev_out_grad) {
grad_reset_output.device(place) = value_reset_gate * grad_frame_state;
}
}
template <class OpGruGrad, typename T> template <class OpGruGrad, typename T>
inline void cpu_gru_backward(OpGruGrad op_gru_grad, GRUMetaValue<T> value, inline void cpu_gru_backward(const platform::CPUDeviceContext &context,
OpGruGrad op_gru_grad, GRUMetaValue<T> value,
GRUMetaGrad<T> grad, int frame_size, GRUMetaGrad<T> grad, int frame_size,
int batch_size, ActivationType active_node, int batch_size, ActivationType active_node,
ActivationType active_gate) { ActivationType active_gate) {
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
if (OpGruGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { // eigen
hl_avx_gru_backward( gru_backward(context, value, grad, frame_size);
op_gru_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, value.reset_output_value, grad.reset_output_grad,
grad.output_grad, frame_size, active_node, active_gate);
} else {
hl_naive_gru_backward(
op_gru_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, value.reset_output_value, grad.reset_output_grad,
grad.output_grad, frame_size, active_node, active_gate);
}
value.gate_value += frame_size * 3; value.gate_value += frame_size * 3;
value.reset_output_value += frame_size; value.reset_output_value += frame_size;
......
...@@ -42,7 +42,8 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> { ...@@ -42,7 +42,8 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
} }
detail::forward_reset_output(detail::forward::gru_resetOutput<T>(), value, detail::forward_reset_output(detail::forward::gru_resetOutput<T>(), value,
frame_size, batch_size, active_gate); frame_size, batch_size, active_gate, true,
&context);
if (value.prev_out_value) { if (value.prev_out_value) {
blas.GEMM(false, false, batch_size, frame_size, frame_size, 1, blas.GEMM(false, false, batch_size, frame_size, frame_size, 1,
...@@ -53,7 +54,7 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> { ...@@ -53,7 +54,7 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value, detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value,
frame_size, batch_size, active_node, frame_size, batch_size, active_node,
origin_mode); origin_mode, &context);
#endif #endif
} }
}; };
...@@ -116,7 +117,8 @@ struct GRUUnitFunctorV2<platform::CPUDeviceContext, T> { ...@@ -116,7 +117,8 @@ struct GRUUnitFunctorV2<platform::CPUDeviceContext, T> {
value.reset_output_value); value.reset_output_value);
} }
detail::forward_reset_output(detail::forward::gru_resetOutput<T>(), value, detail::forward_reset_output(detail::forward::gru_resetOutput<T>(), value,
frame_size, batch_size, active_gate, false); frame_size, batch_size, active_gate, false,
&context);
T *cell_state_value = value.gate_value + 2 * frame_size; T *cell_state_value = value.gate_value + 2 * frame_size;
T *reset_output_value = value.reset_output_value; T *reset_output_value = value.reset_output_value;
...@@ -129,7 +131,7 @@ struct GRUUnitFunctorV2<platform::CPUDeviceContext, T> { ...@@ -129,7 +131,7 @@ struct GRUUnitFunctorV2<platform::CPUDeviceContext, T> {
detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value, detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value,
frame_size, batch_size, active_node, true, frame_size, batch_size, active_node, true,
false); false, &context);
#endif #endif
} }
}; };
...@@ -144,8 +146,50 @@ struct GRUUnitGradFunctorV2<platform::CPUDeviceContext, T> { ...@@ -144,8 +146,50 @@ struct GRUUnitGradFunctorV2<platform::CPUDeviceContext, T> {
#ifndef __NVCC__ #ifndef __NVCC__
// calculate grad_update_gate, grad_frame_state, // calculate grad_update_gate, grad_frame_state,
// grad_reset_output, grad_reset_gate // grad_reset_output, grad_reset_gate
detail::cpu_gru_backward(detail::backward::gru<T>(), value, grad, detail::cpu_gru_backward(context, detail::backward::gru<T>(), value, grad,
frame_size, batch_size, active_node, active_gate); frame_size, batch_size, active_node, active_gate);
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
if (grad.prev_out_grad && value.prev_out_value) {
// update prev_out_grad
blas.GEMM(false, false, batch_size, frame_size, frame_size, 1,
grad.gate_grad, frame_size * 3, value.gate_weight, frame_size,
1, grad.prev_out_grad, frame_size);
blas.GEMM(false, false, batch_size, frame_size, frame_size, 1,
grad.gate_grad + frame_size, frame_size * 3,
value.gate_weight + frame_size * frame_size, frame_size, 1,
grad.prev_out_grad, frame_size);
blas.GEMM(false, false, batch_size, frame_size, frame_size, 1,
grad.reset_output_grad, frame_size, value.state_weight,
frame_size, 1, grad.prev_out_grad, frame_size);
// update weight_hh_grad
if (grad.gate_weight_grad) {
// reset gate
blas.GEMM(true, false, frame_size, frame_size, batch_size, 1,
grad.gate_grad, frame_size * 3, value.prev_out_value,
frame_size, 1, grad.gate_weight_grad, frame_size);
// update gate
blas.GEMM(true, false, frame_size, frame_size, batch_size, 1,
grad.gate_grad + frame_size, frame_size * 3,
value.prev_out_value, frame_size, 1,
grad.gate_weight_grad + frame_size * frame_size, frame_size);
// cell state
blas.GEMM(true, false, frame_size, frame_size, batch_size, 1,
grad.reset_output_grad, frame_size, value.prev_out_value,
frame_size, 1, grad.state_weight_grad, frame_size);
}
}
// update bias_hh_grad
T *gate_grad = grad.gate_grad;
T *bias_hh_grad = grad.bias_hh_grad;
T *state_bias_grad = grad.bias_hh_grad + 2 * frame_size;
T *reset_output_grad = grad.reset_output_grad;
for (int b = 0; b < batch_size; ++b) {
blas.VADD(2 * frame_size, bias_hh_grad, gate_grad, bias_hh_grad);
blas.VADD(frame_size, state_bias_grad, reset_output_grad,
state_bias_grad);
gate_grad += 3 * frame_size;
reset_output_grad += frame_size;
}
#endif #endif
} }
}; };
......
...@@ -38,7 +38,7 @@ struct GRUMetaGrad { ...@@ -38,7 +38,7 @@ struct GRUMetaGrad {
T *reset_output_grad; T *reset_output_grad;
T *output_grad; T *output_grad;
T *prev_out_grad; T *prev_out_grad;
T *state_bias_grad; T *bias_hh_grad;
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
......
...@@ -210,66 +210,58 @@ struct LSTMCell : Cell<T> { ...@@ -210,66 +210,58 @@ struct LSTMCell : Cell<T> {
} }
}; };
template <typename T>
void dropout_helper(const framework::ExecutionContext& context, Tensor* x,
Tensor* y, const Tensor* mask, const float& dropout_prob) {
auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto dropout_mask = EigenVector<uint8_t>::Flatten(*mask);
auto in = EigenVector<T>::Flatten(*x);
auto out = EigenVector<T>::Flatten(*y);
if (dropout_prob == 1.0f) {
out.device(place) = static_cast<T>(0) * in;
} else {
out.device(place) =
in * dropout_mask.cast<T>() / static_cast<T>(1.0f - dropout_prob);
}
}
template <typename T> template <typename T>
void dropout_cpu_function_inplace(const framework::ExecutionContext& context, void dropout_cpu_function_inplace(const framework::ExecutionContext& context,
Tensor* x, Tensor* mask, Tensor* x, Tensor* y, Tensor* mask,
const float& dropout_prob, const float& dropout_prob,
const int& seed_number, const bool& is_test, const int& seed_number, const bool& is_test,
bool* is_has_reset) { bool* is_has_reset) {
if (is_test) { if (is_test) {
return; return;
} }
auto* x_data = x->data<T>();
size_t size = framework::product(x->dims()); size_t size = framework::product(x->dims());
auto* mask_data = mask->data<uint8_t>(); auto* mask_data = mask->data<uint8_t>();
if (!(*is_has_reset)) { if (!(*is_has_reset)) {
// Special case when dropout_prob is 1.0 // Special case when dropout_prob is 1.0
if (dropout_prob == 1.0f) { if (dropout_prob == 1.0f) {
std::fill(x_data, x_data + size, static_cast<T>(0)); std::fill(mask_data, mask_data + size, static_cast<uint8_t>(0));
std::fill(mask_data, mask_data + size, static_cast<T>(0)); } else {
*is_has_reset = true; auto engine = framework::GetCPURandomEngine(seed_number);
return; std::uniform_real_distribution<float> dist(0, 1);
} for (size_t i = 0; i < size; ++i) {
auto engine = framework::GetCPURandomEngine(seed_number); if (dist(*engine) < dropout_prob) {
std::uniform_real_distribution<float> dist(0, 1); mask_data[i] = 0;
for (size_t i = 0; i < size; ++i) { } else {
if (dist(*engine) < dropout_prob) { mask_data[i] = 1;
mask_data[i] = 0; }
x_data[i] = static_cast<T>(0);
} else {
mask_data[i] = 1;
x_data[i] /= static_cast<T>(1.0f - dropout_prob);
} }
} }
*is_has_reset = true; *is_has_reset = true;
} else {
if (dropout_prob == 1.0f) {
std::fill(x_data, x_data + size, static_cast<T>(0));
return;
}
for (size_t i = 0; i < size; ++i) {
if (mask_data[i] == 0) {
x_data[i] = static_cast<T>(0);
} else {
x_data[i] /= static_cast<T>(1.0f - dropout_prob);
}
}
} }
dropout_helper<T>(context, x, y, mask, dropout_prob);
} }
template <typename T> template <typename T>
void dropout_cpu_grad_function_inplace( void dropout_cpu_grad_function_inplace(
const framework::ExecutionContext& context, Tensor* grad_x, const framework::ExecutionContext& context, Tensor* grad_x,
const Tensor* mask, const float& dropout_prob) { const Tensor* mask, const float& dropout_prob) {
auto& place = *context.template device_context<platform::CPUDeviceContext>() dropout_helper<T>(context, grad_x, grad_x, mask, dropout_prob);
.eigen_device();
auto M = EigenVector<uint8_t>::Flatten(*mask);
auto dX = EigenVector<T>::Flatten(*grad_x);
if (dropout_prob == 1.0f) {
dX.device(place) = static_cast<T>(0) * dX;
} else {
dX.device(place) = dX * M.cast<T>() / static_cast<T>(1.0f - dropout_prob);
}
} }
template <typename T, typename CellType> template <typename T, typename CellType>
...@@ -298,14 +290,13 @@ struct Layer { ...@@ -298,14 +290,13 @@ struct Layer {
blas.MatMul(*input, mat_dim_a, weight, mat_dim_b, static_cast<T>(1.0), blas.MatMul(*input, mat_dim_a, weight, mat_dim_b, static_cast<T>(1.0),
cache_input, static_cast<T>(0)); cache_input, static_cast<T>(0));
auto eigen_in = framework::EigenMatrix<T>::Reshape( auto in = framework::EigenMatrix<T>::Reshape(
*cache_input, cache_input->dims().size() - 1); *cache_input, cache_input->dims().size() - 1);
auto eigen_bias_ih = framework::EigenMatrix<T>::From( auto bias_ih_tmp = framework::EigenMatrix<T>::From(
bias_ih, framework::make_ddim({1, bias_ih.dims()[0]})); bias_ih, framework::make_ddim({1, bias_ih.dims()[0]}));
const int& row_num = const int& row_num =
framework::product(cache_input->dims()) / cache_input->dims()[2]; framework::product(cache_input->dims()) / cache_input->dims()[2];
eigen_in = in = in + bias_ih_tmp.broadcast(Eigen::DSizes<int, 2>(row_num, 1));
eigen_in + eigen_bias_ih.broadcast(Eigen::DSizes<int, 2>(row_num, 1));
if (is_gru(context)) { if (is_gru(context)) {
// reset_gate update_gate cell_gate = [1, 1, 0] // reset_gate update_gate cell_gate = [1, 1, 0]
Tensor bias_hh_tmp; Tensor bias_hh_tmp;
...@@ -317,15 +308,13 @@ struct Layer { ...@@ -317,15 +308,13 @@ struct Layer {
math::SetConstant<platform::CPUDeviceContext, T> zero; math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(dev_ctx, &bias_hh_tmp_unbind[2], static_cast<T>(0.0)); zero(dev_ctx, &bias_hh_tmp_unbind[2], static_cast<T>(0.0));
auto eigen_bias_hh_tmp = framework::EigenMatrix<T>::From( auto bias_hh_after_mask = framework::EigenMatrix<T>::From(
bias_hh_tmp, framework::make_ddim({1, bias_hh.dims()[0]})); bias_hh_tmp, framework::make_ddim({1, bias_hh.dims()[0]}));
eigen_in = eigen_in + in = in + bias_hh_after_mask.broadcast(Eigen::DSizes<int, 2>(row_num, 1));
eigen_bias_hh_tmp.broadcast(Eigen::DSizes<int, 2>(row_num, 1));
} else { } else {
auto eigen_bias_hh = framework::EigenMatrix<T>::From( auto bias_hh_no_mask = framework::EigenMatrix<T>::From(
bias_hh, framework::make_ddim({1, bias_hh.dims()[0]})); bias_hh, framework::make_ddim({1, bias_hh.dims()[0]}));
eigen_in = in = in + bias_hh_no_mask.broadcast(Eigen::DSizes<int, 2>(row_num, 1));
eigen_in + eigen_bias_hh.broadcast(Eigen::DSizes<int, 2>(row_num, 1));
} }
} }
...@@ -335,27 +324,26 @@ struct Layer { ...@@ -335,27 +324,26 @@ struct Layer {
// in the output, if mask flag is 0, we will retun the zero data // in the output, if mask flag is 0, we will retun the zero data
auto& place = *context.template device_context<platform::CPUDeviceContext>() auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device(); .eigen_device();
auto eigen_output = auto out =
framework::EigenMatrix<T>::Reshape(*output, output->dims().size() - 1); framework::EigenMatrix<T>::Reshape(*output, output->dims().size() - 1);
auto eigen_mask = framework::EigenMatrix<T>::From( auto mask = framework::EigenMatrix<T>::From(
mask_tensor, framework::make_ddim({mask_tensor.dims()[1], 1})); mask_tensor, framework::make_ddim({mask_tensor.dims()[1], 1}));
auto eigen_init_h = auto pre_h =
framework::EigenMatrix<T>::Reshape(*init_h, init_h->dims().size() - 1); framework::EigenMatrix<T>::Reshape(*init_h, init_h->dims().size() - 1);
auto eigen_last_h = auto curr_h =
framework::EigenMatrix<T>::Reshape(*last_h, last_h->dims().size() - 1); framework::EigenMatrix<T>::Reshape(*last_h, last_h->dims().size() - 1);
auto eigen_mask_broadcast = auto mask_broadcast =
eigen_mask.broadcast(Eigen::DSizes<int, 2>(1, output->dims()[2])); mask.broadcast(Eigen::DSizes<int, 2>(1, output->dims()[2]));
eigen_last_h.device(place) = eigen_output * eigen_mask_broadcast + curr_h.device(place) = out * mask_broadcast + pre_h * (1 - mask_broadcast);
eigen_init_h * (1 - eigen_mask_broadcast); out.device(place) = out * mask_broadcast;
eigen_output.device(place) = eigen_output * eigen_mask_broadcast;
if (is_lstm(context)) { if (is_lstm(context)) {
auto eigen_init_c = framework::EigenMatrix<T>::Reshape( auto pre_c = framework::EigenMatrix<T>::Reshape(
*init_c, init_c->dims().size() - 1); *init_c, init_c->dims().size() - 1);
auto eigen_last_c = framework::EigenMatrix<T>::Reshape( auto curr_c = framework::EigenMatrix<T>::Reshape(
*last_c, last_c->dims().size() - 1); *last_c, last_c->dims().size() - 1);
eigen_last_c.device(place) = eigen_last_c * eigen_mask_broadcast + curr_c.device(place) =
eigen_init_c * (1 - eigen_mask_broadcast); curr_c * mask_broadcast + pre_c * (1 - mask_broadcast);
} }
} }
...@@ -910,16 +898,18 @@ void RnnFunc(const framework::ExecutionContext& ctx, const Tensor* input, ...@@ -910,16 +898,18 @@ void RnnFunc(const framework::ExecutionContext& ctx, const Tensor* input,
} }
if (!is_test) { if (!is_test) {
prev_hidden_data = hidden_data.Slice(i - 1, i); prev_hidden_data = hidden_data.Slice(i - 1, i);
input_holder = &prev_hidden_data;
input_holder->Resize(output->dims()); input_holder->Resize(output->dims());
if (dropout_prob != 0) {
dropout_cpu_function_inplace<T>(ctx, &prev_hidden_data, input_holder,
dropout_mask, dropout_prob, seed,
is_test, &has_dropout_reset);
} else {
input_holder = &prev_hidden_data;
input_holder->Resize(output->dims());
}
} else { } else {
SwapPoniter(&output_holder, &input_holder); SwapPoniter(&output_holder, &input_holder);
} }
if (dropout_prob != 0 && (!is_test)) {
dropout_cpu_function_inplace<T>(ctx, input_holder, dropout_mask,
dropout_prob, seed, is_test,
&has_dropout_reset);
}
} }
const Tensor* input_temp_holder = input; const Tensor* input_temp_holder = input;
if (i > 0) { if (i > 0) {
...@@ -1040,53 +1030,6 @@ void create_tensor_by_list(const framework::ExecutionContext& context, ...@@ -1040,53 +1030,6 @@ void create_tensor_by_list(const framework::ExecutionContext& context,
} }
} }
template <typename T>
void make_grad_gate_buf(const framework::ExecutionContext& context,
Tensor* grad_gate, Tensor* grad_gate_buf,
Tensor* reset_output_grad = nullptr) {
int dim_size = grad_gate->dims().size();
int batch_size = grad_gate->dims()[dim_size - 2];
int frame_size = grad_gate->dims()[dim_size - 1];
Tensor grad_gate_mask;
create_tensor_by_list<T>(context, &grad_gate_mask, {1, 1, 0});
auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto eigen_grad_gate_mask = framework::EigenMatrix<T>::From(
grad_gate_mask, framework::make_ddim({3, 1}));
auto eigen_grad_gate_mask_broadcast =
eigen_grad_gate_mask.broadcast(Eigen::DSizes<int, 2>(1, frame_size / 3))
.reshape(Eigen::DSizes<int, 1>(frame_size))
.broadcast(Eigen::DSizes<int, 2>(batch_size, 1));
auto eigen_grad_gate_buf = framework::EigenMatrix<T>::From(
*grad_gate_buf, framework::make_ddim({batch_size, frame_size}));
auto eigen_grad_gate = framework::EigenMatrix<T>::From(
*grad_gate, framework::make_ddim({batch_size, frame_size}));
eigen_grad_gate_buf.device(place) =
eigen_grad_gate * eigen_grad_gate_mask_broadcast;
if (reset_output_grad) {
Tensor grad_reset_output_mask;
create_tensor_by_list<T>(context, &grad_reset_output_mask, {0, 0, 1});
auto eigen_grad_reset_output_mask = framework::EigenMatrix<T>::From(
grad_reset_output_mask, framework::make_ddim({3, 1}));
auto eigen_grad_reset_output_mask_broadcast =
eigen_grad_reset_output_mask
.broadcast(Eigen::DSizes<int, 2>(1, frame_size / 3))
.reshape(Eigen::DSizes<int, 1>(frame_size))
.broadcast(Eigen::DSizes<int, 2>(batch_size, 1));
auto eigen_grad_reset_output =
framework::EigenMatrix<T>::Reshape(*reset_output_grad,
reset_output_grad->dims().size() - 1)
.broadcast(Eigen::DSizes<int, 3>(1, 3, 1))
.reshape(Eigen::DSizes<int, 2>(batch_size, frame_size));
eigen_grad_gate_buf.device(place) =
eigen_grad_gate_buf +
eigen_grad_reset_output_mask_broadcast * eigen_grad_reset_output;
}
}
template <typename T, typename GradCellType> template <typename T, typename GradCellType>
struct GradLayer { struct GradLayer {
explicit GradLayer(const GradCellType& cell) : cell_(cell) {} explicit GradLayer(const GradCellType& cell) : cell_(cell) {}
...@@ -1196,12 +1139,10 @@ struct GradLayer { ...@@ -1196,12 +1139,10 @@ struct GradLayer {
Tensor* pre_hidden = nullptr; Tensor* pre_hidden = nullptr;
Tensor* pre_state = nullptr; Tensor* pre_state = nullptr;
Tensor* hidden = nullptr; Tensor* hidden = nullptr;
Tensor grad_gate_buf;
TensorList grad_gate_buf_unbind;
if (is_gru(context)) { if (is_gru(context)) {
grad_gate_buf.Resize(layer_grad_gate_tensor->dims()); zero(device_ctx,
grad_gate_buf.mutable_data<T>(context.GetPlace()); &((*weight_list_grad)[layer_idx][current_reverse_idx * 4 + 3]),
grad_gate_buf_unbind = Unbind(grad_gate_buf); static_cast<T>(0.0));
} }
for (int i = time_step - 1; i >= 0; --i) { for (int i = time_step - 1; i >= 0; --i) {
if (has_sequence_length) { if (has_sequence_length) {
...@@ -1232,7 +1173,7 @@ struct GradLayer { ...@@ -1232,7 +1173,7 @@ struct GradLayer {
&(parameter_lists[layer_idx][current_reverse_idx * 4 + 1]), &(parameter_lists[layer_idx][current_reverse_idx * 4 + 1]),
pre_hidden, pre_state, dynamic_grad_last_h, dynamic_grad_last_c, pre_hidden, pre_state, dynamic_grad_last_h, dynamic_grad_last_c,
&(*layer_grad_gate_tensor_unbind)[i], weight_grad, dynamic_grad_pre_h, &(*layer_grad_gate_tensor_unbind)[i], weight_grad, dynamic_grad_pre_h,
dynamic_grad_pre_c, &grad_gate_buf_unbind[i], dynamic_grad_pre_c,
&((*weight_list_grad)[layer_idx][current_reverse_idx * 4 + 3]), &((*weight_list_grad)[layer_idx][current_reverse_idx * 4 + 3]),
mask_tensor_list[i], has_sequence_length); mask_tensor_list[i], has_sequence_length);
SwapPoniter(&dynamic_grad_last_h, &dynamic_grad_pre_h); SwapPoniter(&dynamic_grad_last_h, &dynamic_grad_pre_h);
...@@ -1241,8 +1182,7 @@ struct GradLayer { ...@@ -1241,8 +1182,7 @@ struct GradLayer {
// postproces for gradient for w_hi, X, bias_hi, bias_hh // postproces for gradient for w_hi, X, bias_hi, bias_hh
this->postprocess(context, *layer_grad_gate_tensor, *input, input_grad, this->postprocess(context, *layer_grad_gate_tensor, *input, input_grad,
parameter_lists[layer_idx], parameter_lists[layer_idx],
&((*weight_list_grad)[layer_idx]), &grad_gate_buf, &((*weight_list_grad)[layer_idx]), is_reverse);
is_reverse);
// copy the gradient to init_c init_h // copy the gradient to init_c init_h
if ((*init_h_grad_unbind).size() > 0 && time_step % 2 == 0) { if ((*init_h_grad_unbind).size() > 0 && time_step % 2 == 0) {
...@@ -1268,16 +1208,17 @@ struct GradLayer { ...@@ -1268,16 +1208,17 @@ struct GradLayer {
TensorList* init_h_grad_unbind, TensorList* init_c_grad_unbind, TensorList* init_h_grad_unbind, TensorList* init_c_grad_unbind,
const std::vector<TensorList>& weight_list_grad, const int& layer_idx, const std::vector<TensorList>& weight_list_grad, const int& layer_idx,
const int& gate_num) {} const int& gate_num) {}
void preprocess(const framework::ExecutionContext& context, void preprocess(const framework::ExecutionContext& context,
const Tensor* grad_output, Tensor* grad_last_h) { const Tensor* grad_output, Tensor* grad_last_h) {
auto& place = *context.template device_context<platform::CPUDeviceContext>() auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device(); .eigen_device();
auto eigen_grad_output = framework::EigenMatrix<T>::Reshape( auto output_grad = framework::EigenMatrix<T>::Reshape(
*grad_output, grad_output->dims().size() - 1); *grad_output, grad_output->dims().size() - 1);
auto eigen_grad_last_h = framework::EigenMatrix<T>::Reshape( auto last_h_grad = framework::EigenMatrix<T>::Reshape(
*grad_last_h, grad_last_h->dims().size() - 1); *grad_last_h, grad_last_h->dims().size() - 1);
// the output gradient contribute the gradient to last_h // the output gradient contribute the gradient to last_h
eigen_grad_last_h.device(place) = eigen_grad_last_h + eigen_grad_output; last_h_grad.device(place) = last_h_grad + output_grad;
} }
void mask_preprocess(const framework::ExecutionContext& context, void mask_preprocess(const framework::ExecutionContext& context,
...@@ -1286,40 +1227,35 @@ struct GradLayer { ...@@ -1286,40 +1227,35 @@ struct GradLayer {
Tensor* grad_pre_c, const Tensor& mask_tensor) { Tensor* grad_pre_c, const Tensor& mask_tensor) {
auto& place = *context.template device_context<platform::CPUDeviceContext>() auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device(); .eigen_device();
auto eigen_mask = framework::EigenMatrix<T>::From( auto mask = framework::EigenMatrix<T>::From(
mask_tensor, framework::make_ddim({mask_tensor.dims()[1], 1})); mask_tensor, framework::make_ddim({mask_tensor.dims()[1], 1}));
auto eigen_mask_broadcast = auto mask_broadcast =
eigen_mask.broadcast(Eigen::DSizes<int, 2>(1, grad_output->dims()[2])); mask.broadcast(Eigen::DSizes<int, 2>(1, grad_output->dims()[2]));
auto eigen_grad_last_h = framework::EigenMatrix<T>::Reshape( auto last_h_grad = framework::EigenMatrix<T>::Reshape(
*grad_last_h, grad_last_h->dims().size() - 1); *grad_last_h, grad_last_h->dims().size() - 1);
auto eigen_grad_pre_h = framework::EigenMatrix<T>::Reshape( auto pre_h_grad = framework::EigenMatrix<T>::Reshape(
*grad_pre_h, grad_pre_h->dims().size() - 1); *grad_pre_h, grad_pre_h->dims().size() - 1);
auto eigen_grad_output = framework::EigenMatrix<T>::Reshape( auto output_grad = framework::EigenMatrix<T>::Reshape(
*grad_output, grad_output->dims().size() - 1); *grad_output, grad_output->dims().size() - 1);
eigen_grad_last_h.device(place) = last_h_grad.device(place) = last_h_grad + output_grad * mask_broadcast;
eigen_grad_last_h + eigen_grad_output * eigen_mask_broadcast; pre_h_grad.device(place) = (1 - mask_broadcast) * last_h_grad;
eigen_grad_pre_h.device(place) = last_h_grad.device(place) = mask_broadcast * last_h_grad;
(1 - eigen_mask_broadcast) * eigen_grad_last_h;
eigen_grad_last_h.device(place) = eigen_mask_broadcast * eigen_grad_last_h;
if (grad_last_c && grad_pre_c && is_lstm(context)) { if (grad_last_c && grad_pre_c && is_lstm(context)) {
auto eigen_grad_last_c = framework::EigenMatrix<T>::Reshape( auto last_c_grad = framework::EigenMatrix<T>::Reshape(
*grad_last_c, grad_last_c->dims().size() - 1); *grad_last_c, grad_last_c->dims().size() - 1);
auto eigen_grad_pre_c = framework::EigenMatrix<T>::Reshape( auto pre_c_grad = framework::EigenMatrix<T>::Reshape(
*grad_pre_c, grad_pre_c->dims().size() - 1); *grad_pre_c, grad_pre_c->dims().size() - 1);
eigen_grad_pre_c.device(place) = pre_c_grad.device(place) = (1 - mask_broadcast) * last_c_grad;
(1 - eigen_mask_broadcast) * eigen_grad_last_c; last_c_grad.device(place) = mask_broadcast * last_c_grad;
eigen_grad_last_c.device(place) =
eigen_mask_broadcast * eigen_grad_last_c;
} }
} }
void postprocess(const framework::ExecutionContext& context, void postprocess(const framework::ExecutionContext& context,
const Tensor& grad_gate, const Tensor& input, const Tensor& grad_gate, const Tensor& input,
Tensor* input_grad, const TensorList& parameters, Tensor* input_grad, const TensorList& parameters,
TensorList* grad_parameters, Tensor* grad_gate_buf, TensorList* grad_parameters, const int& is_reverse) {
const int& is_reverse) {
// we get the grad_gate step by step, and need to bradocast the grad to the // we get the grad_gate step by step, and need to bradocast the grad to the
// grad_w_hi, grad_bias_hi, grad_bias_hh // grad_w_hi, grad_bias_hi, grad_bias_hh
int begin_idx = 0; int begin_idx = 0;
...@@ -1360,10 +1296,7 @@ struct GradLayer { ...@@ -1360,10 +1296,7 @@ struct GradLayer {
{grad_gate.dims()[0] * grad_gate.dims()[1], grad_gate.dims()[2]}); {grad_gate.dims()[0] * grad_gate.dims()[1], grad_gate.dims()[2]});
col_sum(device_ctx, tmp_grad_gate, &((*grad_parameters)[begin_idx + 2])); col_sum(device_ctx, tmp_grad_gate, &((*grad_parameters)[begin_idx + 2]));
// Bias_hh // Bias_hh
if (is_gru(context)) { if (!is_gru(context)) {
grad_gate_buf->Resize(tmp_grad_gate.dims());
col_sum(device_ctx, *grad_gate_buf, &((*grad_parameters)[begin_idx + 3]));
} else {
col_sum(device_ctx, tmp_grad_gate, &((*grad_parameters)[begin_idx + 3])); col_sum(device_ctx, tmp_grad_gate, &((*grad_parameters)[begin_idx + 3]));
} }
} }
...@@ -1600,64 +1533,69 @@ struct GradCell { ...@@ -1600,64 +1533,69 @@ struct GradCell {
Tensor* pre_state, Tensor* grad_hidden, Tensor* pre_state, Tensor* grad_hidden,
Tensor* grad_state, Tensor* grad_gate, Tensor* grad_state, Tensor* grad_gate,
Tensor* grad_weight_hh, Tensor* grad_pre_hidden, Tensor* grad_weight_hh, Tensor* grad_pre_hidden,
Tensor* grad_pre_state, Tensor* grad_gate_buf, Tensor* grad_pre_state, Tensor* grad_bias_hh,
Tensor* grad_bias_hh, const Tensor& mask_tensor, const Tensor& mask_tensor,
bool has_sequence_length) const {} bool has_sequence_length) const {}
void postprocess_pre_hidden_grad(const framework::ExecutionContext& context,
Tensor* grad_pre_hidden,
Tensor* grad_pre_hidden_bak,
Tensor* grad_pre_state,
Tensor* grad_pre_state_bak,
const Tensor& mask_tensor,
bool has_sequence_length) const {
if (has_sequence_length) {
auto& place =
*context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto mask = framework::EigenMatrix<T>::From(
mask_tensor, framework::make_ddim({mask_tensor.dims()[1], 1}));
auto mask_broadcast =
mask.broadcast(Eigen::DSizes<int, 2>(1, grad_pre_hidden->dims()[2]));
auto pre_hidden_grad = framework::EigenMatrix<T>::Reshape(
*grad_pre_hidden, grad_pre_hidden->dims().size() - 1);
auto pre_hidden_bak_grad = framework::EigenMatrix<T>::Reshape(
*grad_pre_hidden_bak, grad_pre_hidden_bak->dims().size() - 1);
pre_hidden_grad.device(place) =
(1 - mask_broadcast) * pre_hidden_bak_grad +
pre_hidden_grad * mask_broadcast;
if (grad_pre_state) {
auto pre_state_grad = framework::EigenMatrix<T>::Reshape(
*grad_pre_state, grad_pre_state->dims().size() - 1);
auto pre_state_bak_grad = framework::EigenMatrix<T>::Reshape(
*grad_pre_state_bak, grad_pre_state_bak->dims().size() - 1);
pre_state_grad.device(place) =
(1 - mask_broadcast) * pre_state_bak_grad +
pre_state_grad * mask_broadcast;
}
}
}
virtual void update_pre_hidden_grad( virtual void update_pre_hidden_grad(
const framework::ExecutionContext& context, Tensor* grad_gate, const framework::ExecutionContext& context, Tensor* grad_gate,
const Tensor* weight_hh, Tensor* grad_pre_hidden, const Tensor* weight_hh, Tensor* grad_pre_hidden,
Tensor* grad_pre_hidden_bak, Tensor* grad_pre_state, Tensor* grad_pre_hidden_bak, Tensor* grad_pre_state,
Tensor* grad_pre_state_bak, Tensor* grad_gate_buf, Tensor* grad_pre_state_bak, const Tensor& mask_tensor,
const Tensor& mask_tensor, bool has_sequence_length) const { bool has_sequence_length) const {
auto& device_ctx = auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>(); context.template device_context<platform::CPUDeviceContext>();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(device_ctx); auto blas = math::GetBlas<platform::CPUDeviceContext, T>(device_ctx);
T beta = 0;
Tensor* grad_gate_tmp = grad_gate; Tensor* grad_gate_tmp = grad_gate;
if (is_gru(context)) {
beta = 1.0;
grad_gate_tmp = grad_gate_buf;
}
auto mat_dim_a = auto mat_dim_a =
math::CreateMatrixDescriptor(grad_gate_tmp->dims(), 0, false); math::CreateMatrixDescriptor(grad_gate_tmp->dims(), 0, false);
mat_dim_a.height_ *= mat_dim_a.batch_size_; mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0; mat_dim_a.batch_size_ = 0;
auto mat_dim_b = math::CreateMatrixDescriptor(weight_hh->dims(), 0, false); auto mat_dim_b = math::CreateMatrixDescriptor(weight_hh->dims(), 0, false);
blas.MatMul(*grad_gate_tmp, mat_dim_a, *weight_hh, mat_dim_b, blas.MatMul(*grad_gate_tmp, mat_dim_a, *weight_hh, mat_dim_b,
static_cast<T>(1.0), grad_pre_hidden, beta); static_cast<T>(1.0), grad_pre_hidden, 0);
postprocess_pre_hidden_grad(context, grad_pre_hidden, grad_pre_hidden_bak,
if (has_sequence_length) { grad_pre_state, grad_pre_state_bak, mask_tensor,
auto& place = has_sequence_length);
*context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto eigen_mask = framework::EigenMatrix<T>::From(
mask_tensor, framework::make_ddim({mask_tensor.dims()[1], 1}));
auto eigen_mask_broadcast = eigen_mask.broadcast(
Eigen::DSizes<int, 2>(1, grad_pre_hidden->dims()[2]));
auto eigen_grad_pre_hidden = framework::EigenMatrix<T>::Reshape(
*grad_pre_hidden, grad_pre_hidden->dims().size() - 1);
auto eigen_grad_pre_hidden_bak = framework::EigenMatrix<T>::Reshape(
*grad_pre_hidden_bak, grad_pre_hidden_bak->dims().size() - 1);
eigen_grad_pre_hidden.device(place) =
(1 - eigen_mask_broadcast) * eigen_grad_pre_hidden_bak +
eigen_grad_pre_hidden * eigen_mask_broadcast;
if (grad_pre_state) {
auto eigen_grad_pre_state = framework::EigenMatrix<T>::Reshape(
*grad_pre_state, grad_pre_state->dims().size() - 1);
auto eigen_grad_pre_state_bak = framework::EigenMatrix<T>::Reshape(
*grad_pre_state_bak, grad_pre_state_bak->dims().size() - 1);
eigen_grad_pre_state.device(place) =
(1 - eigen_mask_broadcast) * eigen_grad_pre_state_bak +
eigen_grad_pre_state * eigen_mask_broadcast;
}
}
} }
virtual void update_weight_hh_grad(const framework::ExecutionContext& context, virtual void update_weight_hh_grad(const framework::ExecutionContext& context,
Tensor* grad_gate, Tensor* pre_hidden, Tensor* grad_gate, Tensor* pre_hidden,
Tensor* grad_weight_hh, Tensor* grad_weight_hh) const {
Tensor* grad_gate_buf) const {
auto& device_ctx = auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>(); context.template device_context<platform::CPUDeviceContext>();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(device_ctx); auto blas = math::GetBlas<platform::CPUDeviceContext, T>(device_ctx);
...@@ -1667,11 +1605,7 @@ struct GradCell { ...@@ -1667,11 +1605,7 @@ struct GradCell {
auto mat_dim_d = math::CreateMatrixDescriptor(pre_hidden->dims(), 0, false); auto mat_dim_d = math::CreateMatrixDescriptor(pre_hidden->dims(), 0, false);
mat_dim_d.height_ *= mat_dim_d.batch_size_; mat_dim_d.height_ *= mat_dim_d.batch_size_;
mat_dim_d.batch_size_ = 0; mat_dim_d.batch_size_ = 0;
Tensor* grad_gate_tmp = grad_gate; blas.MatMul(*grad_gate, mat_dim_c, *pre_hidden, mat_dim_d,
if (is_gru(context)) {
grad_gate_tmp = grad_gate_buf;
}
blas.MatMul(*grad_gate_tmp, mat_dim_c, *pre_hidden, mat_dim_d,
static_cast<T>(1.0), grad_weight_hh, static_cast<T>(1.0)); static_cast<T>(1.0), grad_weight_hh, static_cast<T>(1.0));
} }
}; };
...@@ -1685,8 +1619,7 @@ struct SimpleRNNGradCell : GradCell<T> { ...@@ -1685,8 +1619,7 @@ struct SimpleRNNGradCell : GradCell<T> {
Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state, Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state,
Tensor* grad_gate, Tensor* grad_weight_hh, Tensor* grad_gate, Tensor* grad_weight_hh,
Tensor* grad_pre_hidden, Tensor* grad_pre_state, Tensor* grad_pre_hidden, Tensor* grad_pre_state,
Tensor* grad_gate_buf, Tensor* grad_bias_hh, Tensor* grad_bias_hh, const Tensor& mask_tensor,
const Tensor& mask_tensor,
bool has_sequence_length) const override { bool has_sequence_length) const override {
auto& device_ctx = auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>(); context.template device_context<platform::CPUDeviceContext>();
...@@ -1711,11 +1644,10 @@ struct SimpleRNNGradCell : GradCell<T> { ...@@ -1711,11 +1644,10 @@ struct SimpleRNNGradCell : GradCell<T> {
functor(*place, z, h, dh, dz); functor(*place, z, h, dh, dz);
// update grad_weight_hh, grad_pre_hidden // update grad_weight_hh, grad_pre_hidden
this->update_pre_hidden_grad( this->update_pre_hidden_grad(context, grad_gate, weight_hh, grad_pre_hidden,
context, grad_gate, weight_hh, grad_pre_hidden, &grad_pre_hidden_bak, &grad_pre_hidden_bak, nullptr, nullptr,
nullptr, nullptr, grad_gate_buf, mask_tensor, has_sequence_length); mask_tensor, has_sequence_length);
this->update_weight_hh_grad(context, grad_gate, pre_hidden, grad_weight_hh, this->update_weight_hh_grad(context, grad_gate, pre_hidden, grad_weight_hh);
grad_gate_buf);
} }
}; };
...@@ -1728,8 +1660,7 @@ struct GRUGradCell : GradCell<T> { ...@@ -1728,8 +1660,7 @@ struct GRUGradCell : GradCell<T> {
Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state, Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state,
Tensor* grad_gate, Tensor* grad_weight_hh, Tensor* grad_gate, Tensor* grad_weight_hh,
Tensor* grad_pre_hidden, Tensor* grad_pre_state, Tensor* grad_pre_hidden, Tensor* grad_pre_state,
Tensor* grad_gate_buf, Tensor* grad_bias_hh, Tensor* grad_bias_hh, const Tensor& mask_tensor,
const Tensor& mask_tensor,
bool has_sequence_length) const override { bool has_sequence_length) const override {
auto& device_ctx = auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>(); context.template device_context<platform::CPUDeviceContext>();
...@@ -1747,6 +1678,8 @@ struct GRUGradCell : GradCell<T> { ...@@ -1747,6 +1678,8 @@ struct GRUGradCell : GradCell<T> {
gru_value.gate_value = gate_tensor->data<T>(); gru_value.gate_value = gate_tensor->data<T>();
gru_value.prev_out_value = pre_hidden->data<T>(); gru_value.prev_out_value = pre_hidden->data<T>();
gru_value.reset_output_value = state_tensor->data<T>(); gru_value.reset_output_value = state_tensor->data<T>();
gru_value.state_weight = weight_hh->data<T>() + 2 * frame_size * frame_size;
gru_value.gate_weight = weight_hh->data<T>();
gru_grad.gate_grad = grad_gate->data<T>(); gru_grad.gate_grad = grad_gate->data<T>();
gru_grad.reset_output_grad = grad_state->data<T>(); gru_grad.reset_output_grad = grad_state->data<T>();
...@@ -1755,7 +1688,7 @@ struct GRUGradCell : GradCell<T> { ...@@ -1755,7 +1688,7 @@ struct GRUGradCell : GradCell<T> {
gru_grad.gate_weight_grad = grad_weight_hh->data<T>(); gru_grad.gate_weight_grad = grad_weight_hh->data<T>();
gru_grad.state_weight_grad = gru_grad.state_weight_grad =
grad_weight_hh->data<T>() + 2 * frame_size * frame_size; grad_weight_hh->data<T>() + 2 * frame_size * frame_size;
gru_grad.state_bias_grad = grad_bias_hh->data<T>() + 2 * frame_size; gru_grad.bias_hh_grad = grad_bias_hh->data<T>();
auto act_gate = math::detail::GetActivationType("sigmoid_v2"); auto act_gate = math::detail::GetActivationType("sigmoid_v2");
auto act_node = math::detail::GetActivationType("tanh_v2"); auto act_node = math::detail::GetActivationType("tanh_v2");
...@@ -1763,13 +1696,9 @@ struct GRUGradCell : GradCell<T> { ...@@ -1763,13 +1696,9 @@ struct GRUGradCell : GradCell<T> {
device_ctx, gru_value, gru_grad, frame_size, batch_size, act_node, device_ctx, gru_value, gru_grad, frame_size, batch_size, act_node,
act_gate); act_gate);
make_grad_gate_buf<T>(context, grad_gate, grad_gate_buf, grad_state); this->postprocess_pre_hidden_grad(context, grad_pre_hidden,
&grad_pre_hidden_bak, nullptr, nullptr,
this->update_pre_hidden_grad( mask_tensor, has_sequence_length);
context, grad_gate, weight_hh, grad_pre_hidden, &grad_pre_hidden_bak,
nullptr, nullptr, grad_gate_buf, mask_tensor, has_sequence_length);
this->update_weight_hh_grad(context, grad_gate, pre_hidden, grad_weight_hh,
grad_gate_buf);
} }
}; };
...@@ -1782,8 +1711,7 @@ struct LSTMGradCell : GradCell<T> { ...@@ -1782,8 +1711,7 @@ struct LSTMGradCell : GradCell<T> {
Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state, Tensor* pre_state, Tensor* grad_hidden, Tensor* grad_state,
Tensor* grad_gate, Tensor* grad_weight_hh, Tensor* grad_gate, Tensor* grad_weight_hh,
Tensor* grad_pre_hidden, Tensor* grad_pre_state, Tensor* grad_pre_hidden, Tensor* grad_pre_state,
Tensor* grad_gate_buf, Tensor* grad_bias_hh, Tensor* grad_bias_hh, const Tensor& mask_tensor,
const Tensor& mask_tensor,
bool has_sequence_length) const override { bool has_sequence_length) const override {
auto& device_ctx = auto& device_ctx =
context.template device_context<platform::CPUDeviceContext>(); context.template device_context<platform::CPUDeviceContext>();
...@@ -1822,12 +1750,10 @@ struct LSTMGradCell : GradCell<T> { ...@@ -1822,12 +1750,10 @@ struct LSTMGradCell : GradCell<T> {
math::LstmUnitGradFunctor<platform::CPUDeviceContext, T>::compute( math::LstmUnitGradFunctor<platform::CPUDeviceContext, T>::compute(
device_ctx, lstm_value, lstm_grad, frame_size, batch_size, cell_clip, device_ctx, lstm_value, lstm_grad, frame_size, batch_size, cell_clip,
gate_act, state_act, cand_act, false); gate_act, state_act, cand_act, false);
this->update_pre_hidden_grad(context, grad_gate, weight_hh, grad_pre_hidden, this->update_pre_hidden_grad(
&grad_pre_hidden_bak, grad_pre_state, context, grad_gate, weight_hh, grad_pre_hidden, &grad_pre_hidden_bak,
&grad_pre_state_bak, grad_gate_buf, grad_pre_state, &grad_pre_state_bak, mask_tensor, has_sequence_length);
mask_tensor, has_sequence_length); this->update_weight_hh_grad(context, grad_gate, pre_hidden, grad_weight_hh);
this->update_weight_hh_grad(context, grad_gate, pre_hidden, grad_weight_hh,
grad_gate_buf);
} }
}; };
...@@ -2001,7 +1927,12 @@ void RnnGradFunc(const framework::ExecutionContext& context, ...@@ -2001,7 +1927,12 @@ void RnnGradFunc(const framework::ExecutionContext& context,
for (int i = num_layers - 1; i >= 0; --i) { for (int i = num_layers - 1; i >= 0; --i) {
// the layer input output had saved, just use the data // the layer input output had saved, just use the data
if (i > 0) { if (i > 0) {
layer_input.ShareDataWith(hidden_tensor_unbind[i - 1]); if (layer_input.numel() == 0) {
layer_input.Resize(hidden_tensor_unbind[i - 1].dims());
layer_input.mutable_data<T>(context.GetPlace());
}
dropout_helper<T>(context, &hidden_tensor_unbind[i - 1], &layer_input,
dropout_state, dropout_prob);
} else { } else {
layer_input.ShareDataWith(*input); layer_input.ShareDataWith(*input);
} }
......
...@@ -294,7 +294,6 @@ def unstack(array, axis=0): ...@@ -294,7 +294,6 @@ def unstack(array, axis=0):
def dropout(array, p=0.5): def dropout(array, p=0.5):
if p == 0.0: if p == 0.0:
return array return array
mask = (np.random.uniform(size=array.shape) < (1 - p)).astype(array.dtype) mask = (np.random.uniform(size=array.shape) < (1 - p)).astype(array.dtype)
return array * (mask / (1 - p)) return array * (mask / (1 - p))
...@@ -390,11 +389,12 @@ class RNNMixin(LayerListMixin): ...@@ -390,11 +389,12 @@ class RNNMixin(LayerListMixin):
states = split_states(initial_states, self.num_directions == 2, states = split_states(initial_states, self.num_directions == 2,
self.state_components) self.state_components)
final_states = [] final_states = []
input_temp = inputs
for i, rnn_layer in enumerate(self): for i, rnn_layer in enumerate(self):
if i > 0: if i > 0:
inputs = dropout(inputs, self.dropout) input_temp = dropout(inputs, self.dropout)
outputs, final_state = rnn_layer(inputs, states[i], sequence_length) outputs, final_state = rnn_layer(input_temp, states[i],
sequence_length)
final_states.append(final_state) final_states.append(final_state)
inputs = outputs inputs = outputs
......
...@@ -53,6 +53,7 @@ class TestRNNOp(OpTest): ...@@ -53,6 +53,7 @@ class TestRNNOp(OpTest):
self.is_bidirec = False self.is_bidirec = False
self.mode = "LSTM" self.mode = "LSTM"
self.is_test = False self.is_test = False
self.dropout = 0.0
self.set_attrs() self.set_attrs()
self.direction_num = 2 if self.is_bidirec else 1 self.direction_num = 2 if self.is_bidirec else 1
...@@ -76,7 +77,8 @@ class TestRNNOp(OpTest): ...@@ -76,7 +77,8 @@ class TestRNNOp(OpTest):
hidden_size, hidden_size,
num_layers=self.num_layers, num_layers=self.num_layers,
time_major=True, time_major=True,
direction=direction) direction=direction,
dropout=self.dropout)
flat_w = get_params_for_net(rnn1) flat_w = get_params_for_net(rnn1)
output, (last_hidden, last_cell) = rnn1( output, (last_hidden, last_cell) = rnn1(
...@@ -101,7 +103,7 @@ class TestRNNOp(OpTest): ...@@ -101,7 +103,7 @@ class TestRNNOp(OpTest):
'PreState': [('init_h', init_h), ('init_c', init_c)], 'PreState': [('init_h', init_h), ('init_c', init_c)],
} }
self.attrs = { self.attrs = {
'dropout_prob': 0.0, 'dropout_prob': self.dropout,
'is_bidirec': self.is_bidirec, 'is_bidirec': self.is_bidirec,
'input_size': input_size, 'input_size': input_size,
'hidden_size': hidden_size, 'hidden_size': hidden_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册