未验证 提交 00d23897 编写于 作者: z8hanghuan's avatar z8hanghuan 提交者: GitHub

support multi_layer of bilstm,*test=kunlun (#41151)

* support multi_layer of bilstm,*test=kunlun

* support multi_layer of bilstm, *test=kunlun

* support multi_layer of bilstm, *test=kunlun

* support multi_layer of bilstm, *test=kunlun
上级 0d28edfa
......@@ -36,7 +36,7 @@ ENDIF()
if(NOT DEFINED XPU_BASE_URL)
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220327")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220331")
else()
SET(XPU_BASE_URL "${XPU_BASE_URL}")
endif()
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
......@@ -13,6 +13,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
#include "paddle/fluid/platform/device_context.h"
......@@ -21,9 +22,7 @@ namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
using TensorList = std::vector<framework::Tensor>;
template <typename TensorType, typename T>
void reset_parameter_vector(const std::vector<TensorType>& raw_params_vec,
const int& num_layers, const bool& is_bidirec,
......@@ -51,54 +50,89 @@ void reset_parameter_vector(const std::vector<TensorType>& raw_params_vec,
}
}
template <typename DeviceContext, typename T>
void RunLSTMLayer(const framework::ExecutionContext& ctx, int seq_len,
int batch_size, int xdim, int hidden_size, const T* x, T* y,
const T* init_h, const T* init_c, T* last_h, T* last_c,
int state_offset, const std::vector<int>& seq_len_tensor,
const std::vector<const T*>& param_list, T* i_f_g_o, T* c,
bool is_bidirect, int layer_idx, int offset) {
bool is_reverse = false;
if (is_bidirect) {
layer_idx = 2 * layer_idx + offset;
if (offset > 0) {
is_reverse = true;
}
}
auto w_x = param_list[0 + offset * 4];
auto w_h = param_list[1 + offset * 4];
auto b_x = param_list[2 + offset * 4];
auto b_h = param_list[3 + offset * 4];
auto h_0 = init_h + layer_idx * state_offset;
auto c_0 = init_c + layer_idx * state_offset;
auto last_h_ptr = last_h + layer_idx * state_offset;
auto last_c_ptr = last_c + layer_idx * state_offset;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::lstm_train<T, T, int16_t>(
dev_ctx.x_context(), (const T*)x, (const T*)h_0, (const T*)c_0,
(const T*)w_x, (const T*)w_h, (const T*)b_x, (const T*)b_h,
reinterpret_cast<T*>(y), reinterpret_cast<T*>(last_h_ptr),
reinterpret_cast<T*>(last_c_ptr), batch_size, xdim, hidden_size, seq_len,
seq_len_tensor, is_reverse, nullptr, nullptr, nullptr, nullptr,
reinterpret_cast<T*>(i_f_g_o), reinterpret_cast<T*>(c),
xpu::Activation_t::TANH, xpu::Activation_t::SIGMOID);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "lstm_train");
}
template <typename DeviceContext, typename T>
class RnnXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// Input
auto* input = ctx.Input<Tensor>("Input");
auto pre_state = ctx.MultiInput<Tensor>("PreState");
auto weight_list = ctx.MultiInput<framework::Tensor>("WeightList");
bool has_seq_length = ctx.HasInput("SequenceLength");
// Output
auto state = ctx.MultiOutput<Tensor>("State");
auto* output = ctx.Output<Tensor>("Out");
auto* dropout_mask = ctx.Output<Tensor>("DropoutState");
auto* reserve_data = ctx.Output<Tensor>("Reserve");
// Attrbutes
const int& num_layers = ctx.Attr<int>("num_layers");
const bool& is_bidirec = ctx.Attr<bool>("is_bidirec");
const int& hidden_size = ctx.Attr<int>("hidden_size");
const std::string& mode = ctx.Attr<std::string>("mode");
bool has_seq_length = ctx.HasInput("SequenceLength");
const Tensor* sequence_length = nullptr;
if (has_seq_length) {
sequence_length = ctx.Input<Tensor>("SequenceLength");
}
if (dropout_mask->IsInitialized()) {
if (dropout_mask->numel() != output->numel()) dropout_mask->clear();
}
dropout_mask->mutable_data<uint8_t>(output->dims(), ctx.GetPlace());
PADDLE_ENFORCE_EQ(
mode, "LSTM",
platform::errors::InvalidArgument(
"XPU only support LSTM mode now, current mode is %s", mode));
PADDLE_ENFORCE_EQ(is_bidirec, false,
platform::errors::InvalidArgument(
"XPU only support unidirectional LSTM now"));
PADDLE_ENFORCE_EQ(
num_layers, 1,
platform::errors::InvalidArgument(
"XPU only support 1 layer LSTM now, current layer num is %s",
num_layers));
auto init_h = pre_state[0];
auto init_c = pre_state[1];
auto last_h = state[0];
auto last_c = state[1];
// check shape
int seq_len = input->dims()[0];
int batch_size = input->dims()[1];
int input_dim = input->dims()[2];
const int& seq_len = input->dims()[0]; // time_step
const int& batch_size = input->dims()[1];
const int& input_dim = input->dims()[2];
const int& direction_num = is_bidirec ? 2 : 1;
PADDLE_ENFORCE_EQ(
init_h->dims()[0], num_layers,
init_h->dims()[0], num_layers * direction_num,
platform::errors::InvalidArgument("The num_layers of in RNN layer must"
" be the same as first dim of init "
"hidden, but received num_layers:%d,"
......@@ -106,13 +140,13 @@ class RnnXPUKernel : public framework::OpKernel<T> {
num_layers, init_h->dims()[0]));
PADDLE_ENFORCE_EQ(
init_c->dims()[0], num_layers,
init_c->dims()[0], num_layers * direction_num,
platform::errors::InvalidArgument(
"The num_layers of in RNN layer must"
" be the same as first dim of cell state hidden, but received"
" num_layers:%d, dim:%d",
num_layers, init_c->dims()[0]));
// weightlist
std::vector<std::vector<const T*>> parameter_lists;
parameter_lists.resize(num_layers);
reset_parameter_vector(weight_list, num_layers, is_bidirec,
......@@ -122,41 +156,106 @@ class RnnXPUKernel : public framework::OpKernel<T> {
output->mutable_data<T>(ctx.GetPlace());
last_h->mutable_data<T>(ctx.GetPlace());
last_c->mutable_data<T>(ctx.GetPlace());
reserve_data->Resize({seq_len * batch_size * hidden_size * 5});
reserve_data->mutable_data<T>(ctx.GetPlace());
reserve_data->Resize(
{num_layers * direction_num * seq_len * batch_size * hidden_size * 5});
reserve_data->mutable_data<T>(ctx.GetPlace());
Tensor internal_output_1_tensor, internal_output_2_tensor;
T* internal_output_1_ptr = nullptr;
T* internal_output_2_ptr = nullptr;
if (num_layers >= 2) {
internal_output_1_tensor.Resize(output->dims());
internal_output_1_ptr =
internal_output_1_tensor.mutable_data<T>(ctx.GetPlace());
}
if (num_layers >= 3) {
internal_output_2_tensor.Resize(output->dims());
internal_output_2_ptr =
internal_output_2_tensor.mutable_data<T>(ctx.GetPlace());
}
// get ptr from tensor
auto x = input->data<T>();
auto h_0 = init_h->data<T>();
auto c_0 = init_c->data<T>();
auto w_x = parameter_lists[0][0];
auto w_h = parameter_lists[0][1];
auto b_x = parameter_lists[0][2];
auto b_h = parameter_lists[0][3];
auto init_h_ptr = init_h->data<T>();
auto init_c_ptr = init_c->data<T>();
auto y = output->data<T>();
auto last_h_ptr = last_h->data<T>();
auto last_c_ptr = last_c->data<T>();
auto i_f_g_o = reserve_data->data<T>();
auto c = i_f_g_o + seq_len * batch_size * hidden_size * 4;
auto c =
i_f_g_o +
num_layers * direction_num * seq_len * batch_size * hidden_size * 4;
std::vector<int> seq_len_tensor(batch_size, seq_len);
if (has_seq_length) {
seq_len_tensor = operators::GetDataFromTensor(sequence_length);
}
// run kernel
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::lstm_train<T, T, int16_t>(
dev_ctx.x_context(), (const T*)x, (const T*)h_0, (const T*)c_0,
(const T*)w_x, (const T*)w_h, (const T*)b_x, (const T*)b_h,
reinterpret_cast<T*>(y), reinterpret_cast<T*>(last_h_ptr),
reinterpret_cast<T*>(last_c_ptr), batch_size, input_dim, hidden_size,
seq_len, seq_len_tensor, nullptr, nullptr, nullptr, nullptr,
reinterpret_cast<T*>(i_f_g_o), reinterpret_cast<T*>(c));
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External("RnnXPU(lstm) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
int state_offset = pre_state[0]->dims()[1] * pre_state[0]->dims()[2];
for (int i = 0; i < num_layers; i++) {
const T* cur_input_ptr = nullptr;
int cur_xdim = -1;
i_f_g_o += i * direction_num * seq_len * batch_size * hidden_size * 4;
c += i * direction_num * seq_len * batch_size * hidden_size;
if (i == 0) {
cur_input_ptr = x;
cur_xdim = input_dim;
} else if (i % 2 != 0) {
cur_input_ptr = internal_output_1_ptr;
cur_xdim = is_bidirec ? 2 * hidden_size : hidden_size;
} else {
cur_input_ptr = internal_output_2_ptr;
cur_xdim = is_bidirec ? 2 * hidden_size : hidden_size;
}
T* cur_output_ptr = nullptr;
if (i == num_layers - 1) {
cur_output_ptr = y;
} else if (i % 2 != 0) {
cur_output_ptr = internal_output_2_ptr;
} else {
cur_output_ptr = internal_output_1_ptr;
}
if (is_bidirec) {
std::vector<Tensor> output_vec(2);
std::vector<T*> output_ptr_vec(2);
for (int k = 0; k < 2; ++k) {
output_vec[k].Resize({seq_len, batch_size, output->dims()[2] / 2});
output_ptr_vec[k] = output_vec[k].mutable_data<T>(ctx.GetPlace());
}
RunLSTMLayer<DeviceContext, T>(
ctx, seq_len, batch_size, cur_xdim, hidden_size, cur_input_ptr,
output_ptr_vec[0], init_h_ptr, init_c_ptr, last_h_ptr, last_c_ptr,
state_offset, seq_len_tensor, parameter_lists[i], i_f_g_o, c,
is_bidirec, i, 0);
T* bw_i_f_g_o = i_f_g_o + seq_len * batch_size * hidden_size * 4;
T* bw_c = c + seq_len * batch_size * hidden_size;
RunLSTMLayer<DeviceContext, T>(
ctx, seq_len, batch_size, cur_xdim, hidden_size, cur_input_ptr,
output_ptr_vec[1], init_h_ptr, init_c_ptr, last_h_ptr, last_c_ptr,
state_offset, seq_len_tensor, parameter_lists[i], bw_i_f_g_o, bw_c,
is_bidirec, i, 1);
// concat
int r = xpu::concat<T>(
dev_ctx.x_context(), {output_ptr_vec[0], output_ptr_vec[1]},
cur_output_ptr, {{seq_len, batch_size, hidden_size},
{seq_len, batch_size, hidden_size}},
2);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "concat");
xpu_wait(dev_ctx.x_context()->xpu_stream);
} else {
RunLSTMLayer<DeviceContext, T>(
ctx, seq_len, batch_size, cur_xdim, hidden_size, cur_input_ptr,
cur_output_ptr, init_h_ptr, init_c_ptr, last_h_ptr, last_c_ptr,
state_offset, seq_len_tensor, parameter_lists[i], i_f_g_o, c,
is_bidirec, i, 0);
}
}
}
};
......@@ -221,7 +320,6 @@ class RnnXPUGradKernel : public framework::OpKernel<T> {
int seq_len = input->dims()[0];
int batch_size = input->dims()[1];
int input_dim = input->dims()[2];
PADDLE_ENFORCE_EQ(
init_h->dims()[0], num_layers,
platform::errors::InvalidArgument("The num_layers of in RNN layer must"
......
......@@ -295,6 +295,7 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"rnn", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"roi_align", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"roi_align_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# You may obtain a copy of the License at #
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
......@@ -14,6 +12,8 @@
from __future__ import print_function
import sys
sys.path.append("..")
import unittest
import numpy as np
import math
......@@ -22,152 +22,180 @@ import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import random
import sys
sys.path.append("..")
from op_test import OpTest
from op_test_xpu import XPUOpTest
sys.path.append("../rnn")
from rnn_numpy import SimpleRNN, LSTM, GRU
from convert import get_params_for_net
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
random.seed(2)
np.set_printoptions(threshold=np.inf)
paddle.enable_static()
class TestRNNOp(XPUOpTest):
def init_size(self):
self.seq_length = 1
self.batch_size = 1
self.input_size = 5
self.hidden_size = 16
def get_weight_names(self):
weight_names = []
for i in range(self.num_layers):
for j in range(0, 2 * self.direction_num):
weight_names.append("{}.weight_{}".format(i, j))
for i in range(self.num_layers):
for j in range(0, 2 * self.direction_num):
weight_names.append("{}.bias_{}".format(i, j))
return weight_names
def setUp(self):
self.init_size()
self.op_type = "rnn"
self.dtype = np.float32
self.sequence_length = np.ones(
(self.batch_size, ), dtype=np.int32) * self.seq_length
self.num_layers = 1
self.is_bidirec = False
self.mode = "LSTM"
self.is_test = False
self.dropout = 0.0
self.set_attrs()
self.direction_num = 2 if self.is_bidirec else 1
direction = "bidirectional" if self.is_bidirec else "forward"
input = np.random.uniform(
low=-0.1,
high=0.1,
size=(self.seq_length, self.batch_size,
self.input_size)).astype(self.dtype)
rnn1 = LSTM(
self.input_size,
self.hidden_size,
num_layers=self.num_layers,
time_major=True,
direction=direction,
dropout=self.dropout,
dtype="float32")
flat_w = get_params_for_net(rnn1)
output, (last_hidden, last_cell) = rnn1(
input, sequence_length=self.sequence_length)
init_h = np.zeros(
(self.num_layers * self.direction_num, self.batch_size,
self.hidden_size)).astype(self.dtype)
init_c = np.zeros(
(self.num_layers * self.direction_num, self.batch_size,
self.hidden_size)).astype(self.dtype)
state_out = np.ndarray((300)).astype("uint8")
self.inputs = {
'Input': input,
'WeightList': flat_w,
'PreState': [('init_h', init_h), ('init_c', init_c)],
'SequenceLength': self.sequence_length
}
if self.sequence_length is None:
class XPUTestRNNOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'rnn'
self.use_dynamic_create_class = False
class TestRNNOp(XPUOpTest):
def setUp(self):
self.init_size()
self.init_dtype()
self.op_type = "rnn"
self.place = paddle.XPUPlace(0)
self.sequence_length = np.ones(
(self.batch_size, ), dtype=np.int32) * self.seq_length
self.set_attrs()
self.mode = "LSTM"
self.is_test = False
self.dropout = 0.0
self.direction_num = 2 if self.is_bidirec else 1
direction = "bidirectional" if self.is_bidirec else "forward"
input = np.random.uniform(
low=-0.1,
high=0.1,
size=(self.seq_length, self.batch_size,
self.input_size)).astype(self.dtype)
rnn1 = LSTM(
self.input_size,
self.hidden_size,
num_layers=self.num_layers,
time_major=True,
direction=direction,
dropout=self.dropout,
dtype=self.dtype)
flat_w = get_params_for_net(rnn1)
output, (last_hidden, last_cell) = rnn1(
input, sequence_length=self.sequence_length)
init_h = np.zeros(
(self.num_layers * self.direction_num, self.batch_size,
self.hidden_size)).astype(self.dtype)
init_c = np.zeros(
(self.num_layers * self.direction_num, self.batch_size,
self.hidden_size)).astype(self.dtype)
state_out = np.ndarray((300)).astype("uint8")
self.inputs = {
'Input': input,
'WeightList': flat_w,
'PreState': [('init_h', init_h), ('init_c', init_c)],
'SequenceLength': self.sequence_length
}
if self.sequence_length is None:
self.inputs = {
'Input': input,
'WeightList': flat_w,
'PreState': [('init_h', init_h), ('init_c', init_c)],
}
self.attrs = {
'dropout_prob': self.dropout,
'is_bidirec': self.is_bidirec,
'input_size': self.input_size,
'hidden_size': self.hidden_size,
'num_layers': self.num_layers,
'mode': self.mode,
'is_test': self.is_test
}
self.outputs = {
'Out': output,
"State":
[('last_hidden', last_hidden), ('last_cell', last_cell)],
'Reserve': np.ndarray((400)).astype("uint8"),
'DropoutState': state_out
}
self.attrs = {
'dropout_prob': self.dropout,
'is_bidirec': self.is_bidirec,
'input_size': self.input_size,
'hidden_size': self.hidden_size,
'num_layers': self.num_layers,
'mode': self.mode,
'is_test': self.is_test
}
self.outputs = {
'Out': output,
"State": [('last_hidden', last_hidden), ('last_cell', last_cell)],
'Reserve': np.ndarray((400)).astype("uint8"),
'DropoutState': state_out
}
def test_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(
place, atol=0.01, no_check_set=['Reserve', 'DropoutState'])
def set_attrs(self):
pass
def test_grad(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
if not self.is_test:
var_name_list = self.get_weight_names()
grad_check_list = ['Input', 'init_h', 'init_c']
grad_check_list.extend(var_name_list)
self.check_grad_with_place(
place,
set(grad_check_list), ['Out', 'last_hidden', 'last_cell'],
max_relative_error=0.1)
class TestRNNOpCase0(TestRNNOp):
def init_size(self):
self.seq_length = 2
self.batch_size = 4
self.input_size = 10
self.hidden_size = 32
class TestRNNOpCase1(TestRNNOp):
def init_size(self):
self.seq_length = 5
self.batch_size = 16
self.input_size = 30
self.hidden_size = 64
class TestRNNOpCase2(TestRNNOp):
def init_size(self):
self.seq_length = 10
self.batch_size = 64
self.input_size = 50
self.hidden_size = 64
def init_dtype(self):
self.dtype = self.in_type
def set_xpu(self):
self.__class__.use_xpu = True
self.__class__.no_need_check_grad = True
self.__class__.op_type = self.in_type
def test_check_output(self):
self.check_output_with_place(
self.place, atol=0.01,
no_check_set=['Reserve', 'DropoutState'])
def init_size(self):
self.seq_length = 1
self.batch_size = 1
self.input_size = 5
self.hidden_size = 16
def get_weight_names(self):
weight_names = []
for i in range(self.num_layers):
for j in range(0, 2 * self.direction_num):
weight_names.append("{}.weight_{}".format(i, j))
for i in range(self.num_layers):
for j in range(0, 2 * self.direction_num):
weight_names.append("{}.bias_{}".format(i, j))
return weight_names
def set_attrs(self):
self.num_layers = 1
self.is_bidirec = False
class TestRNNOp1(TestRNNOp):
def init_size(self):
self.seq_length = 2
self.batch_size = 4
self.input_size = 10
self.hidden_size = 32
def set_attrs(self):
self.num_layers = 1
self.is_bidirec = False
class TestRNNOp2(TestRNNOp):
def init_size(self):
self.seq_length = 5
self.batch_size = 16
self.input_size = 30
self.hidden_size = 64
def set_attrs(self):
self.num_layers = 1
self.is_bidirec = True
class TestRNNOp3(TestRNNOp):
def init_size(self):
self.seq_length = 10
self.batch_size = 64
self.input_size = 50
self.hidden_size = 64
def set_attrs(self):
self.num_layers = 2
self.is_bidirec = False
class TestRNNOp4(TestRNNOp):
def set_attrs(self):
self.num_layers = 3
self.is_bidirec = False
class TestRNNOp5(TestRNNOp):
def set_attrs(self):
self.num_layers = 2
self.is_bidirec = True
support_types = get_xpu_op_support_types('rnn')
for stype in support_types:
create_test_class(
globals(),
XPUTestRNNOp,
stype,
ignore_deivce_version=[core.XPUVersion.XPU1])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册