未验证 提交 66cf8b08 编写于 作者: Z zyfncg 提交者: GitHub

[Phi] Move Rnn Op from fluid to phi (#41007)

* move rnn kernel to phi

* move infershape of rnn to phi

* fix HIP bug

* rename function

* fix HIP bug

* fix hip bug
上级 59c4fdac
...@@ -396,7 +396,7 @@ class GRUCPUKernel : public framework::OpKernel<T> { ...@@ -396,7 +396,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
frame_size * 2, T(1), gru_value.gate_value, frame_size * 3); frame_size * 2, T(1), gru_value.gate_value, frame_size * 3);
} }
phi::funcs::detail::forward_reset_output( phi::funcs::detail::forward_reset_output<DeviceContext>(
phi::funcs::detail::forward::gru_resetOutput<T>(), gru_value, phi::funcs::detail::forward::gru_resetOutput<T>(), gru_value,
frame_size, cur_batch_size, active_gate); frame_size, cur_batch_size, active_gate);
...@@ -408,7 +408,7 @@ class GRUCPUKernel : public framework::OpKernel<T> { ...@@ -408,7 +408,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
frame_size * 3); frame_size * 3);
} }
phi::funcs::detail::forward_final_output( phi::funcs::detail::forward_final_output<DeviceContext>(
phi::funcs::detail::forward::gru_finalOutput<T>(), gru_value, phi::funcs::detail::forward::gru_finalOutput<T>(), gru_value,
frame_size, cur_batch_size, active_node, origin_mode); frame_size, cur_batch_size, active_node, origin_mode);
......
...@@ -12,11 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,11 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/rnn_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -25,69 +27,6 @@ class RNNOp : public framework::OperatorWithKernel { ...@@ -25,69 +27,6 @@ class RNNOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "RNN");
OP_INOUT_CHECK(ctx->HasInputs("PreState"), "Input", "PreState", "RNN");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "RNN");
OP_INOUT_CHECK(ctx->HasOutputs("State"), "Output", "State", "RNN");
auto in_dims = ctx->GetInputDim("Input");
auto pre_state_dims = ctx->GetInputsDim("PreState");
PADDLE_ENFORCE_EQ(in_dims.size(), 3,
platform::errors::InvalidArgument(
"The rank of Input in RNN must be 3. But "
"received Input's rank is %d.",
in_dims.size()));
if (ctx->HasInput("SequenceLength")) {
auto seq_dims = ctx->GetInputDim("SequenceLength");
PADDLE_ENFORCE_EQ(
in_dims[1], seq_dims[0],
platform::errors::InvalidArgument(
"The size of SequenceLength has to equal the batch_size. But "
"received batch_size is %d and the size of SequenceLength is %d.",
in_dims[1], seq_dims[0]));
}
PADDLE_ENFORCE_EQ(pre_state_dims[0].size(), 3,
platform::errors::InvalidArgument(
"The rank of PreState in RNN must be 3. But "
"the received rank is %d.",
pre_state_dims[0].size()));
size_t i = 0;
for (; i < pre_state_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(
in_dims[1], pre_state_dims[i][1],
platform::errors::InvalidArgument(
"The second dimension size (representing for batch size) of "
"Input and PreState should be equal. But received %d and %d.",
in_dims[1], pre_state_dims[i][1]));
PADDLE_ENFORCE_EQ(
pre_state_dims[0], pre_state_dims[i],
platform::errors::InvalidArgument(
"The dims of all tensors in PreState should be same. But "
"received PreState[0] is %s and PreState[%d] is %s.",
pre_state_dims[0], i, pre_state_dims[i]));
}
auto mode = ctx->Attrs().Get<std::string>("mode");
size_t num_state = mode == "LSTM" ? 2 : 1;
PADDLE_ENFORCE_EQ(
i, num_state,
platform::errors::InvalidArgument(
"The number of tensors in PreState of %s should be %d, "
"but received %d.",
mode, 2, i));
auto out_dims = in_dims;
auto hidden_size = ctx->Attrs().Get<int>("hidden_size");
bool is_bidirec = ctx->Attrs().Get<bool>("is_bidirec");
out_dims[2] = is_bidirec ? hidden_size * 2 : hidden_size;
ctx->SetOutputDim("Out", out_dims);
ctx->SetOutputsDim("State", pre_state_dims);
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -249,15 +188,11 @@ class NotImpleKernel : public framework::OpKernel<T> { ...@@ -249,15 +188,11 @@ class NotImpleKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(rnn, RnnInferShapeFunctor,
PD_INFER_META(phi::RnnInferMeta));
REGISTER_OPERATOR(rnn, ops::RNNOp, ops::RNNOpMaker, REGISTER_OPERATOR(rnn, ops::RNNOp, ops::RNNOpMaker,
ops::RNNGradOpMaker<paddle::framework::OpDesc>, ops::RNNGradOpMaker<paddle::framework::OpDesc>,
ops::RNNGradOpMaker<paddle::imperative::OpBase>); ops::RNNGradOpMaker<paddle::imperative::OpBase>,
RnnInferShapeFunctor);
REGISTER_OPERATOR(rnn_grad, ops::RNNGradOp); REGISTER_OPERATOR(rnn_grad, ops::RNNGradOp);
REGISTER_OP_CPU_KERNEL(
rnn, ops::RNNCPUKernel<paddle::platform::CPUDeviceContext, float>,
ops::RNNCPUKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
rnn_grad, ops::RNNCPUGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::RNNCPUGradKernel<paddle::platform::CPUDeviceContext, double>);
此差异已折叠。
此差异已折叠。
...@@ -647,7 +647,6 @@ void BindImperative(py::module *m_ptr) { ...@@ -647,7 +647,6 @@ void BindImperative(py::module *m_ptr) {
} else { } else {
act_name = name.cast<std::string>(); act_name = name.cast<std::string>();
} }
VLOG(4) << "Init VarBase :" << act_name;
new (&self) imperative::VarBase(act_name); new (&self) imperative::VarBase(act_name);
self.SetPersistable(persistable); self.SetPersistable(persistable);
self.SetType(type); self.SetType(type);
......
...@@ -1082,6 +1082,91 @@ void PsroiPoolInferMeta(const MetaTensor& x, ...@@ -1082,6 +1082,91 @@ void PsroiPoolInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
} }
void RnnInferMeta(const MetaTensor& x,
const std::vector<MetaTensor*>& pre_state,
const std::vector<MetaTensor*>& weight_list,
paddle::optional<const MetaTensor&> sequence_length,
float dropout_prob,
bool is_bidirec,
int input_size,
int hidden_size,
int num_layers,
const std::string& mode,
int seed,
bool is_test,
MetaTensor* out,
MetaTensor* dropout_state,
std::vector<MetaTensor*> state,
MetaTensor* reserve) {
auto in_dims = x.dims();
PADDLE_ENFORCE_EQ(
in_dims.size(),
3,
phi::errors::InvalidArgument("The rank of Input in RNN must be 3. But "
"received Input's rank is %d.",
in_dims.size()));
if (sequence_length) {
auto seq_dims = sequence_length->dims();
PADDLE_ENFORCE_EQ(
in_dims[1],
seq_dims[0],
phi::errors::InvalidArgument(
"The size of SequenceLength has to equal the batch_size. But "
"received batch_size is %d and the size of SequenceLength is %d.",
in_dims[1],
seq_dims[0]));
}
PADDLE_ENFORCE_EQ(pre_state[0]->dims().size(),
3,
phi::errors::InvalidArgument(
"The rank of PreState in RNN must be 3. But "
"the received rank is %d.",
pre_state[0]->dims().size()));
size_t i = 0;
for (; i < pre_state.size(); ++i) {
PADDLE_ENFORCE_EQ(
in_dims[1],
pre_state[i]->dims()[1],
phi::errors::InvalidArgument(
"The second dimension size (representing for batch size) of "
"Input and PreState should be equal. But received %d and %d.",
in_dims[1],
pre_state[i]->dims()[1]));
PADDLE_ENFORCE_EQ(
pre_state[0]->dims(),
pre_state[i]->dims(),
phi::errors::InvalidArgument(
"The dims of all tensors in PreState should be same. But "
"received PreState[0] is %s and PreState[%d] is %s.",
pre_state[0]->dims(),
i,
pre_state[i]->dims()));
}
size_t num_state = mode == "LSTM" ? 2 : 1;
PADDLE_ENFORCE_EQ(i,
num_state,
phi::errors::InvalidArgument(
"The number of tensors in PreState of %s should be %d, "
"but received %d.",
mode,
2,
i));
auto out_dims = in_dims;
out_dims[2] = is_bidirec ? hidden_size * 2 : hidden_size;
out->set_dims(out_dims);
out->set_dtype(x.dtype());
int state_num = pre_state.size();
for (int i = 0; i < state_num; ++i) {
state[i]->set_dims(pre_state[i]->dims());
state[i]->set_dtype(x.dtype());
}
}
void WarpctcInferMeta(const MetaTensor& logits, void WarpctcInferMeta(const MetaTensor& logits,
const MetaTensor& label, const MetaTensor& label,
const paddle::optional<const MetaTensor&> logits_length, const paddle::optional<const MetaTensor&> logits_length,
......
...@@ -214,6 +214,23 @@ void PsroiPoolInferMeta(const MetaTensor& x, ...@@ -214,6 +214,23 @@ void PsroiPoolInferMeta(const MetaTensor& x,
float spatial_scale, float spatial_scale,
MetaTensor* out); MetaTensor* out);
void RnnInferMeta(const MetaTensor& x,
const std::vector<MetaTensor*>& pre_state,
const std::vector<MetaTensor*>& weight_list,
paddle::optional<const MetaTensor&> sequence_length,
float dropout_prob,
bool is_bidirec,
int input_size,
int hidden_size,
int num_layers,
const std::string& mode,
int seed,
bool is_test,
MetaTensor* out,
MetaTensor* dropout_state,
std::vector<MetaTensor*> state,
MetaTensor* reserve);
void WarpctcInferMeta(const MetaTensor& logits, void WarpctcInferMeta(const MetaTensor& logits,
const MetaTensor& label, const MetaTensor& label,
const paddle::optional<const MetaTensor&> logits_length, const paddle::optional<const MetaTensor&> logits_length,
......
...@@ -32,7 +32,7 @@ set(MANUAL_BUILD_KERNELS adam_kernel adamw_kernel deformable_conv_kernel deforma ...@@ -32,7 +32,7 @@ set(MANUAL_BUILD_KERNELS adam_kernel adamw_kernel deformable_conv_kernel deforma
matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel
put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel
softmax_kernel softmax_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel softmax_kernel softmax_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel
triangular_solve_grad_kernel determinant_grad_kernel reduce_kernel warpctc_kernel warpctc_grad_kernel) triangular_solve_grad_kernel determinant_grad_kernel reduce_kernel rnn_kernel rnn_grad_kernel warpctc_kernel warpctc_grad_kernel)
kernel_library(adam_kernel DEPS gflags glog flags ${COMMON_KERNEL_DEPS} selected_rows_functor threadpool jit_kernel_helper) kernel_library(adam_kernel DEPS gflags glog flags ${COMMON_KERNEL_DEPS} selected_rows_functor threadpool jit_kernel_helper)
kernel_library(adamw_kernel DEPS ${COMMON_KERNEL_DEPS} adam_kernel) kernel_library(adamw_kernel DEPS ${COMMON_KERNEL_DEPS} adam_kernel)
kernel_library(deformable_conv_kernel DEPS ${COMMON_KERNEL_DEPS} deformable_conv_functor) kernel_library(deformable_conv_kernel DEPS ${COMMON_KERNEL_DEPS} deformable_conv_functor)
...@@ -58,6 +58,8 @@ kernel_library(take_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_ ...@@ -58,6 +58,8 @@ kernel_library(take_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_
kernel_library(take_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel) kernel_library(take_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(triangular_solve_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_reduce) kernel_library(triangular_solve_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_reduce)
kernel_library(determinant_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_inverse) kernel_library(determinant_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_inverse)
kernel_library(rnn_kernel DEPS ${COMMON_KERNEL_DEPS} concat_and_split_functor lstm_compute gru_compute)
kernel_library(rnn_grad_kernel DEPS ${COMMON_KERNEL_DEPS} concat_and_split_functor lstm_compute gru_compute)
kernel_library(warpctc_kernel DEPS ${COMMON_KERNEL_DEPS} phi_dynload_warpctc sequence_padding sequence_scale) kernel_library(warpctc_kernel DEPS ${COMMON_KERNEL_DEPS} phi_dynload_warpctc sequence_padding sequence_scale)
kernel_library(warpctc_grad_kernel DEPS ${COMMON_KERNEL_DEPS} sequence_padding sequence_scale) kernel_library(warpctc_grad_kernel DEPS ${COMMON_KERNEL_DEPS} sequence_padding sequence_scale)
...@@ -73,5 +75,5 @@ copy_if_different(${kernel_declare_file} ${kernel_declare_file_final}) ...@@ -73,5 +75,5 @@ copy_if_different(${kernel_declare_file} ${kernel_declare_file_final})
# For strings kernels # For strings kernels
add_subdirectory(strings) add_subdirectory(strings)
# 5. kernel autotune # 5. kernel autotune
add_subdirectory(autotune) add_subdirectory(autotune)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/utils.h"
namespace phi {
#define DEFINE_MODE_DETECTOR(MODE_NAME, MODE_STR) \
inline bool is_##MODE_NAME(const std::string& mode) { \
return mode == #MODE_STR; \
}
DEFINE_MODE_DETECTOR(lstm, LSTM);
DEFINE_MODE_DETECTOR(gru, GRU);
DEFINE_MODE_DETECTOR(rnn_relu, RNN_RELU);
DEFINE_MODE_DETECTOR(rnn_tanh, RNN_TANH);
inline void SwapPoniter(DenseTensor** a, DenseTensor** b) {
DenseTensor* c = *a;
*a = *b;
*b = c;
}
template <typename T>
void CreateMaskMatrix(const CPUContext& dev_ctx,
const DenseTensor* sequence_length,
DenseTensor* mask_matrix,
const bool& is_reverse,
int* min_seq_len) {
const auto& seq_len_vec =
paddle::operators::GetDataFromTensor<int>(sequence_length);
const int table_width = mask_matrix->dims()[0];
DenseTensor temp =
Empty<T>(dev_ctx, {mask_matrix->dims()[1], mask_matrix->dims()[0]});
T* data_temp = temp.data<T>();
std::fill(data_temp, data_temp + mask_matrix->numel(), static_cast<T>(1.0));
*min_seq_len = table_width;
for (unsigned int i = 0; i < seq_len_vec.size(); i++) {
// reset the mask matrix
*min_seq_len = std::min(seq_len_vec[i], *min_seq_len);
if (seq_len_vec[i] == table_width) {
continue;
}
if (is_reverse) {
std::fill(data_temp + i * table_width,
data_temp + (i + 1) * table_width - seq_len_vec[i],
static_cast<T>(0));
} else {
std::fill(data_temp + i * table_width + seq_len_vec[i],
data_temp + (i + 1) * table_width,
static_cast<T>(0));
}
}
dev_ctx.Alloc<T>(mask_matrix);
std::vector<int> trans_vec;
trans_vec.emplace_back(1);
trans_vec.emplace_back(0);
funcs::TransCompute<CPUContext, T>(2, dev_ctx, temp, mask_matrix, trans_vec);
}
template <typename TensorType>
void ResetParameterVector(const std::vector<TensorType>& raw_params_vec,
int num_layers,
int gate_num,
bool is_bidirec,
std::vector<std::vector<DenseTensor>>* params_vec) {
// the parameter raw seuquence is [FWhi, FWhh, BWhi, BWhh] * num_layers
// + [FBhi, FBhh, BBhi, BBhh] * num_layers, we will reset the parameter to
// ([FWhi, FWhh, FBhi, FBhh] + [BWhi, BWhh, BBhi, BBhh]) * num_layers
const int& direction_num = is_bidirec ? 2 : 1;
const int& layer_weight_size = 4 * direction_num;
const int& all_weight_size = num_layers * layer_weight_size;
const int& bias_start_idx = all_weight_size / 2;
for (int i = 0; i < num_layers; i++) {
std::vector<DenseTensor> tensor_list;
tensor_list.reserve(layer_weight_size);
for (int j = 0; j < layer_weight_size; j++) {
DenseTensor tensor_holder;
tensor_list.emplace_back(tensor_holder);
}
for (int j = 0; j < layer_weight_size; j++) {
int k = j % 4;
const int& section = j / 4;
int tensor_idx = i * 2 * direction_num + section * 2 + k % 2;
if (k >= 2) {
tensor_idx += bias_start_idx;
}
tensor_list[j].ShareDataWith(*raw_params_vec[tensor_idx]);
}
params_vec->emplace_back(tensor_list);
}
}
template <typename T>
void DropoutHelper(const CPUContext& dev_ctx,
DenseTensor* x,
DenseTensor* y,
const DenseTensor* mask,
float dropout_prob) {
auto& place = *dev_ctx.eigen_device();
auto dropout_mask = EigenVector<uint8_t>::Flatten(*mask);
auto in = EigenVector<T>::Flatten(*x);
auto out = EigenVector<T>::Flatten(*y);
if (dropout_prob == 1.0f) {
out.device(place) = static_cast<T>(0) * in;
} else {
out.device(place) =
in * dropout_mask.cast<T>() / static_cast<T>(1.0f - dropout_prob);
}
}
template <typename T>
void DropoutCpuFunctionInplace(const CPUContext& dev_ctx,
DenseTensor* x,
DenseTensor* y,
DenseTensor* mask,
const float& dropout_prob,
const int& seed_number,
bool is_test,
bool* is_has_reset) {
if (is_test) {
return;
}
size_t size = phi::product(x->dims());
auto* mask_data = mask->data<uint8_t>();
if (!(*is_has_reset)) {
// Special case when dropout_prob is 1.0
if (dropout_prob == 1.0f) {
std::fill(mask_data, mask_data + size, static_cast<uint8_t>(0));
} else {
auto engine = paddle::framework::GetCPURandomEngine(seed_number);
std::uniform_real_distribution<float> dist(0, 1);
for (size_t i = 0; i < size; ++i) {
if (dist(*engine) < dropout_prob) {
mask_data[i] = 0;
} else {
mask_data[i] = 1;
}
}
}
*is_has_reset = true;
}
DropoutHelper<T>(dev_ctx, x, y, mask, dropout_prob);
}
template <typename Context, typename TensorType>
void SplitReserveData(const Context& dev_ctx,
int direction_num,
int time_step,
int batch_size,
int hidden_size,
int gate_num,
int num_layers,
const std::string& mode,
TensorType* reserve_data,
DenseTensor* gate_data,
DenseTensor* cell_data,
DenseTensor* cell_act_data,
DenseTensor* hidden_data) {
int gate_data_idx = gate_num * num_layers;
int cell_data_idx = (gate_num + 1) * num_layers;
int cell_act_data_idx = (gate_num + 2) * num_layers;
// simple rnn
int hidden_data_start_idx = gate_data_idx;
*gate_data = reserve_data->Slice(0, gate_data_idx);
if (is_lstm(mode)) {
*cell_data = reserve_data->Slice(gate_data_idx, cell_data_idx);
*cell_act_data = reserve_data->Slice(cell_data_idx, cell_act_data_idx);
hidden_data_start_idx = cell_act_data_idx;
} else if (is_gru(mode)) {
*cell_data = reserve_data->Slice(gate_data_idx, cell_data_idx);
hidden_data_start_idx = cell_data_idx;
}
int hidden_data_idx = hidden_data_start_idx + (num_layers - 1);
if (num_layers > 1) {
*hidden_data = reserve_data->Slice(hidden_data_start_idx, hidden_data_idx);
}
}
template <typename CellType, typename T, typename Context>
void AllocateReserveData(const Context& dev_ctx,
bool is_bidirec,
int num_layers,
int gate_num,
int hidden_size,
const std::string& mode,
DenseTensor* reserve_data,
DenseTensor* gate_data,
DenseTensor* cell_data,
DenseTensor* cell_act_data,
DenseTensor* hidden_data,
const DenseTensor* input) {
int direction_num = is_bidirec ? 2 : 1;
int time_step = input->dims()[0];
int batch_size = input->dims()[1];
int block_size = direction_num * time_step * batch_size * hidden_size;
int hidden_data_idx = (num_layers - 1);
if (is_lstm(mode)) {
hidden_data_idx += (gate_num + 2) * num_layers;
} else if (is_gru(mode)) {
hidden_data_idx += (gate_num + 1) * num_layers;
} else {
hidden_data_idx += gate_num * num_layers;
}
reserve_data->Resize({hidden_data_idx, block_size});
dev_ctx.template Alloc<T>(reserve_data);
SplitReserveData(dev_ctx,
direction_num,
time_step,
batch_size,
hidden_size,
gate_num,
num_layers,
mode,
reserve_data,
gate_data,
cell_data,
cell_act_data,
hidden_data);
}
inline std::vector<DenseTensor> Unbind(const DenseTensor& in) {
int64_t size = in.dims()[0];
std::vector<DenseTensor> tensors;
tensors.reserve(size);
for (int64_t i = 0; i < size; ++i) {
tensors.emplace_back(in.Slice(i, i + 1));
}
return tensors;
}
template <typename CellType,
template <typename, typename> class LayerT,
template <typename, typename> class SingleLayerT,
template <typename, typename> class BidirLayerT,
typename T,
typename Context>
void RnnFunc(const Context& dev_ctx,
const DenseTensor* input,
const std::vector<const DenseTensor*>& weight_list,
const DenseTensor* init_h,
const DenseTensor* init_c,
const DenseTensor* sequence_length,
DenseTensor* last_h,
DenseTensor* last_c,
DenseTensor* output,
DenseTensor* dropout_mask,
int num_layers,
int gate_num,
int input_size,
int hidden_size,
bool is_bidirec,
const std::string& cell_type,
float dropout_prob,
bool is_test,
int seed,
DenseTensor* reserve_data) {
int direction_num = is_bidirec ? 2 : 1;
const auto& init_h_dims = init_h->dims();
PADDLE_ENFORCE_EQ(init_h_dims[0],
num_layers * direction_num,
phi::errors::InvalidArgument(
"The num_layers of in RNN layer must be the same as "
"first dim of init hidden, but received"
" num_layers:%d, dim:%d",
num_layers,
init_h_dims[0]));
if (is_lstm(cell_type)) {
const auto& init_c_dims = init_c->dims();
PADDLE_ENFORCE_EQ(init_c_dims[0],
num_layers * direction_num,
phi::errors::InvalidArgument(
"The num_layers of in RNN layer must be the same as "
"first dim of cell state hidden, but received"
" num_layers:%d, dim:%d",
num_layers,
init_h_dims[0]));
}
CellType cell;
std::vector<std::vector<DenseTensor>> parameter_lists;
parameter_lists.reserve(num_layers);
ResetParameterVector(
weight_list, num_layers, gate_num, is_bidirec, &parameter_lists);
DenseTensor gate_data, cell_data, cell_act_data, hidden_data;
if (!is_test) {
AllocateReserveData<CellType, T, Context>(dev_ctx,
is_bidirec,
num_layers,
gate_num,
hidden_size,
cell_type,
reserve_data,
&gate_data,
&cell_data,
&cell_act_data,
&hidden_data,
input);
gate_data.Resize({num_layers, gate_data.numel() / num_layers});
cell_data.Resize({num_layers, cell_data.numel() / num_layers});
cell_act_data.Resize({num_layers, cell_act_data.numel() / num_layers});
if (num_layers > 1) {
hidden_data.Resize(
{num_layers - 1, hidden_data.numel() / (num_layers - 1)});
}
}
DenseTensor* input_holder;
DenseTensor* output_holder = output;
bool has_allocate_mem = false;
auto init_h_unbind = Unbind(*init_h);
auto last_h_unbind = Unbind(*last_h);
std::vector<DenseTensor> init_c_unbind, last_c_unbind;
if (is_lstm(cell_type)) {
init_c_unbind = Unbind(*init_c);
last_c_unbind = Unbind(*last_c);
}
DenseTensor curr_gate_data, curr_cell_data, curr_cell_act_data;
DenseTensor curr_hidden_data, prev_hidden_data;
DenseTensor temp;
bool has_dropout_reset = false;
for (int i = 0; i < num_layers; i++) {
if (!is_test) {
if (cell_data.numel() > 0) /** for lstm, gru **/ {
curr_cell_data = cell_data.Slice(i, i + 1);
}
if (cell_act_data.numel() > 0) /*for lstm*/ {
curr_cell_act_data = cell_act_data.Slice(i, i + 1);
}
curr_gate_data = gate_data.Slice(i, i + 1);
output_holder = output;
if (i < num_layers - 1 && num_layers > 1) {
curr_hidden_data = hidden_data.Slice(i, i + 1);
curr_hidden_data.Resize(output->dims());
output_holder = &curr_hidden_data;
}
}
if (i > 0) {
if (!has_allocate_mem) {
temp.Resize(output->dims());
dev_ctx.template Alloc<T>(&temp);
input_holder = &temp;
has_allocate_mem = true;
}
if (!is_test) {
prev_hidden_data = hidden_data.Slice(i - 1, i);
input_holder->Resize(output->dims());
if (dropout_prob != 0) {
DropoutCpuFunctionInplace<T>(dev_ctx,
&prev_hidden_data,
input_holder,
dropout_mask,
dropout_prob,
seed,
is_test,
&has_dropout_reset);
} else {
input_holder = &prev_hidden_data;
input_holder->Resize(output->dims());
}
} else {
SwapPoniter(&output_holder, &input_holder);
}
}
const DenseTensor* input_temp_holder = input;
if (i > 0) {
input_temp_holder = input_holder;
}
LayerT<T, CellType>* layer;
SingleLayerT<T, CellType> slayer(cell);
BidirLayerT<T, CellType> blayer(cell);
if (is_bidirec) {
layer = &blayer;
} else {
layer = &slayer;
}
(*layer)(dev_ctx,
input_temp_holder,
parameter_lists[i],
init_h_unbind,
init_c_unbind,
sequence_length,
last_h_unbind,
last_c_unbind,
output_holder,
i,
gate_num,
&curr_gate_data,
&curr_cell_data,
&curr_cell_act_data,
cell_type,
is_test);
}
if (num_layers % 2 == 0) {
Copy(dev_ctx, *output_holder, dev_ctx.GetPlace(), false, output);
}
}
} // namespace phi
此差异已折叠。
此差异已折叠。
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <type_traits> #include <type_traits>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/activation_op.h" #include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/funcs/detail/activation_functions.h" #include "paddle/phi/kernels/funcs/detail/activation_functions.h"
#include "paddle/phi/kernels/funcs/gru_compute.h" #include "paddle/phi/kernels/funcs/gru_compute.h"
...@@ -283,11 +283,10 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, ...@@ -283,11 +283,10 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
#endif #endif
} }
template <typename T> template <typename T, typename Context>
inline void forward_reset_outputV2( inline void forward_reset_outputV2(const Context &context,
const paddle::platform::CPUDeviceContext &context, phi::funcs::GRUMetaValue<T> value,
phi::funcs::GRUMetaValue<T> value, int frame_size) {
int frame_size) {
auto &place = *context.eigen_device(); auto &place = *context.eigen_device();
auto value_reset_gate = auto value_reset_gate =
typename EigenVector<T>::Type(value.gate_value, Array1(frame_size)); typename EigenVector<T>::Type(value.gate_value, Array1(frame_size));
...@@ -297,23 +296,20 @@ inline void forward_reset_outputV2( ...@@ -297,23 +296,20 @@ inline void forward_reset_outputV2(
value.reset_output_value, Array1(frame_size)); value.reset_output_value, Array1(frame_size));
auto value_reset_bias = auto value_reset_bias =
typename EigenVector<T>::ConstType(value.reset_bias, Array1(frame_size)); typename EigenVector<T>::ConstType(value.reset_bias, Array1(frame_size));
paddle::operators::SigmoidFunctor<T>()( SigmoidFunctor<T>()(place, value_reset_gate, value_reset_gate);
place, value_reset_gate, value_reset_gate); SigmoidFunctor<T>()(place, value_update_gate, value_update_gate);
paddle::operators::SigmoidFunctor<T>()(
place, value_update_gate, value_update_gate);
value_reset_output.device(place) = value_reset_output.device(place) =
(value_reset_output + value_reset_bias) * value_reset_gate; (value_reset_output + value_reset_bias) * value_reset_gate;
} }
template <class OpResetOutput, typename T> template <typename Context, class OpResetOutput, typename T>
inline void forward_reset_output( inline void forward_reset_output(OpResetOutput op_reset_output,
OpResetOutput op_reset_output, phi::funcs::GRUMetaValue<T> value,
phi::funcs::GRUMetaValue<T> value, int frame_size,
int frame_size, int batch_size,
int batch_size, ActivationType active_gate,
ActivationType active_gate, bool old_version = true,
bool old_version = true, const Context *context = nullptr) {
const paddle::platform::CPUDeviceContext *context = nullptr) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
if (!old_version) { if (!old_version) {
// use eigen // use eigen
...@@ -348,11 +344,10 @@ inline void forward_reset_output( ...@@ -348,11 +344,10 @@ inline void forward_reset_output(
} }
} }
template <typename T> template <typename T, typename Context>
inline void forward_final_outputV2( inline void forward_final_outputV2(const Context &context,
const paddle::platform::CPUDeviceContext &context, phi::funcs::GRUMetaValue<T> value,
phi::funcs::GRUMetaValue<T> value, int frame_size) {
int frame_size) {
auto &place = *context.eigen_device(); auto &place = *context.eigen_device();
auto value_update_gate = typename EigenVector<T>::Type( auto value_update_gate = typename EigenVector<T>::Type(
value.gate_value + frame_size, Array1(frame_size)); value.gate_value + frame_size, Array1(frame_size));
...@@ -360,8 +355,7 @@ inline void forward_final_outputV2( ...@@ -360,8 +355,7 @@ inline void forward_final_outputV2(
value.gate_value + 2 * frame_size, Array1(frame_size)); value.gate_value + 2 * frame_size, Array1(frame_size));
auto value_output = auto value_output =
typename EigenVector<T>::Type(value.output_value, Array1(frame_size)); typename EigenVector<T>::Type(value.output_value, Array1(frame_size));
paddle::operators::TanhFunctor<T>()( TanhFunctor<T>()(place, value_frame_state, value_frame_state);
place, value_frame_state, value_frame_state);
value_output.device(place) = value_output.device(place) =
(static_cast<T>(1.0) - value_update_gate) * value_frame_state; (static_cast<T>(1.0) - value_update_gate) * value_frame_state;
if (value.prev_out_value) { if (value.prev_out_value) {
...@@ -372,16 +366,15 @@ inline void forward_final_outputV2( ...@@ -372,16 +366,15 @@ inline void forward_final_outputV2(
} }
} }
template <class OpFinalOutput, typename T> template <typename Context, class OpFinalOutput, typename T>
inline void forward_final_output( inline void forward_final_output(OpFinalOutput op_final_output,
OpFinalOutput op_final_output, phi::funcs::GRUMetaValue<T> value,
phi::funcs::GRUMetaValue<T> value, int frame_size,
int frame_size, int batch_size,
int batch_size, ActivationType active_node,
ActivationType active_node, bool origin_mode,
bool origin_mode, bool old_version = true,
bool old_version = true, const Context *context = nullptr) {
const paddle::platform::CPUDeviceContext *context = nullptr) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
if (!old_version) { if (!old_version) {
// eigen // eigen
...@@ -871,8 +864,8 @@ inline void backward_reset_grad(OpResetGrad op_reset_grad, ...@@ -871,8 +864,8 @@ inline void backward_reset_grad(OpResetGrad op_reset_grad,
} }
} }
template <typename T> template <typename T, typename Context>
inline void gru_backward(const paddle::platform::CPUDeviceContext &context, inline void gru_backward(const Context &context,
phi::funcs::GRUMetaValue<T> value, phi::funcs::GRUMetaValue<T> value,
phi::funcs::GRUMetaGrad<T> grad, phi::funcs::GRUMetaGrad<T> grad,
int frame_size) { int frame_size) {
...@@ -901,14 +894,13 @@ inline void gru_backward(const paddle::platform::CPUDeviceContext &context, ...@@ -901,14 +894,13 @@ inline void gru_backward(const paddle::platform::CPUDeviceContext &context,
if (value.prev_out_value) { if (value.prev_out_value) {
auto value_prev_out = typename EigenVector<T>::ConstType( auto value_prev_out = typename EigenVector<T>::ConstType(
value.prev_out_value, Array1(frame_size)); value.prev_out_value, Array1(frame_size));
paddle::operators::SigmoidGradFunctor<T>()( SigmoidGradFunctor<T>()(place,
place, 1 /*useless*/,
1 /*useless*/, value_update_gate,
value_update_gate, (value_prev_out - value_frame_state) * grad_output,
(value_prev_out - value_frame_state) * grad_output, grad_update_gate);
grad_update_gate);
} else { } else {
paddle::operators::SigmoidGradFunctor<T>()( SigmoidGradFunctor<T>()(
place, place,
1 /*useless*/, 1 /*useless*/,
value_update_gate, value_update_gate,
...@@ -921,13 +913,12 @@ inline void gru_backward(const paddle::platform::CPUDeviceContext &context, ...@@ -921,13 +913,12 @@ inline void gru_backward(const paddle::platform::CPUDeviceContext &context,
grad_prev_out.device(place) = grad_prev_out.device(place) =
grad_prev_out + grad_output * value_update_gate; grad_prev_out + grad_output * value_update_gate;
} }
paddle::operators::TanhGradFunctor<T>()( TanhGradFunctor<T>()(place,
place, 1 /*useless*/,
1 /*useless*/, value_frame_state,
value_frame_state, grad_output * (static_cast<T>(1.0) - value_update_gate),
grad_output * (static_cast<T>(1.0) - value_update_gate), grad_frame_state);
grad_frame_state); SigmoidGradFunctor<T>()(
paddle::operators::SigmoidGradFunctor<T>()(
place, place,
1 /*useless*/, 1 /*useless*/,
value_reset_gate, value_reset_gate,
...@@ -938,8 +929,8 @@ inline void gru_backward(const paddle::platform::CPUDeviceContext &context, ...@@ -938,8 +929,8 @@ inline void gru_backward(const paddle::platform::CPUDeviceContext &context,
} }
} }
template <class OpGruGrad, typename T> template <class OpGruGrad, typename T, typename Context>
inline void cpu_gru_backward(const paddle::platform::CPUDeviceContext &context, inline void cpu_gru_backward(const Context &context,
OpGruGrad op_gru_grad, OpGruGrad op_gru_grad,
phi::funcs::GRUMetaValue<T> value, phi::funcs::GRUMetaValue<T> value,
phi::funcs::GRUMetaGrad<T> grad, phi::funcs::GRUMetaGrad<T> grad,
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <type_traits> #include <type_traits>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/activation_op.h" #include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/funcs/detail/activation_functions.h" #include "paddle/phi/kernels/funcs/detail/activation_functions.h"
#include "paddle/phi/kernels/funcs/lstm_compute.h" #include "paddle/phi/kernels/funcs/lstm_compute.h"
...@@ -409,11 +409,10 @@ void avx_lstm_backward_one_sequence(Op op, ...@@ -409,11 +409,10 @@ void avx_lstm_backward_one_sequence(Op op,
#endif #endif
} }
template <class T> template <class T, class Context>
void eigen_lstm_forward_one_sequence( void eigen_lstm_forward_one_sequence(const Context &context,
const paddle::platform::CPUDeviceContext &context, phi::funcs::LstmMetaValue<T> value,
phi::funcs::LstmMetaValue<T> value, int frame_size) {
int frame_size) {
auto eigen_value_ig = auto eigen_value_ig =
typename EigenVector<T>::Type(value.gate_value, Array1(frame_size)); typename EigenVector<T>::Type(value.gate_value, Array1(frame_size));
auto eigen_value_fg = typename EigenVector<T>::Type( auto eigen_value_fg = typename EigenVector<T>::Type(
...@@ -430,10 +429,10 @@ void eigen_lstm_forward_one_sequence( ...@@ -430,10 +429,10 @@ void eigen_lstm_forward_one_sequence(
typename EigenVector<T>::Type(value.output_value, Array1(frame_size)); typename EigenVector<T>::Type(value.output_value, Array1(frame_size));
auto &place = *context.eigen_device(); auto &place = *context.eigen_device();
paddle::operators::TanhFunctor<T>()(place, eigen_value_in, eigen_value_in); TanhFunctor<T>()(place, eigen_value_in, eigen_value_in);
paddle::operators::SigmoidFunctor<T>()(place, eigen_value_ig, eigen_value_ig); SigmoidFunctor<T>()(place, eigen_value_ig, eigen_value_ig);
paddle::operators::SigmoidFunctor<T>()(place, eigen_value_fg, eigen_value_fg); SigmoidFunctor<T>()(place, eigen_value_fg, eigen_value_fg);
paddle::operators::SigmoidFunctor<T>()(place, eigen_value_og, eigen_value_og); SigmoidFunctor<T>()(place, eigen_value_og, eigen_value_og);
eigen_state.device(place) = eigen_value_in * eigen_value_ig; eigen_state.device(place) = eigen_value_in * eigen_value_ig;
if (value.prev_state_value) { if (value.prev_state_value) {
...@@ -442,16 +441,15 @@ void eigen_lstm_forward_one_sequence( ...@@ -442,16 +441,15 @@ void eigen_lstm_forward_one_sequence(
eigen_state.device(place) = eigen_state + eigen_prev_state * eigen_value_fg; eigen_state.device(place) = eigen_state + eigen_prev_state * eigen_value_fg;
} }
paddle::operators::TanhFunctor<T>()(place, eigen_state, eigen_state_act); TanhFunctor<T>()(place, eigen_state, eigen_state_act);
eigen_output.device(place) = eigen_value_og * eigen_state_act; eigen_output.device(place) = eigen_value_og * eigen_state_act;
} }
template <class T> template <class T, class Context>
void eigen_lstm_backward_one_sequence( void eigen_lstm_backward_one_sequence(const Context &context,
const paddle::platform::CPUDeviceContext &context, phi::funcs::LstmMetaValue<T> value,
phi::funcs::LstmMetaValue<T> value, phi::funcs::LstmMetaGrad<T> grad,
phi::funcs::LstmMetaGrad<T> grad, int frame_size) {
int frame_size) {
auto eigen_value_ig = auto eigen_value_ig =
typename EigenVector<T>::Type(value.gate_value, Array1(frame_size)); typename EigenVector<T>::Type(value.gate_value, Array1(frame_size));
auto eigen_value_fg = typename EigenVector<T>::Type( auto eigen_value_fg = typename EigenVector<T>::Type(
...@@ -477,38 +475,35 @@ void eigen_lstm_backward_one_sequence( ...@@ -477,38 +475,35 @@ void eigen_lstm_backward_one_sequence(
typename EigenVector<T>::Type(grad.state_grad, Array1(frame_size)); typename EigenVector<T>::Type(grad.state_grad, Array1(frame_size));
auto &place = *context.eigen_device(); auto &place = *context.eigen_device();
paddle::operators::SigmoidGradFunctor<T>()( SigmoidGradFunctor<T>()(place,
place, 1 /*useless*/,
1 /*useless*/, eigen_value_og,
eigen_value_og, eigen_grad_output * eigen_state_act,
eigen_grad_output * eigen_state_act, eigen_grad_og);
eigen_grad_og);
eigen_grad_state.device(place) = eigen_grad_state.device(place) =
eigen_grad_state + eigen_grad_state +
eigen_grad_output * eigen_value_og * eigen_grad_output * eigen_value_og *
(static_cast<T>(1) - eigen_state_act * eigen_state_act); (static_cast<T>(1) - eigen_state_act * eigen_state_act);
paddle::operators::TanhGradFunctor<T>()(place, TanhGradFunctor<T>()(place,
1, 1,
eigen_value_in, eigen_value_in,
eigen_grad_state * eigen_value_ig, eigen_grad_state * eigen_value_ig,
eigen_grad_in); eigen_grad_in);
paddle::operators::SigmoidGradFunctor<T>()(place, SigmoidGradFunctor<T>()(place,
1, 1,
eigen_value_ig, eigen_value_ig,
eigen_grad_state * eigen_value_in, eigen_grad_state * eigen_value_in,
eigen_grad_ig); eigen_grad_ig);
if (value.prev_state_value) { if (value.prev_state_value) {
auto eigen_prev_state = typename EigenVector<T>::ConstType( auto eigen_prev_state = typename EigenVector<T>::ConstType(
value.prev_state_value, Array1(frame_size)); value.prev_state_value, Array1(frame_size));
paddle::operators::SigmoidGradFunctor<T>()( SigmoidGradFunctor<T>()(place,
place, 1,
1, eigen_value_fg,
eigen_value_fg, eigen_grad_state * eigen_prev_state,
eigen_grad_state * eigen_prev_state, eigen_grad_fg);
eigen_grad_fg);
} else { } else {
paddle::operators::SigmoidGradFunctor<T>()( SigmoidGradFunctor<T>()(place, 1, eigen_value_fg, 0, eigen_grad_fg);
place, 1, eigen_value_fg, 0, eigen_grad_fg);
} }
if (grad.prev_state_grad) { if (grad.prev_state_grad) {
auto eigen_grad_pre_state = auto eigen_grad_pre_state =
...@@ -517,8 +512,8 @@ void eigen_lstm_backward_one_sequence( ...@@ -517,8 +512,8 @@ void eigen_lstm_backward_one_sequence(
} }
} }
template <class T, class Op> template <class T, class Op, class Context>
void cpu_lstm_forward(const paddle::platform::CPUDeviceContext &context, void cpu_lstm_forward(const Context &context,
Op op, Op op,
phi::funcs::LstmMetaValue<T> value, phi::funcs::LstmMetaValue<T> value,
int frame_size, int frame_size,
...@@ -552,8 +547,8 @@ void cpu_lstm_forward(const paddle::platform::CPUDeviceContext &context, ...@@ -552,8 +547,8 @@ void cpu_lstm_forward(const paddle::platform::CPUDeviceContext &context,
} }
} }
template <class T, class Op> template <class T, class Op, class Context>
void cpu_lstm_backward(const paddle::platform::CPUDeviceContext &context, void cpu_lstm_backward(const Context &context,
Op op, Op op,
phi::funcs::LstmMetaValue<T> value, phi::funcs::LstmMetaValue<T> value,
phi::funcs::LstmMetaGrad<T> grad, phi::funcs::LstmMetaGrad<T> grad,
......
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/phi/kernels/funcs/lstm_compute.h" #include "paddle/phi/kernels/funcs/lstm_compute.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/kernels/funcs/detail/lstm_cpu_kernel.h" #include "paddle/phi/kernels/funcs/detail/lstm_cpu_kernel.h"
#include "paddle/phi/kernels/funcs/detail/lstm_kernel.h" #include "paddle/phi/kernels/funcs/detail/lstm_kernel.h"
...@@ -51,6 +53,38 @@ struct LstmUnitFunctor<paddle::platform::CPUDeviceContext, T> { ...@@ -51,6 +53,38 @@ struct LstmUnitFunctor<paddle::platform::CPUDeviceContext, T> {
} }
}; };
template <class T>
struct LstmUnitFunctor<CPUContext, T> {
static void compute(const CPUContext& context,
LstmMetaValue<T> value,
int frame_size,
int batch_size,
T cell_clip,
const phi::funcs::detail::ActivationType& gate_act,
const phi::funcs::detail::ActivationType& cell_act,
const phi::funcs::detail::ActivationType& cand_act,
bool old_api_version = true) {
for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_forward(context,
phi::funcs::detail::forward::lstm<T>(),
value,
frame_size,
cell_clip,
cand_act,
gate_act,
cell_act,
old_api_version);
value.gate_value += frame_size * 4;
value.state_value += frame_size;
value.state_active_value += frame_size;
value.output_value += frame_size;
if (value.prev_state_value) {
value.prev_state_value += frame_size;
}
}
}
};
template <class T> template <class T>
struct LstmUnitGradFunctor<paddle::platform::CPUDeviceContext, T> { struct LstmUnitGradFunctor<paddle::platform::CPUDeviceContext, T> {
static void compute(const paddle::platform::CPUDeviceContext& context, static void compute(const paddle::platform::CPUDeviceContext& context,
...@@ -94,10 +128,58 @@ struct LstmUnitGradFunctor<paddle::platform::CPUDeviceContext, T> { ...@@ -94,10 +128,58 @@ struct LstmUnitGradFunctor<paddle::platform::CPUDeviceContext, T> {
} }
}; };
template <class T>
struct LstmUnitGradFunctor<CPUContext, T> {
static void compute(const CPUContext& context,
LstmMetaValue<T> value,
LstmMetaGrad<T> grad,
int frame_size,
int batch_size,
T cell_clip,
const phi::funcs::detail::ActivationType& gate_act,
const phi::funcs::detail::ActivationType& cell_act,
const phi::funcs::detail::ActivationType& cand_act,
bool old_api_version = true) {
for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_backward(context,
phi::funcs::detail::backward::lstm<T>(),
value,
grad,
frame_size,
cell_clip,
cand_act,
gate_act,
cell_act,
old_api_version);
value.gate_value += frame_size * 4;
value.state_value += frame_size;
value.state_active_value += frame_size;
value.output_value += frame_size;
if (value.prev_state_value) {
value.prev_state_value += frame_size;
}
grad.gate_grad += frame_size * 4;
grad.state_grad += frame_size;
grad.state_active_grad += frame_size;
grad.output_grad += frame_size;
if (grad.prev_state_grad) {
grad.prev_state_grad += frame_size;
}
}
}
};
template class LstmUnitFunctor<paddle::platform::CPUDeviceContext, float>; template class LstmUnitFunctor<paddle::platform::CPUDeviceContext, float>;
template class LstmUnitFunctor<paddle::platform::CPUDeviceContext, double>; template class LstmUnitFunctor<paddle::platform::CPUDeviceContext, double>;
template class LstmUnitGradFunctor<paddle::platform::CPUDeviceContext, float>; template class LstmUnitGradFunctor<paddle::platform::CPUDeviceContext, float>;
template class LstmUnitGradFunctor<paddle::platform::CPUDeviceContext, double>; template class LstmUnitGradFunctor<paddle::platform::CPUDeviceContext, double>;
template class LstmUnitFunctor<CPUContext, float>;
template class LstmUnitFunctor<CPUContext, double>;
template class LstmUnitGradFunctor<CPUContext, float>;
template class LstmUnitGradFunctor<CPUContext, double>;
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册