未验证 提交 36bb056e 编写于 作者: G GaoWei8 提交者: GitHub

Add flattern weight of lstm (#27192)

* add flattern weight of lstm
上级 7779790c
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace operators {
......@@ -25,7 +26,6 @@ class CudnnLSTMOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTM");
......@@ -122,7 +122,13 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("W",
"(Tensor) the learnable hidden-hidden weights."
" The shape is (N), where N is total weight size of the LSTM. "
" cudnn concatenate all the weight to one Tensor");
" cudnn concatenate all the weight to one Tensor")
.AsDispensable();
AddInput("WeightList",
"(vector<Tensor>), stores weight and bias data when the weight "
"use the list format. ")
.AsDispensable()
.AsDuplicable();
AddInput("SequenceLength",
"(Tensor) When the input data is padding, "
"set this parameter. This parameter represents "
......@@ -216,7 +222,6 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTMGrad");
OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTMGrad");
OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTMGrad");
OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTMGrad");
......@@ -228,7 +233,10 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel {
};
SetOutGradDim("Input");
SetOutGradDim("W");
if (ctx->HasInputs("WeightList")) {
ctx->SetOutputsDim(framework::GradVarName("WeightList"),
ctx->GetInputsDim("WeightList"));
}
SetOutGradDim("InitH");
SetOutGradDim("InitC");
}
......@@ -251,7 +259,9 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("Input", this->Input("Input"));
op->SetInput("InitH", this->Input("InitH"));
op->SetInput("InitC", this->Input("InitC"));
op->SetInput("W", this->Input("W"));
if (this->HasInput("WeightList")) {
op->SetInput("WeightList", this->Input("WeightList"));
}
if (this->HasInput("SequenceLength")) {
op->SetInput("SequenceLength", this->Input("SequenceLength"));
}
......@@ -262,8 +272,12 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput(framework::GradVarName("LastC"), this->OutputGrad("LastC"));
op->SetInput(framework::GradVarName("LastH"), this->OutputGrad("LastH"));
if (this->HasInput("WeightList")) {
op->SetOutput(framework::GradVarName("WeightList"),
this->InputGrad("WeightList", false));
}
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
op->SetOutput(framework::GradVarName("W"), this->InputGrad("W"));
op->SetOutput(framework::GradVarName("InitH"), this->InputGrad("InitH"));
op->SetOutput(framework::GradVarName("InitC"), this->InputGrad("InitC"));
op->SetAttrMap(this->Attrs());
......@@ -290,3 +304,20 @@ REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp);
REGISTER_OP_CPU_KERNEL(cudnn_lstm, ops::NotImpleKernel<float>);
REGISTER_OP_CPU_KERNEL(cudnn_lstm_grad, ops::NotImpleKernel<float>);
// TODO(Shixiaowei02) Add ModifyInput support
REGISTER_OP_VERSION(cudnn_lstm)
.AddCheckpoint(
R"ROC(
Upgrade cudnn_lstm add a new input [WeightList] and modify input [W] to dispensable.)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewInput(
"WeightList",
"The WeightList stores weight and bias data. WeightList is "
"dispensable.")
.NewInput("SequenceLength",
"When the input data is padding, set this parameter. "
"SequenceLength is dispensable.")
.NewOutput("StateOut", "Store the global drop state when training")
.NewOutput("Reserve",
"A temporary output Tensor to store the reserve_data"));
......@@ -30,6 +30,66 @@ namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
template <typename T, typename Type>
bool is_continuous(const Type &weight_list) {
bool continuous = true;
for (size_t i = 0; i < weight_list.size() - 1; ++i) {
auto *in_data = weight_list[i]->template data<T>();
auto *in_after_data = weight_list[i + 1]->template data<T>();
auto in_size = weight_list[i]->numel();
bool temp = in_data + in_size == in_after_data;
continuous = continuous && temp;
}
return continuous;
}
int size_sum(const std::vector<const Tensor *> &weight_list) {
int size = 0;
for (size_t i = 0; i < weight_list.size(); ++i) {
auto in_size = weight_list[i]->numel();
size += in_size;
}
return size;
}
template <typename T>
void weight_to_tensor(const platform::Place &place, cudaStream_t stream,
const std::vector<const Tensor *> &weight_list,
Tensor *weight) {
auto weight_data = weight->data<T>();
int weight_offset = 0;
for (size_t i = 0; i < weight_list.size(); ++i) {
const T *in_data = weight_list[i]->data<T>();
auto in_size = weight_list[i]->numel();
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, weight->place()),
weight_data + weight_offset,
BOOST_GET_CONST(platform::CUDAPlace, weight_list[i]->place()),
in_data, in_size * sizeof(T), stream);
weight_offset += in_size;
}
}
template <typename T>
void weight_to_tensor_list(const platform::Place &place, cudaStream_t stream,
std::vector<Tensor *> *weight_grad,
const std::vector<const Tensor *> &weight_input,
const Tensor *weight) {
int weight_offset = 0;
auto *weight_data = weight->data<T>();
for (size_t i = 0; i < weight_input.size(); ++i) {
auto in_size = weight_input[i]->numel();
T *weight_grad_data = (*weight_grad)[i]->mutable_data<T>(place);
const T *src = weight_data + weight_offset;
memory::Copy(
BOOST_GET_CONST(platform::CUDAPlace, (*weight_grad)[i]->place()),
weight_grad_data, BOOST_GET_CONST(platform::CUDAPlace, weight->place()),
src, in_size * sizeof(T), stream);
weight_offset += in_size;
}
}
template <typename T>
void LSTMInferece(const bool &has_seq_length, const cudnnHandle_t &handle,
const int &seq_length, ScopedRNNBase *rnn, const T *x_data,
......@@ -75,8 +135,6 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
const Tensor *init_h = ctx.Input<Tensor>("InitH");
const Tensor *init_c = ctx.Input<Tensor>("InitC");
auto w = ctx.Input<Tensor>("W");
Tensor *out = ctx.Output<Tensor>("Out");
Tensor *last_h = ctx.Output<Tensor>("LastH");
Tensor *last_c = ctx.Output<Tensor>("LastC");
......@@ -87,8 +145,6 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
const T *init_h_data = init_h->data<T>();
const T *init_c_data = init_c->data<T>();
const T *w_data = w->data<T>();
T *out_data = out->mutable_data<T>(ctx.GetPlace());
T *last_h_data = last_h->mutable_data<T>(ctx.GetPlace());
T *last_c_data = last_c->mutable_data<T>(ctx.GetPlace());
......@@ -113,11 +169,45 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
int seq_length = x->dims()[0];
int batch_size = x->dims()[1];
int input_size = x->dims()[2];
int weight_numel = w->numel();
bool state_initialized = state_out->IsInitialized() ? true : false;
size_t workspace_size;
size_t reserve_size;
Tensor weight_whole;
T *w_data = nullptr;
int weight_numel;
bool w_initialized = false;
auto place = ctx.GetPlace();
auto stream = reinterpret_cast<const platform::CUDADeviceContext &>(
ctx.device_context())
.stream();
if (is_test && ctx.HasInput("W")) {
auto *W = ctx.Input<Tensor>("W");
w_initialized = W->IsInitialized() ? true : false;
weight_numel = W->numel();
}
if (!w_initialized) {
auto weight_list = ctx.MultiInput<framework::Tensor>("WeightList");
bool continuous =
is_continuous<T, std::vector<const Tensor *>>(weight_list);
weight_numel = size_sum(weight_list);
if (!continuous) {
LOG_FIRST_N(WARNING, 2)
<< "If the memory space of the Input WeightList is not "
"continuous, less efficient calculation will be "
"called. Please call coalesce_tensor op to make the "
"input memory continuous.";
weight_whole.mutable_data<T>({weight_numel}, place);
weight_to_tensor<T>(place, stream, weight_list, &weight_whole);
w_data = weight_whole.data<T>();
} else {
w_data = const_cast<T *>(weight_list[0]->data<T>());
}
} else {
auto *W = ctx.Input<Tensor>("W");
w_data = const_cast<T *>(W->data<T>());
}
ScopedRNNBase rnn(seq_length, batch_size, input_size, hidden_size,
num_layers, dropout_prob, seed, weight_numel,
......@@ -136,6 +226,12 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
LSTMInferece<T>(has_seq_length, handle, seq_length, &rnn, x_data,
init_h_data, init_c_data, w_data, out_data, last_h_data,
last_c_data, &workspace_data_, workspace_size);
if (!w_initialized && ctx.HasInput("W") && ctx.HasInput("WeightList")) {
auto *W = const_cast<Tensor *>(ctx.Input<Tensor>("W"));
auto weight_list = ctx.MultiInput<framework::Tensor>("WeightList");
W->mutable_data<T>({weight_numel}, place);
weight_to_tensor<T>(place, stream, weight_list, W);
}
} else {
if (!has_seq_length) {
// for train
......@@ -176,11 +272,11 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *input = ctx.Input<Tensor>("Input");
auto *weight = ctx.Input<Tensor>("W");
auto *init_h = ctx.Input<Tensor>("InitH");
auto *init_c = ctx.Input<Tensor>("InitC");
auto *reserve = ctx.Input<Tensor>("Reserve");
auto *state_out = ctx.Input<Tensor>("StateOut");
auto weight_list = ctx.MultiInput<Tensor>("WeightList");
auto *out = ctx.Input<Tensor>("Out");
auto *out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
......@@ -188,9 +284,10 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
auto *last_c_grad = ctx.Input<Tensor>(framework::GradVarName("LastC"));
auto *in_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto *weight_grad = ctx.Output<Tensor>(framework::GradVarName("W"));
auto *init_h_grad = ctx.Output<Tensor>(framework::GradVarName("InitH"));
auto *init_c_grad = ctx.Output<Tensor>(framework::GradVarName("InitC"));
auto weight_grad_list = ctx.MultiOutput<framework::Tensor>(
framework::GradVarName("WeightList"));
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
......@@ -199,7 +296,6 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
auto init_h_dims = init_h->dims();
auto init_c_dims = init_c->dims();
auto *weight_data = weight->data<T>();
auto *init_h_data = init_h->data<T>();
auto *init_c_data = init_c->data<T>();
auto *out_data = out->data<T>();
......@@ -207,18 +303,50 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
auto *last_h_grad_data = last_h_grad->data<T>();
auto *last_c_grad_data = last_c_grad->data<T>();
auto place = ctx.GetPlace();
int weight_numel = size_sum(weight_list);
bool continuous =
is_continuous<T, std::vector<const Tensor *>>(weight_list);
auto stream = reinterpret_cast<const platform::CUDADeviceContext &>(
ctx.device_context())
.stream();
Tensor weight_whole;
T *weight_data = nullptr;
if (!continuous) {
weight_whole.mutable_data<T>({weight_numel}, place);
weight_to_tensor<T>(place, stream, weight_list, &weight_whole);
weight_data = weight_whole.data<T>();
} else {
weight_data = const_cast<T *>(weight_list[0]->data<T>());
}
Tensor weight_grad;
math::SetConstant<paddle::platform::CUDADeviceContext, T> zero;
weight_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, weight_grad, static_cast<T>(0.0));
weight_grad.mutable_data<T>({weight_numel}, ctx.GetPlace());
zero(dev_ctx, &weight_grad, static_cast<T>(0.0));
T *weight_grad_data = weight_grad.data<T>();
int offset = 0;
for (size_t i = 0; i < weight_grad_list.size(); ++i) {
size_t len = weight_grad_list[i]->numel();
auto dim = weight_grad_list[i]->dims();
weight_grad_list[i]
->ShareDataWith(weight_grad.Slice(static_cast<int64_t>(offset),
static_cast<int64_t>(offset + len)))
.Resize(dim);
offset += len;
}
in_grad->mutable_data<T>(input_dims, ctx.GetPlace());
auto *in_grad_data = in_grad->data<T>();
init_h_grad->mutable_data<T>(init_h_dims, ctx.GetPlace());
auto *init_h_grad_data = init_h_grad->data<T>();
if (init_h_grad) init_h_grad->mutable_data<T>(init_h_dims, ctx.GetPlace());
auto *init_h_grad_data = init_h_grad ? init_h_grad->data<T>() : nullptr;
init_c_grad->mutable_data<T>(init_c_dims, ctx.GetPlace());
auto *init_c_grad_data = init_c_grad->data<T>();
if (init_c_grad) init_c_grad->mutable_data<T>(init_c_dims, ctx.GetPlace());
auto *init_c_grad_data = init_c_grad ? init_c_grad->data<T>() : nullptr;
float dropout_prob = ctx.Attr<float>("dropout_prob");
bool is_bidirec = ctx.Attr<bool>("is_bidirec");
......@@ -236,7 +364,6 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
int seq_length = input_dims[0];
int batch_size = input->dims()[1];
int input_size = input->dims()[2];
int weight_numel = weight->numel();
size_t workspace_size;
size_t reserve_size;
......@@ -268,8 +395,7 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
handle, rnn.rnn_desc(), seq_length, rnn.x_descs(), input->data<T>(),
rnn.init_h_desc(), init_h->data<T>(), rnn.y_descs(), out->data<T>(),
workspace_data_.data<uint8_t>(), workspace_size, rnn.weight_desc(),
weight_grad->data<T>(), const_cast<uint8_t *>(reserve_data),
reserve_size));
weight_grad_data, const_cast<uint8_t *>(reserve_data), reserve_size));
} else {
#if CUDNN_VERSION >= 7201
// for train
......@@ -288,7 +414,7 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
handle, rnn.rnn_desc(), rnn.x_seq_desc(), input->data<T>(),
rnn.init_h_desc(), init_h->data<T>(), rnn.y_seq_desc(),
out->data<T>(), workspace_data_.data<uint8_t>(), workspace_size,
rnn.weight_desc(), weight_grad->data<T>(),
rnn.weight_desc(), weight_grad_data,
const_cast<uint8_t *>(reserve_data), reserve_size));
#else
PADDLE_THROW(platform::errors::Unavailable(
......
......@@ -2443,23 +2443,17 @@ def lstm(input,
input_shape = list(input.shape)
input_size = input_shape[-1]
weight_size = 0
num_dirrection = 2 if is_bidirec == True else 1
for i in range(num_layers):
if i == 0:
input_weight_size = (input_size * hidden_size) * 4
input_weight_size = (input_size * hidden_size) * 4 * num_dirrection
else:
if is_bidirec:
input_weight_size = (hidden_size * 2 * hidden_size) * 4
else:
input_weight_size = (hidden_size * hidden_size) * 4
input_weight_size = (hidden_size * hidden_size) * 4 * num_dirrection
hidden_weight_size = (hidden_size * hidden_size) * 4 * num_dirrection
hidden_weight_size = (hidden_size * hidden_size) * 4
if is_bidirec:
weight_size += (input_weight_size + hidden_weight_size) * 2
weight_size += hidden_size * 8 * 2
else:
weight_size += input_weight_size + hidden_weight_size
weight_size += hidden_size * 8
weight_size += input_weight_size + hidden_weight_size
weight_size += hidden_size * 8 * num_dirrection
weight = helper.create_parameter(
attr=helper.param_attr,
......
......@@ -20,14 +20,44 @@ import math
import paddle.fluid.core as core
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import random
random.seed(2)
np.set_printoptions(threshold=np.inf)
paddle.enable_static()
SIGMOID_THRESHOLD_MIN = -40.0
SIGMOID_THRESHOLD_MAX = 13.0
EXP_MAX_INPUT = 40.0
class RandomWeight:
def __init__(self):
pass
def updata_weight(self, hidden_size, input_size, dtype):
std = 1.0 / math.sqrt(hidden_size)
self.hidden_size = hidden_size
self.input_size = input_size
self.dtype = dtype
self.weight_ih = np.random.uniform(
low=-std, high=std, size=(4 * self.hidden_size,
self.input_size)).astype(dtype)
self.weight_hh = np.random.uniform(
low=-std, high=std, size=(4 * self.hidden_size,
self.hidden_size)).astype(dtype)
self.bias_ih = np.random.uniform(
low=-std, high=std, size=(4 * self.hidden_size)).astype(dtype)
self.bias_hh = np.random.uniform(
low=-std, high=std, size=(4 * self.hidden_size)).astype(dtype)
weight = RandomWeight()
class LayerMixin(object):
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
......@@ -51,16 +81,13 @@ class LSTMCell(LayerMixin):
self.bias = bias
self.dtype = np.float64
self.parameters = dict()
std = 1.0 / math.sqrt(hidden_size)
self.weight_ih = np.ones(
(4 * hidden_size, input_size), dtype=self.dtype)
self.weight_hh = np.ones((4 * hidden_size,
hidden_size)).astype(self.dtype)
self.weight_ih = weight.weight_ih
self.weight_hh = weight.weight_hh
self.parameters['weight_ih'] = self.weight_ih
self.parameters['weight_hh'] = self.weight_hh
if bias:
self.bias_ih = np.ones((4 * hidden_size)).astype(self.dtype)
self.bias_hh = np.ones((4 * hidden_size)).astype(self.dtype)
self.bias_ih = weight.bias_ih
self.bias_hh = weight.bias_hh
self.parameters['bias_ih'] = self.bias_ih
self.parameters['bias_hh'] = self.bias_hh
else:
......@@ -353,24 +380,26 @@ class LSTM(RNNMixin):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNLstmOp(OpTest):
#TODO(GaoWei8): Need to satisfy the result through the new interface
def get_weight_names(self):
weight_names = []
for i in range(2 * self.num_layers):
weight_names.append('weight{}'.format(i))
for i in range(2 * self.num_layers):
weight_names.append('bias{}'.format(i))
return weight_names
def setUp(self):
self.op_type = "cudnn_lstm"
self.dtype = np.float64
self.sequence_length = np.array([12, 11, 10, 9, 8], dtype=np.int32)
self.num_layers = 1
self.set_attrs()
seq_length = 12
batch_size = 5
input_size = 21
hidden_size = 21
input_weight_size = (hidden_size * hidden_size) * 4
hidden_weight_size = (hidden_size * hidden_size) * 4
weight_size = input_weight_size + hidden_weight_size
weight_size += hidden_size * 8
weight_size *= self.num_layers
input = np.random.uniform(
low=-0.1, high=0.1,
size=(seq_length, batch_size, input_size)).astype(self.dtype)
......@@ -379,17 +408,39 @@ class TestCUDNNLstmOp(OpTest):
input[9][3:][:] = 0
input[8][4:][:] = 0
weight.updata_weight(hidden_size, input_size, self.dtype)
rnn1 = LSTM(
input_size,
hidden_size,
self.num_layers,
num_layers=self.num_layers,
time_major=True,
direction="forward")
output, (last_hidden, last_cell) = rnn1(
input, sequence_length=self.sequence_length)
flat_w = np.ones((weight_size)).astype(self.dtype)
flat_w = []
num = 0
for i in range(self.num_layers):
if i == 0:
weight_ih = weight.weight_ih
else:
weight_ih = weight.weight_hh
flat_w.append(("weight" + str(num), weight_ih))
num += 1
for i in range(self.num_layers):
weight_hh = weight.weight_hh
flat_w.append(("weight" + str(num), weight_hh))
num += 1
num = 0
for i in range(self.num_layers):
bias_ih = weight.bias_ih
flat_w.append(("bias" + str(num), bias_ih))
num += 1
for i in range(self.num_layers):
bias_hh = weight.bias_hh
flat_w.append(("bias" + str(num), bias_hh))
num += 1
init_h = np.zeros((self.num_layers, batch_size,
hidden_size)).astype(self.dtype)
init_c = np.zeros((self.num_layers, batch_size,
......@@ -398,7 +449,7 @@ class TestCUDNNLstmOp(OpTest):
self.inputs = {
'Input': input,
'W': flat_w,
'WeightList': flat_w,
'InitH': init_h,
'InitC': init_c,
'SequenceLength': self.sequence_length
......@@ -408,7 +459,7 @@ class TestCUDNNLstmOp(OpTest):
'is_bidirec': False,
'input_size': input_size,
'hidden_size': hidden_size,
'num_layers': 1,
'num_layers': self.num_layers,
}
self.outputs = {
'Out': output,
......@@ -428,16 +479,42 @@ class TestCUDNNLstmOp(OpTest):
def test_grad_with_place(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place,
set(['Input', 'W', 'InitH', 'InitC']),
['Out', 'LastH', 'LastC'])
var_name_list = self.get_weight_names()
for var_name in var_name_list:
self.check_grad_with_place(
place,
set(['Input', var_name, 'InitH', 'InitC']),
['Out', 'LastH', 'LastC'])
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestCUDNNLstmOp2(TestCUDNNLstmOp):
def set_attrs(self):
self.num_layers = 2
class TestCUDNNlstmAPI(unittest.TestCase):
def test_lstm(self):
seq_len = 20
batch_size = 5
hidden_size = 20
dropout_prob = 0.0
num_layers = 1
input = fluid.data(
name='input',
shape=[seq_len, batch_size, hidden_size],
dtype='float64')
init_h = layers.fill_constant([num_layers, batch_size, hidden_size],
'float64', 0.0)
init_c = layers.fill_constant([num_layers, batch_size, hidden_size],
'float64', 0.0)
rnn_out, last_h, last_c = layers.lstm(input, init_h, init_c, seq_len,
hidden_size, num_layers,
dropout_prob, False)
exe = fluid.Executor(fluid.CUDAPlace(0))
exe.run(fluid.default_startup_program())
input_i = np.random.uniform(
low=-0.1, high=0.1, size=(seq_len, batch_size,
hidden_size)).astype("float64")
out = exe.run(fluid.default_main_program(),
feed={'input': input_i},
fetch_list=[rnn_out, last_h, last_c, 'cudnn_lstm_0.w_0'])
@unittest.skipIf(not core.is_compiled_with_cuda(),
......@@ -448,7 +525,7 @@ class TestCUDNNlstmAPI(unittest.TestCase):
batch_size = 5
hidden_size = 20
dropout_prob = 0.0
num_layers = 1
num_layers = 2
input = fluid.data(
name='input',
shape=[seq_len, batch_size, hidden_size],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册