未验证 提交 66cf8b08 编写于 作者: Z zyfncg 提交者: GitHub

[Phi] Move Rnn Op from fluid to phi (#41007)

* move rnn kernel to phi

* move infershape of rnn to phi

* fix HIP bug

* rename function

* fix HIP bug

* fix hip bug
上级 59c4fdac
......@@ -396,7 +396,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
frame_size * 2, T(1), gru_value.gate_value, frame_size * 3);
}
phi::funcs::detail::forward_reset_output(
phi::funcs::detail::forward_reset_output<DeviceContext>(
phi::funcs::detail::forward::gru_resetOutput<T>(), gru_value,
frame_size, cur_batch_size, active_gate);
......@@ -408,7 +408,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
frame_size * 3);
}
phi::funcs::detail::forward_final_output(
phi::funcs::detail::forward_final_output<DeviceContext>(
phi::funcs::detail::forward::gru_finalOutput<T>(), gru_value,
frame_size, cur_batch_size, active_node, origin_mode);
......
......@@ -12,11 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/rnn_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
......@@ -25,69 +27,6 @@ class RNNOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "RNN");
OP_INOUT_CHECK(ctx->HasInputs("PreState"), "Input", "PreState", "RNN");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "RNN");
OP_INOUT_CHECK(ctx->HasOutputs("State"), "Output", "State", "RNN");
auto in_dims = ctx->GetInputDim("Input");
auto pre_state_dims = ctx->GetInputsDim("PreState");
PADDLE_ENFORCE_EQ(in_dims.size(), 3,
platform::errors::InvalidArgument(
"The rank of Input in RNN must be 3. But "
"received Input's rank is %d.",
in_dims.size()));
if (ctx->HasInput("SequenceLength")) {
auto seq_dims = ctx->GetInputDim("SequenceLength");
PADDLE_ENFORCE_EQ(
in_dims[1], seq_dims[0],
platform::errors::InvalidArgument(
"The size of SequenceLength has to equal the batch_size. But "
"received batch_size is %d and the size of SequenceLength is %d.",
in_dims[1], seq_dims[0]));
}
PADDLE_ENFORCE_EQ(pre_state_dims[0].size(), 3,
platform::errors::InvalidArgument(
"The rank of PreState in RNN must be 3. But "
"the received rank is %d.",
pre_state_dims[0].size()));
size_t i = 0;
for (; i < pre_state_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(
in_dims[1], pre_state_dims[i][1],
platform::errors::InvalidArgument(
"The second dimension size (representing for batch size) of "
"Input and PreState should be equal. But received %d and %d.",
in_dims[1], pre_state_dims[i][1]));
PADDLE_ENFORCE_EQ(
pre_state_dims[0], pre_state_dims[i],
platform::errors::InvalidArgument(
"The dims of all tensors in PreState should be same. But "
"received PreState[0] is %s and PreState[%d] is %s.",
pre_state_dims[0], i, pre_state_dims[i]));
}
auto mode = ctx->Attrs().Get<std::string>("mode");
size_t num_state = mode == "LSTM" ? 2 : 1;
PADDLE_ENFORCE_EQ(
i, num_state,
platform::errors::InvalidArgument(
"The number of tensors in PreState of %s should be %d, "
"but received %d.",
mode, 2, i));
auto out_dims = in_dims;
auto hidden_size = ctx->Attrs().Get<int>("hidden_size");
bool is_bidirec = ctx->Attrs().Get<bool>("is_bidirec");
out_dims[2] = is_bidirec ? hidden_size * 2 : hidden_size;
ctx->SetOutputDim("Out", out_dims);
ctx->SetOutputsDim("State", pre_state_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -249,15 +188,11 @@ class NotImpleKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(rnn, RnnInferShapeFunctor,
PD_INFER_META(phi::RnnInferMeta));
REGISTER_OPERATOR(rnn, ops::RNNOp, ops::RNNOpMaker,
ops::RNNGradOpMaker<paddle::framework::OpDesc>,
ops::RNNGradOpMaker<paddle::imperative::OpBase>);
ops::RNNGradOpMaker<paddle::imperative::OpBase>,
RnnInferShapeFunctor);
REGISTER_OPERATOR(rnn_grad, ops::RNNGradOp);
REGISTER_OP_CPU_KERNEL(
rnn, ops::RNNCPUKernel<paddle::platform::CPUDeviceContext, float>,
ops::RNNCPUKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
rnn_grad, ops::RNNCPUGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::RNNCPUGradKernel<paddle::platform::CPUDeviceContext, double>);
此差异已折叠。
此差异已折叠。
......@@ -647,7 +647,6 @@ void BindImperative(py::module *m_ptr) {
} else {
act_name = name.cast<std::string>();
}
VLOG(4) << "Init VarBase :" << act_name;
new (&self) imperative::VarBase(act_name);
self.SetPersistable(persistable);
self.SetType(type);
......
......@@ -1082,6 +1082,91 @@ void PsroiPoolInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}
void RnnInferMeta(const MetaTensor& x,
const std::vector<MetaTensor*>& pre_state,
const std::vector<MetaTensor*>& weight_list,
paddle::optional<const MetaTensor&> sequence_length,
float dropout_prob,
bool is_bidirec,
int input_size,
int hidden_size,
int num_layers,
const std::string& mode,
int seed,
bool is_test,
MetaTensor* out,
MetaTensor* dropout_state,
std::vector<MetaTensor*> state,
MetaTensor* reserve) {
auto in_dims = x.dims();
PADDLE_ENFORCE_EQ(
in_dims.size(),
3,
phi::errors::InvalidArgument("The rank of Input in RNN must be 3. But "
"received Input's rank is %d.",
in_dims.size()));
if (sequence_length) {
auto seq_dims = sequence_length->dims();
PADDLE_ENFORCE_EQ(
in_dims[1],
seq_dims[0],
phi::errors::InvalidArgument(
"The size of SequenceLength has to equal the batch_size. But "
"received batch_size is %d and the size of SequenceLength is %d.",
in_dims[1],
seq_dims[0]));
}
PADDLE_ENFORCE_EQ(pre_state[0]->dims().size(),
3,
phi::errors::InvalidArgument(
"The rank of PreState in RNN must be 3. But "
"the received rank is %d.",
pre_state[0]->dims().size()));
size_t i = 0;
for (; i < pre_state.size(); ++i) {
PADDLE_ENFORCE_EQ(
in_dims[1],
pre_state[i]->dims()[1],
phi::errors::InvalidArgument(
"The second dimension size (representing for batch size) of "
"Input and PreState should be equal. But received %d and %d.",
in_dims[1],
pre_state[i]->dims()[1]));
PADDLE_ENFORCE_EQ(
pre_state[0]->dims(),
pre_state[i]->dims(),
phi::errors::InvalidArgument(
"The dims of all tensors in PreState should be same. But "
"received PreState[0] is %s and PreState[%d] is %s.",
pre_state[0]->dims(),
i,
pre_state[i]->dims()));
}
size_t num_state = mode == "LSTM" ? 2 : 1;
PADDLE_ENFORCE_EQ(i,
num_state,
phi::errors::InvalidArgument(
"The number of tensors in PreState of %s should be %d, "
"but received %d.",
mode,
2,
i));
auto out_dims = in_dims;
out_dims[2] = is_bidirec ? hidden_size * 2 : hidden_size;
out->set_dims(out_dims);
out->set_dtype(x.dtype());
int state_num = pre_state.size();
for (int i = 0; i < state_num; ++i) {
state[i]->set_dims(pre_state[i]->dims());
state[i]->set_dtype(x.dtype());
}
}
void WarpctcInferMeta(const MetaTensor& logits,
const MetaTensor& label,
const paddle::optional<const MetaTensor&> logits_length,
......
......@@ -214,6 +214,23 @@ void PsroiPoolInferMeta(const MetaTensor& x,
float spatial_scale,
MetaTensor* out);
void RnnInferMeta(const MetaTensor& x,
const std::vector<MetaTensor*>& pre_state,
const std::vector<MetaTensor*>& weight_list,
paddle::optional<const MetaTensor&> sequence_length,
float dropout_prob,
bool is_bidirec,
int input_size,
int hidden_size,
int num_layers,
const std::string& mode,
int seed,
bool is_test,
MetaTensor* out,
MetaTensor* dropout_state,
std::vector<MetaTensor*> state,
MetaTensor* reserve);
void WarpctcInferMeta(const MetaTensor& logits,
const MetaTensor& label,
const paddle::optional<const MetaTensor&> logits_length,
......
......@@ -32,7 +32,7 @@ set(MANUAL_BUILD_KERNELS adam_kernel adamw_kernel deformable_conv_kernel deforma
matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel
put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel
softmax_kernel softmax_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel
triangular_solve_grad_kernel determinant_grad_kernel reduce_kernel warpctc_kernel warpctc_grad_kernel)
triangular_solve_grad_kernel determinant_grad_kernel reduce_kernel rnn_kernel rnn_grad_kernel warpctc_kernel warpctc_grad_kernel)
kernel_library(adam_kernel DEPS gflags glog flags ${COMMON_KERNEL_DEPS} selected_rows_functor threadpool jit_kernel_helper)
kernel_library(adamw_kernel DEPS ${COMMON_KERNEL_DEPS} adam_kernel)
kernel_library(deformable_conv_kernel DEPS ${COMMON_KERNEL_DEPS} deformable_conv_functor)
......@@ -58,6 +58,8 @@ kernel_library(take_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_
kernel_library(take_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(triangular_solve_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_reduce)
kernel_library(determinant_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_inverse)
kernel_library(rnn_kernel DEPS ${COMMON_KERNEL_DEPS} concat_and_split_functor lstm_compute gru_compute)
kernel_library(rnn_grad_kernel DEPS ${COMMON_KERNEL_DEPS} concat_and_split_functor lstm_compute gru_compute)
kernel_library(warpctc_kernel DEPS ${COMMON_KERNEL_DEPS} phi_dynload_warpctc sequence_padding sequence_scale)
kernel_library(warpctc_grad_kernel DEPS ${COMMON_KERNEL_DEPS} sequence_padding sequence_scale)
......@@ -73,5 +75,5 @@ copy_if_different(${kernel_declare_file} ${kernel_declare_file_final})
# For strings kernels
add_subdirectory(strings)
# 5. kernel autotune
# 5. kernel autotune
add_subdirectory(autotune)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/utils.h"
namespace phi {
#define DEFINE_MODE_DETECTOR(MODE_NAME, MODE_STR) \
inline bool is_##MODE_NAME(const std::string& mode) { \
return mode == #MODE_STR; \
}
DEFINE_MODE_DETECTOR(lstm, LSTM);
DEFINE_MODE_DETECTOR(gru, GRU);
DEFINE_MODE_DETECTOR(rnn_relu, RNN_RELU);
DEFINE_MODE_DETECTOR(rnn_tanh, RNN_TANH);
inline void SwapPoniter(DenseTensor** a, DenseTensor** b) {
DenseTensor* c = *a;
*a = *b;
*b = c;
}
template <typename T>
void CreateMaskMatrix(const CPUContext& dev_ctx,
const DenseTensor* sequence_length,
DenseTensor* mask_matrix,
const bool& is_reverse,
int* min_seq_len) {
const auto& seq_len_vec =
paddle::operators::GetDataFromTensor<int>(sequence_length);
const int table_width = mask_matrix->dims()[0];
DenseTensor temp =
Empty<T>(dev_ctx, {mask_matrix->dims()[1], mask_matrix->dims()[0]});
T* data_temp = temp.data<T>();
std::fill(data_temp, data_temp + mask_matrix->numel(), static_cast<T>(1.0));
*min_seq_len = table_width;
for (unsigned int i = 0; i < seq_len_vec.size(); i++) {
// reset the mask matrix
*min_seq_len = std::min(seq_len_vec[i], *min_seq_len);
if (seq_len_vec[i] == table_width) {
continue;
}
if (is_reverse) {
std::fill(data_temp + i * table_width,
data_temp + (i + 1) * table_width - seq_len_vec[i],
static_cast<T>(0));
} else {
std::fill(data_temp + i * table_width + seq_len_vec[i],
data_temp + (i + 1) * table_width,
static_cast<T>(0));
}
}
dev_ctx.Alloc<T>(mask_matrix);
std::vector<int> trans_vec;
trans_vec.emplace_back(1);
trans_vec.emplace_back(0);
funcs::TransCompute<CPUContext, T>(2, dev_ctx, temp, mask_matrix, trans_vec);
}
template <typename TensorType>
void ResetParameterVector(const std::vector<TensorType>& raw_params_vec,
int num_layers,
int gate_num,
bool is_bidirec,
std::vector<std::vector<DenseTensor>>* params_vec) {
// the parameter raw seuquence is [FWhi, FWhh, BWhi, BWhh] * num_layers
// + [FBhi, FBhh, BBhi, BBhh] * num_layers, we will reset the parameter to
// ([FWhi, FWhh, FBhi, FBhh] + [BWhi, BWhh, BBhi, BBhh]) * num_layers
const int& direction_num = is_bidirec ? 2 : 1;
const int& layer_weight_size = 4 * direction_num;
const int& all_weight_size = num_layers * layer_weight_size;
const int& bias_start_idx = all_weight_size / 2;
for (int i = 0; i < num_layers; i++) {
std::vector<DenseTensor> tensor_list;
tensor_list.reserve(layer_weight_size);
for (int j = 0; j < layer_weight_size; j++) {
DenseTensor tensor_holder;
tensor_list.emplace_back(tensor_holder);
}
for (int j = 0; j < layer_weight_size; j++) {
int k = j % 4;
const int& section = j / 4;
int tensor_idx = i * 2 * direction_num + section * 2 + k % 2;
if (k >= 2) {
tensor_idx += bias_start_idx;
}
tensor_list[j].ShareDataWith(*raw_params_vec[tensor_idx]);
}
params_vec->emplace_back(tensor_list);
}
}
template <typename T>
void DropoutHelper(const CPUContext& dev_ctx,
DenseTensor* x,
DenseTensor* y,
const DenseTensor* mask,
float dropout_prob) {
auto& place = *dev_ctx.eigen_device();
auto dropout_mask = EigenVector<uint8_t>::Flatten(*mask);
auto in = EigenVector<T>::Flatten(*x);
auto out = EigenVector<T>::Flatten(*y);
if (dropout_prob == 1.0f) {
out.device(place) = static_cast<T>(0) * in;
} else {
out.device(place) =
in * dropout_mask.cast<T>() / static_cast<T>(1.0f - dropout_prob);
}
}
template <typename T>
void DropoutCpuFunctionInplace(const CPUContext& dev_ctx,
DenseTensor* x,
DenseTensor* y,
DenseTensor* mask,
const float& dropout_prob,
const int& seed_number,
bool is_test,
bool* is_has_reset) {
if (is_test) {
return;
}
size_t size = phi::product(x->dims());
auto* mask_data = mask->data<uint8_t>();
if (!(*is_has_reset)) {
// Special case when dropout_prob is 1.0
if (dropout_prob == 1.0f) {
std::fill(mask_data, mask_data + size, static_cast<uint8_t>(0));
} else {
auto engine = paddle::framework::GetCPURandomEngine(seed_number);
std::uniform_real_distribution<float> dist(0, 1);
for (size_t i = 0; i < size; ++i) {
if (dist(*engine) < dropout_prob) {
mask_data[i] = 0;
} else {
mask_data[i] = 1;
}
}
}
*is_has_reset = true;
}
DropoutHelper<T>(dev_ctx, x, y, mask, dropout_prob);
}
template <typename Context, typename TensorType>
void SplitReserveData(const Context& dev_ctx,
int direction_num,
int time_step,
int batch_size,
int hidden_size,
int gate_num,
int num_layers,
const std::string& mode,
TensorType* reserve_data,
DenseTensor* gate_data,
DenseTensor* cell_data,
DenseTensor* cell_act_data,
DenseTensor* hidden_data) {
int gate_data_idx = gate_num * num_layers;
int cell_data_idx = (gate_num + 1) * num_layers;
int cell_act_data_idx = (gate_num + 2) * num_layers;
// simple rnn
int hidden_data_start_idx = gate_data_idx;
*gate_data = reserve_data->Slice(0, gate_data_idx);
if (is_lstm(mode)) {
*cell_data = reserve_data->Slice(gate_data_idx, cell_data_idx);
*cell_act_data = reserve_data->Slice(cell_data_idx, cell_act_data_idx);
hidden_data_start_idx = cell_act_data_idx;
} else if (is_gru(mode)) {
*cell_data = reserve_data->Slice(gate_data_idx, cell_data_idx);
hidden_data_start_idx = cell_data_idx;
}
int hidden_data_idx = hidden_data_start_idx + (num_layers - 1);
if (num_layers > 1) {
*hidden_data = reserve_data->Slice(hidden_data_start_idx, hidden_data_idx);
}
}
template <typename CellType, typename T, typename Context>
void AllocateReserveData(const Context& dev_ctx,
bool is_bidirec,
int num_layers,
int gate_num,
int hidden_size,
const std::string& mode,
DenseTensor* reserve_data,
DenseTensor* gate_data,
DenseTensor* cell_data,
DenseTensor* cell_act_data,
DenseTensor* hidden_data,
const DenseTensor* input) {
int direction_num = is_bidirec ? 2 : 1;
int time_step = input->dims()[0];
int batch_size = input->dims()[1];
int block_size = direction_num * time_step * batch_size * hidden_size;
int hidden_data_idx = (num_layers - 1);
if (is_lstm(mode)) {
hidden_data_idx += (gate_num + 2) * num_layers;
} else if (is_gru(mode)) {
hidden_data_idx += (gate_num + 1) * num_layers;
} else {
hidden_data_idx += gate_num * num_layers;
}
reserve_data->Resize({hidden_data_idx, block_size});
dev_ctx.template Alloc<T>(reserve_data);
SplitReserveData(dev_ctx,
direction_num,
time_step,
batch_size,
hidden_size,
gate_num,
num_layers,
mode,
reserve_data,
gate_data,
cell_data,
cell_act_data,
hidden_data);
}
inline std::vector<DenseTensor> Unbind(const DenseTensor& in) {
int64_t size = in.dims()[0];
std::vector<DenseTensor> tensors;
tensors.reserve(size);
for (int64_t i = 0; i < size; ++i) {
tensors.emplace_back(in.Slice(i, i + 1));
}
return tensors;
}
template <typename CellType,
template <typename, typename> class LayerT,
template <typename, typename> class SingleLayerT,
template <typename, typename> class BidirLayerT,
typename T,
typename Context>
void RnnFunc(const Context& dev_ctx,
const DenseTensor* input,
const std::vector<const DenseTensor*>& weight_list,
const DenseTensor* init_h,
const DenseTensor* init_c,
const DenseTensor* sequence_length,
DenseTensor* last_h,
DenseTensor* last_c,
DenseTensor* output,
DenseTensor* dropout_mask,
int num_layers,
int gate_num,
int input_size,
int hidden_size,
bool is_bidirec,
const std::string& cell_type,
float dropout_prob,
bool is_test,
int seed,
DenseTensor* reserve_data) {
int direction_num = is_bidirec ? 2 : 1;
const auto& init_h_dims = init_h->dims();
PADDLE_ENFORCE_EQ(init_h_dims[0],
num_layers * direction_num,
phi::errors::InvalidArgument(
"The num_layers of in RNN layer must be the same as "
"first dim of init hidden, but received"
" num_layers:%d, dim:%d",
num_layers,
init_h_dims[0]));
if (is_lstm(cell_type)) {
const auto& init_c_dims = init_c->dims();
PADDLE_ENFORCE_EQ(init_c_dims[0],
num_layers * direction_num,
phi::errors::InvalidArgument(
"The num_layers of in RNN layer must be the same as "
"first dim of cell state hidden, but received"
" num_layers:%d, dim:%d",
num_layers,
init_h_dims[0]));
}
CellType cell;
std::vector<std::vector<DenseTensor>> parameter_lists;
parameter_lists.reserve(num_layers);
ResetParameterVector(
weight_list, num_layers, gate_num, is_bidirec, &parameter_lists);
DenseTensor gate_data, cell_data, cell_act_data, hidden_data;
if (!is_test) {
AllocateReserveData<CellType, T, Context>(dev_ctx,
is_bidirec,
num_layers,
gate_num,
hidden_size,
cell_type,
reserve_data,
&gate_data,
&cell_data,
&cell_act_data,
&hidden_data,
input);
gate_data.Resize({num_layers, gate_data.numel() / num_layers});
cell_data.Resize({num_layers, cell_data.numel() / num_layers});
cell_act_data.Resize({num_layers, cell_act_data.numel() / num_layers});
if (num_layers > 1) {
hidden_data.Resize(
{num_layers - 1, hidden_data.numel() / (num_layers - 1)});
}
}
DenseTensor* input_holder;
DenseTensor* output_holder = output;
bool has_allocate_mem = false;
auto init_h_unbind = Unbind(*init_h);
auto last_h_unbind = Unbind(*last_h);
std::vector<DenseTensor> init_c_unbind, last_c_unbind;
if (is_lstm(cell_type)) {
init_c_unbind = Unbind(*init_c);
last_c_unbind = Unbind(*last_c);
}
DenseTensor curr_gate_data, curr_cell_data, curr_cell_act_data;
DenseTensor curr_hidden_data, prev_hidden_data;
DenseTensor temp;
bool has_dropout_reset = false;
for (int i = 0; i < num_layers; i++) {
if (!is_test) {
if (cell_data.numel() > 0) /** for lstm, gru **/ {
curr_cell_data = cell_data.Slice(i, i + 1);
}
if (cell_act_data.numel() > 0) /*for lstm*/ {
curr_cell_act_data = cell_act_data.Slice(i, i + 1);
}
curr_gate_data = gate_data.Slice(i, i + 1);
output_holder = output;
if (i < num_layers - 1 && num_layers > 1) {
curr_hidden_data = hidden_data.Slice(i, i + 1);
curr_hidden_data.Resize(output->dims());
output_holder = &curr_hidden_data;
}
}
if (i > 0) {
if (!has_allocate_mem) {
temp.Resize(output->dims());
dev_ctx.template Alloc<T>(&temp);
input_holder = &temp;
has_allocate_mem = true;
}
if (!is_test) {
prev_hidden_data = hidden_data.Slice(i - 1, i);
input_holder->Resize(output->dims());
if (dropout_prob != 0) {
DropoutCpuFunctionInplace<T>(dev_ctx,
&prev_hidden_data,
input_holder,
dropout_mask,
dropout_prob,
seed,
is_test,
&has_dropout_reset);
} else {
input_holder = &prev_hidden_data;
input_holder->Resize(output->dims());
}
} else {
SwapPoniter(&output_holder, &input_holder);
}
}
const DenseTensor* input_temp_holder = input;
if (i > 0) {
input_temp_holder = input_holder;
}
LayerT<T, CellType>* layer;
SingleLayerT<T, CellType> slayer(cell);
BidirLayerT<T, CellType> blayer(cell);
if (is_bidirec) {
layer = &blayer;
} else {
layer = &slayer;
}
(*layer)(dev_ctx,
input_temp_holder,
parameter_lists[i],
init_h_unbind,
init_c_unbind,
sequence_length,
last_h_unbind,
last_c_unbind,
output_holder,
i,
gate_num,
&curr_gate_data,
&curr_cell_data,
&curr_cell_act_data,
cell_type,
is_test);
}
if (num_layers % 2 == 0) {
Copy(dev_ctx, *output_holder, dev_ctx.GetPlace(), false, output);
}
}
} // namespace phi
此差异已折叠。
此差异已折叠。
......@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/funcs/detail/activation_functions.h"
#include "paddle/phi/kernels/funcs/gru_compute.h"
......@@ -283,11 +283,10 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
#endif
}
template <typename T>
inline void forward_reset_outputV2(
const paddle::platform::CPUDeviceContext &context,
phi::funcs::GRUMetaValue<T> value,
int frame_size) {
template <typename T, typename Context>
inline void forward_reset_outputV2(const Context &context,
phi::funcs::GRUMetaValue<T> value,
int frame_size) {
auto &place = *context.eigen_device();
auto value_reset_gate =
typename EigenVector<T>::Type(value.gate_value, Array1(frame_size));
......@@ -297,23 +296,20 @@ inline void forward_reset_outputV2(
value.reset_output_value, Array1(frame_size));
auto value_reset_bias =
typename EigenVector<T>::ConstType(value.reset_bias, Array1(frame_size));
paddle::operators::SigmoidFunctor<T>()(
place, value_reset_gate, value_reset_gate);
paddle::operators::SigmoidFunctor<T>()(
place, value_update_gate, value_update_gate);
SigmoidFunctor<T>()(place, value_reset_gate, value_reset_gate);
SigmoidFunctor<T>()(place, value_update_gate, value_update_gate);
value_reset_output.device(place) =
(value_reset_output + value_reset_bias) * value_reset_gate;
}
template <class OpResetOutput, typename T>
inline void forward_reset_output(
OpResetOutput op_reset_output,
phi::funcs::GRUMetaValue<T> value,
int frame_size,
int batch_size,
ActivationType active_gate,
bool old_version = true,
const paddle::platform::CPUDeviceContext *context = nullptr) {
template <typename Context, class OpResetOutput, typename T>
inline void forward_reset_output(OpResetOutput op_reset_output,
phi::funcs::GRUMetaValue<T> value,
int frame_size,
int batch_size,
ActivationType active_gate,
bool old_version = true,
const Context *context = nullptr) {
for (int b = 0; b < batch_size; b++) {
if (!old_version) {
// use eigen
......@@ -348,11 +344,10 @@ inline void forward_reset_output(
}
}
template <typename T>
inline void forward_final_outputV2(
const paddle::platform::CPUDeviceContext &context,
phi::funcs::GRUMetaValue<T> value,
int frame_size) {
template <typename T, typename Context>
inline void forward_final_outputV2(const Context &context,
phi::funcs::GRUMetaValue<T> value,
int frame_size) {
auto &place = *context.eigen_device();
auto value_update_gate = typename EigenVector<T>::Type(
value.gate_value + frame_size, Array1(frame_size));
......@@ -360,8 +355,7 @@ inline void forward_final_outputV2(
value.gate_value + 2 * frame_size, Array1(frame_size));
auto value_output =
typename EigenVector<T>::Type(value.output_value, Array1(frame_size));
paddle::operators::TanhFunctor<T>()(
place, value_frame_state, value_frame_state);
TanhFunctor<T>()(place, value_frame_state, value_frame_state);
value_output.device(place) =
(static_cast<T>(1.0) - value_update_gate) * value_frame_state;
if (value.prev_out_value) {
......@@ -372,16 +366,15 @@ inline void forward_final_outputV2(
}
}
template <class OpFinalOutput, typename T>
inline void forward_final_output(
OpFinalOutput op_final_output,
phi::funcs::GRUMetaValue<T> value,
int frame_size,
int batch_size,
ActivationType active_node,
bool origin_mode,
bool old_version = true,
const paddle::platform::CPUDeviceContext *context = nullptr) {
template <typename Context, class OpFinalOutput, typename T>
inline void forward_final_output(OpFinalOutput op_final_output,
phi::funcs::GRUMetaValue<T> value,
int frame_size,
int batch_size,
ActivationType active_node,
bool origin_mode,
bool old_version = true,
const Context *context = nullptr) {
for (int b = 0; b < batch_size; b++) {
if (!old_version) {
// eigen
......@@ -871,8 +864,8 @@ inline void backward_reset_grad(OpResetGrad op_reset_grad,
}
}
template <typename T>
inline void gru_backward(const paddle::platform::CPUDeviceContext &context,
template <typename T, typename Context>
inline void gru_backward(const Context &context,
phi::funcs::GRUMetaValue<T> value,
phi::funcs::GRUMetaGrad<T> grad,
int frame_size) {
......@@ -901,14 +894,13 @@ inline void gru_backward(const paddle::platform::CPUDeviceContext &context,
if (value.prev_out_value) {
auto value_prev_out = typename EigenVector<T>::ConstType(
value.prev_out_value, Array1(frame_size));
paddle::operators::SigmoidGradFunctor<T>()(
place,
1 /*useless*/,
value_update_gate,
(value_prev_out - value_frame_state) * grad_output,
grad_update_gate);
SigmoidGradFunctor<T>()(place,
1 /*useless*/,
value_update_gate,
(value_prev_out - value_frame_state) * grad_output,
grad_update_gate);
} else {
paddle::operators::SigmoidGradFunctor<T>()(
SigmoidGradFunctor<T>()(
place,
1 /*useless*/,
value_update_gate,
......@@ -921,13 +913,12 @@ inline void gru_backward(const paddle::platform::CPUDeviceContext &context,
grad_prev_out.device(place) =
grad_prev_out + grad_output * value_update_gate;
}
paddle::operators::TanhGradFunctor<T>()(
place,
1 /*useless*/,
value_frame_state,
grad_output * (static_cast<T>(1.0) - value_update_gate),
grad_frame_state);
paddle::operators::SigmoidGradFunctor<T>()(
TanhGradFunctor<T>()(place,
1 /*useless*/,
value_frame_state,
grad_output * (static_cast<T>(1.0) - value_update_gate),
grad_frame_state);
SigmoidGradFunctor<T>()(
place,
1 /*useless*/,
value_reset_gate,
......@@ -938,8 +929,8 @@ inline void gru_backward(const paddle::platform::CPUDeviceContext &context,
}
}
template <class OpGruGrad, typename T>
inline void cpu_gru_backward(const paddle::platform::CPUDeviceContext &context,
template <class OpGruGrad, typename T, typename Context>
inline void cpu_gru_backward(const Context &context,
OpGruGrad op_gru_grad,
phi::funcs::GRUMetaValue<T> value,
phi::funcs::GRUMetaGrad<T> grad,
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/funcs/detail/activation_functions.h"
#include "paddle/phi/kernels/funcs/lstm_compute.h"
......@@ -409,11 +409,10 @@ void avx_lstm_backward_one_sequence(Op op,
#endif
}
template <class T>
void eigen_lstm_forward_one_sequence(
const paddle::platform::CPUDeviceContext &context,
phi::funcs::LstmMetaValue<T> value,
int frame_size) {
template <class T, class Context>
void eigen_lstm_forward_one_sequence(const Context &context,
phi::funcs::LstmMetaValue<T> value,
int frame_size) {
auto eigen_value_ig =
typename EigenVector<T>::Type(value.gate_value, Array1(frame_size));
auto eigen_value_fg = typename EigenVector<T>::Type(
......@@ -430,10 +429,10 @@ void eigen_lstm_forward_one_sequence(
typename EigenVector<T>::Type(value.output_value, Array1(frame_size));
auto &place = *context.eigen_device();
paddle::operators::TanhFunctor<T>()(place, eigen_value_in, eigen_value_in);
paddle::operators::SigmoidFunctor<T>()(place, eigen_value_ig, eigen_value_ig);
paddle::operators::SigmoidFunctor<T>()(place, eigen_value_fg, eigen_value_fg);
paddle::operators::SigmoidFunctor<T>()(place, eigen_value_og, eigen_value_og);
TanhFunctor<T>()(place, eigen_value_in, eigen_value_in);
SigmoidFunctor<T>()(place, eigen_value_ig, eigen_value_ig);
SigmoidFunctor<T>()(place, eigen_value_fg, eigen_value_fg);
SigmoidFunctor<T>()(place, eigen_value_og, eigen_value_og);
eigen_state.device(place) = eigen_value_in * eigen_value_ig;
if (value.prev_state_value) {
......@@ -442,16 +441,15 @@ void eigen_lstm_forward_one_sequence(
eigen_state.device(place) = eigen_state + eigen_prev_state * eigen_value_fg;
}
paddle::operators::TanhFunctor<T>()(place, eigen_state, eigen_state_act);
TanhFunctor<T>()(place, eigen_state, eigen_state_act);
eigen_output.device(place) = eigen_value_og * eigen_state_act;
}
template <class T>
void eigen_lstm_backward_one_sequence(
const paddle::platform::CPUDeviceContext &context,
phi::funcs::LstmMetaValue<T> value,
phi::funcs::LstmMetaGrad<T> grad,
int frame_size) {
template <class T, class Context>
void eigen_lstm_backward_one_sequence(const Context &context,
phi::funcs::LstmMetaValue<T> value,
phi::funcs::LstmMetaGrad<T> grad,
int frame_size) {
auto eigen_value_ig =
typename EigenVector<T>::Type(value.gate_value, Array1(frame_size));
auto eigen_value_fg = typename EigenVector<T>::Type(
......@@ -477,38 +475,35 @@ void eigen_lstm_backward_one_sequence(
typename EigenVector<T>::Type(grad.state_grad, Array1(frame_size));
auto &place = *context.eigen_device();
paddle::operators::SigmoidGradFunctor<T>()(
place,
1 /*useless*/,
eigen_value_og,
eigen_grad_output * eigen_state_act,
eigen_grad_og);
SigmoidGradFunctor<T>()(place,
1 /*useless*/,
eigen_value_og,
eigen_grad_output * eigen_state_act,
eigen_grad_og);
eigen_grad_state.device(place) =
eigen_grad_state +
eigen_grad_output * eigen_value_og *
(static_cast<T>(1) - eigen_state_act * eigen_state_act);
paddle::operators::TanhGradFunctor<T>()(place,
1,
eigen_value_in,
eigen_grad_state * eigen_value_ig,
eigen_grad_in);
paddle::operators::SigmoidGradFunctor<T>()(place,
1,
eigen_value_ig,
eigen_grad_state * eigen_value_in,
eigen_grad_ig);
TanhGradFunctor<T>()(place,
1,
eigen_value_in,
eigen_grad_state * eigen_value_ig,
eigen_grad_in);
SigmoidGradFunctor<T>()(place,
1,
eigen_value_ig,
eigen_grad_state * eigen_value_in,
eigen_grad_ig);
if (value.prev_state_value) {
auto eigen_prev_state = typename EigenVector<T>::ConstType(
value.prev_state_value, Array1(frame_size));
paddle::operators::SigmoidGradFunctor<T>()(
place,
1,
eigen_value_fg,
eigen_grad_state * eigen_prev_state,
eigen_grad_fg);
SigmoidGradFunctor<T>()(place,
1,
eigen_value_fg,
eigen_grad_state * eigen_prev_state,
eigen_grad_fg);
} else {
paddle::operators::SigmoidGradFunctor<T>()(
place, 1, eigen_value_fg, 0, eigen_grad_fg);
SigmoidGradFunctor<T>()(place, 1, eigen_value_fg, 0, eigen_grad_fg);
}
if (grad.prev_state_grad) {
auto eigen_grad_pre_state =
......@@ -517,8 +512,8 @@ void eigen_lstm_backward_one_sequence(
}
}
template <class T, class Op>
void cpu_lstm_forward(const paddle::platform::CPUDeviceContext &context,
template <class T, class Op, class Context>
void cpu_lstm_forward(const Context &context,
Op op,
phi::funcs::LstmMetaValue<T> value,
int frame_size,
......@@ -552,8 +547,8 @@ void cpu_lstm_forward(const paddle::platform::CPUDeviceContext &context,
}
}
template <class T, class Op>
void cpu_lstm_backward(const paddle::platform::CPUDeviceContext &context,
template <class T, class Op, class Context>
void cpu_lstm_backward(const Context &context,
Op op,
phi::funcs::LstmMetaValue<T> value,
phi::funcs::LstmMetaGrad<T> grad,
......
......@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/funcs/lstm_compute.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/kernels/funcs/detail/lstm_cpu_kernel.h"
#include "paddle/phi/kernels/funcs/detail/lstm_kernel.h"
......@@ -51,6 +53,38 @@ struct LstmUnitFunctor<paddle::platform::CPUDeviceContext, T> {
}
};
template <class T>
struct LstmUnitFunctor<CPUContext, T> {
static void compute(const CPUContext& context,
LstmMetaValue<T> value,
int frame_size,
int batch_size,
T cell_clip,
const phi::funcs::detail::ActivationType& gate_act,
const phi::funcs::detail::ActivationType& cell_act,
const phi::funcs::detail::ActivationType& cand_act,
bool old_api_version = true) {
for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_forward(context,
phi::funcs::detail::forward::lstm<T>(),
value,
frame_size,
cell_clip,
cand_act,
gate_act,
cell_act,
old_api_version);
value.gate_value += frame_size * 4;
value.state_value += frame_size;
value.state_active_value += frame_size;
value.output_value += frame_size;
if (value.prev_state_value) {
value.prev_state_value += frame_size;
}
}
}
};
template <class T>
struct LstmUnitGradFunctor<paddle::platform::CPUDeviceContext, T> {
static void compute(const paddle::platform::CPUDeviceContext& context,
......@@ -94,10 +128,58 @@ struct LstmUnitGradFunctor<paddle::platform::CPUDeviceContext, T> {
}
};
template <class T>
struct LstmUnitGradFunctor<CPUContext, T> {
static void compute(const CPUContext& context,
LstmMetaValue<T> value,
LstmMetaGrad<T> grad,
int frame_size,
int batch_size,
T cell_clip,
const phi::funcs::detail::ActivationType& gate_act,
const phi::funcs::detail::ActivationType& cell_act,
const phi::funcs::detail::ActivationType& cand_act,
bool old_api_version = true) {
for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_backward(context,
phi::funcs::detail::backward::lstm<T>(),
value,
grad,
frame_size,
cell_clip,
cand_act,
gate_act,
cell_act,
old_api_version);
value.gate_value += frame_size * 4;
value.state_value += frame_size;
value.state_active_value += frame_size;
value.output_value += frame_size;
if (value.prev_state_value) {
value.prev_state_value += frame_size;
}
grad.gate_grad += frame_size * 4;
grad.state_grad += frame_size;
grad.state_active_grad += frame_size;
grad.output_grad += frame_size;
if (grad.prev_state_grad) {
grad.prev_state_grad += frame_size;
}
}
}
};
template class LstmUnitFunctor<paddle::platform::CPUDeviceContext, float>;
template class LstmUnitFunctor<paddle::platform::CPUDeviceContext, double>;
template class LstmUnitGradFunctor<paddle::platform::CPUDeviceContext, float>;
template class LstmUnitGradFunctor<paddle::platform::CPUDeviceContext, double>;
template class LstmUnitFunctor<CPUContext, float>;
template class LstmUnitFunctor<CPUContext, double>;
template class LstmUnitGradFunctor<CPUContext, float>;
template class LstmUnitGradFunctor<CPUContext, double>;
} // namespace funcs
} // namespace phi
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册