未验证 提交 fa9d3fa5 编写于 作者: G Guo Sheng 提交者: GitHub

Incorporate cudnn_lstm into LSTM api (#27217)

* Incorporate cudnn_lstm into LSTM api.
test=develop

* Make coalesce_tensor support alignment optionally.
test=develop

* Reorganize RNN apis. test=develop

* Fix cudnn rnn layout conversion.
test=develop

* Add sequence_length support for RNN cudnn implement.
Add optional init_h and init_c gradient for cudnn_lstm_op.
test=develop

* Use create_parameter for rnn cudnn impl.
test=develop

* Move `self._flat_weight = self.create_parameter()` in RNNBase to main_program.
test=develop

* Update RNN api unittest to use set_device.
test=develop

* Fix set_place for unit tests of RNN apis.
test=develop

* Fix use_align in coalesce_tensor_op.
test=develop

* Adjust RNN apis arguments according to comments.
test=develop

* Polish documents for SimpleRNN apis.
test=develop

* Refine random seed in cudnn_lstm_op.
Expose rnn params from sublayers to RNN.
test=develop

* Fix RNN saving for jit.save.
Refine cudnn_lstm dropout behavior.
test=develop

* Fix doc of GRU. test=develop

* Use ShareDataWith to avoid copying for cudnn_lstm_op test.
test=develop

* Remove updates on cudnn_lstm temporarily.
test=develop

* Use ShareDataWith to avoid copying for cudnn_lstm_op test.
test=develop

* Refine random seed in cudnn_lstm_op.
test=develop

* Fix test_lstm by adjust ConcreteProgram buffer getter.
test=develop

* Use create_parameter instead of create_var for rnn._flat_weight for static graph usage.
test=develop

* Remove W input for cudnn_lstm to pass unused_var_check.
test=develop

* Add test_predict for RNN unit tests coverage.
test=develop

* Fix code style of rnn.
test=develop

* Fix F.rnn usage in rnn.py.
test=develop
上级 78b1026f
...@@ -67,6 +67,7 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> { ...@@ -67,6 +67,7 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
} }
auto in_tensors = context.MultiInput<framework::LoDTensor>("Input"); auto in_tensors = context.MultiInput<framework::LoDTensor>("Input");
bool use_align = context.Attr<bool>("use_align");
if (context.Attr<bool>("check_name")) { if (context.Attr<bool>("check_name")) {
for (size_t i = 0; i < in_var_names.size(); ++i) { for (size_t i = 0; i < in_var_names.size(); ++i) {
...@@ -93,7 +94,7 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> { ...@@ -93,7 +94,7 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
context.Attr<int>("dtype")); context.Attr<int>("dtype"));
size_t size_of_dtype = framework::SizeOfType(dtype); size_t size_of_dtype = framework::SizeOfType(dtype);
GetMemSizeAndDtype(in_tensors, in_var_names, &numel, size_of_dtype, GetMemSizeAndDtype(in_tensors, in_var_names, &numel, size_of_dtype,
context.GetPlace()); context.GetPlace(), use_align);
// Alloc the continuous space // Alloc the continuous space
auto fused_tensor = context.Output<framework::LoDTensor>("FusedOutput"); auto fused_tensor = context.Output<framework::LoDTensor>("FusedOutput");
...@@ -111,8 +112,11 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> { ...@@ -111,8 +112,11 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
framework::TensorCopy(*in_tensors[i], context.GetPlace(), dev_ctx, framework::TensorCopy(*in_tensors[i], context.GetPlace(), dev_ctx,
&sub_tensor); &sub_tensor);
offset += platform::Alignment(len * size_of_dtype, context.GetPlace()) / offset +=
size_of_dtype; use_align
? platform::Alignment(len * size_of_dtype, context.GetPlace()) /
size_of_dtype
: len;
} }
} else if (context.Attr<bool>("set_constant")) { } else if (context.Attr<bool>("set_constant")) {
math::SetConstant<DeviceContext, T> set_constant; math::SetConstant<DeviceContext, T> set_constant;
...@@ -131,8 +135,10 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> { ...@@ -131,8 +135,10 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
->ShareDataWith(fused_tensor->Slice( ->ShareDataWith(fused_tensor->Slice(
static_cast<int64_t>(offset), static_cast<int64_t>(offset + len))) static_cast<int64_t>(offset), static_cast<int64_t>(offset + len)))
.Resize(dim); .Resize(dim);
len = platform::Alignment(len * size_of_dtype, context.GetPlace()) / len = use_align
size_of_dtype; ? platform::Alignment(len * size_of_dtype, context.GetPlace()) /
size_of_dtype
: len;
offset += len; offset += len;
ss << "output(" << out_var_names[i] << ") dim:(" << dim << ")" ss << "output(" << out_var_names[i] << ") dim:(" << dim << ")"
<< " address: " << out_tensors[i]->data<void>() << ", "; << " address: " << out_tensors[i]->data<void>() << ", ";
...@@ -144,7 +150,8 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> { ...@@ -144,7 +150,8 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
void GetMemSizeAndDtype( void GetMemSizeAndDtype(
const std::vector<const framework::LoDTensor *> &lod_tensors, const std::vector<const framework::LoDTensor *> &lod_tensors,
const std::vector<std::string> var_names, size_t *numel, const std::vector<std::string> var_names, size_t *numel,
const size_t &size_of_dtype, const platform::Place &place) const { const size_t &size_of_dtype, const platform::Place &place,
const bool use_align = true) const {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
lod_tensors.size(), var_names.size(), lod_tensors.size(), var_names.size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -167,9 +174,11 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> { ...@@ -167,9 +174,11 @@ class CoalesceTensorOpKernel : public framework::OpKernel<T> {
ss << "input(" << var_names[i] << ") dim:(" << lod_tensors[i]->dims() ss << "input(" << var_names[i] << ") dim:(" << lod_tensors[i]->dims()
<< ") " << ") "
<< " addres:" << lod_tensors[i]->data<void>() << ", "; << " addres:" << lod_tensors[i]->data<void>() << ", ";
*numel += platform::Alignment(static_cast<size_t>(size) * size_of_dtype, *numel += use_align
place) / ? platform::Alignment(
size_of_dtype; static_cast<size_t>(size) * size_of_dtype, place) /
size_of_dtype
: static_cast<size_t>(size);
} }
VLOG(10) << ss.str(); VLOG(10) << ss.str();
...@@ -223,6 +232,10 @@ class CoalesceTensorOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -223,6 +232,10 @@ class CoalesceTensorOpMaker : public framework::OpProtoAndCheckerMaker {
"Whether to check the name of Input and Output to ensure " "Whether to check the name of Input and Output to ensure "
"they are the same separately.") "they are the same separately.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("use_align",
"Whether to consider memory chunk and take alignment into "
"account for inputs and outputs.")
.SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
CoalesceTensor Operator. CoalesceTensor Operator.
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/cudnn_lstm_cache.h" #include "paddle/fluid/operators/cudnn_lstm_cache.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -156,6 +157,21 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> { ...@@ -156,6 +157,21 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
bool is_test = ctx.Attr<bool>("is_test"); bool is_test = ctx.Attr<bool>("is_test");
int seed = ctx.Attr<int>("seed"); int seed = ctx.Attr<int>("seed");
if (!is_test) {
int device_id =
BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed == 0) {
// If perform `manual_seed` in python and inner seed is not specified
// (equals 0), use global generator generated seed.
seed = static_cast<int>(gen_cuda->Random64());
} else if (seed == 0) {
// use random generated seed
std::random_device rd;
seed = rd();
} // else use `ctx.Attr<int>("seed")` specified seed
}
bool has_seq_length = ctx.HasInput("SequenceLength"); bool has_seq_length = ctx.HasInput("SequenceLength");
std::vector<int> SequenceLength; std::vector<int> SequenceLength;
if (has_seq_length) { if (has_seq_length) {
...@@ -194,13 +210,25 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> { ...@@ -194,13 +210,25 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
if (!continuous) { if (!continuous) {
LOG_FIRST_N(WARNING, 2) LOG_FIRST_N(WARNING, 2)
<< "If the memory space of the Input WeightList is not " << "If the memory space of the Input WeightList is not continuous, "
"continuous, less efficient calculation will be " "less efficient calculation will be called. Please call "
"called. Please call coalesce_tensor op to make the " "flatten_parameters() to make the input memory continuous.";
"input memory continuous.";
weight_whole.mutable_data<T>({weight_numel}, place); weight_whole.mutable_data<T>({weight_numel}, place);
weight_to_tensor<T>(place, stream, weight_list, &weight_whole); weight_to_tensor<T>(place, stream, weight_list, &weight_whole);
w_data = weight_whole.data<T>(); w_data = weight_whole.data<T>();
if (is_test) { // maybe also reset small weights' ptr for training
int offset = 0;
for (size_t i = 0; i < weight_list.size(); ++i) {
size_t len = weight_list[i]->numel();
auto dim = weight_list[i]->dims();
const_cast<Tensor *>(weight_list[i])
->ShareDataWith(
weight_whole.Slice(static_cast<int64_t>(offset),
static_cast<int64_t>(offset + len)))
.Resize(dim);
offset += len;
}
}
} else { } else {
w_data = const_cast<T *>(weight_list[0]->data<T>()); w_data = const_cast<T *>(weight_list[0]->data<T>());
} }
...@@ -226,12 +254,6 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> { ...@@ -226,12 +254,6 @@ class CudnnLSTMGPUKernel : public framework::OpKernel<T> {
LSTMInferece<T>(has_seq_length, handle, seq_length, &rnn, x_data, LSTMInferece<T>(has_seq_length, handle, seq_length, &rnn, x_data,
init_h_data, init_c_data, w_data, out_data, last_h_data, init_h_data, init_c_data, w_data, out_data, last_h_data,
last_c_data, &workspace_data_, workspace_size); last_c_data, &workspace_data_, workspace_size);
if (!w_initialized && ctx.HasInput("W") && ctx.HasInput("WeightList")) {
auto *W = const_cast<Tensor *>(ctx.Input<Tensor>("W"));
auto weight_list = ctx.MultiInput<framework::Tensor>("WeightList");
W->mutable_data<T>({weight_numel}, place);
weight_to_tensor<T>(place, stream, weight_list, W);
}
} else { } else {
if (!has_seq_length) { if (!has_seq_length) {
// for train // for train
......
...@@ -89,6 +89,7 @@ REGISTER_OP_CPU_KERNEL( ...@@ -89,6 +89,7 @@ REGISTER_OP_CPU_KERNEL(
save, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, float>, save, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, double>, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int>, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int16_t>, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -21,6 +21,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -21,6 +21,7 @@ REGISTER_OP_CUDA_KERNEL(
save, ops::SaveOpKernel<paddle::platform::CUDADeviceContext, float>, save, ops::SaveOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, double>, ops::SaveOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, int>, ops::SaveOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::SaveOpKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, int64_t>, ops::SaveOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SaveOpKernel<paddle::platform::CUDADeviceContext, ops::SaveOpKernel<paddle::platform::CUDADeviceContext,
......
...@@ -592,9 +592,8 @@ class ConcreteProgram(object): ...@@ -592,9 +592,8 @@ class ConcreteProgram(object):
inputs = tuple([class_instance] + list(inputs)) inputs = tuple([class_instance] + list(inputs))
# 2. Gets all ParamBases and buffered VarBases in the function # 2. Gets all ParamBases and buffered VarBases in the function
all_parameters_and_buffers = list( all_parameters_and_buffers = _extract_indeed_params_buffers(
get_parameters(class_instance).values()) + list( class_instance)
get_buffers(class_instance).values())
# 3. Builds program only once and returns the output Variables. # 3. Builds program only once and returns the output Variables.
with param_guard(get_parameters( with param_guard(get_parameters(
...@@ -622,6 +621,17 @@ class ConcreteProgram(object): ...@@ -622,6 +621,17 @@ class ConcreteProgram(object):
startup_program=startup_program) startup_program=startup_program)
def _extract_indeed_params_buffers(class_instance):
"""
To filter not initialzed buffers.
"""
params = list(get_parameters(class_instance).values())
buffers = list(get_buffers(class_instance).values())
buffers = [buffer for buffer in buffers if buffer.shape != []]
return params + buffers
class ProgramCache(object): class ProgramCache(object):
""" """
Wrapper class for the program functions defined by dygraph function. Wrapper class for the program functions defined by dygraph function.
......
...@@ -29,11 +29,13 @@ class TestSimpleRNN(unittest.TestCase): ...@@ -29,11 +29,13 @@ class TestSimpleRNN(unittest.TestCase):
self.time_major = time_major self.time_major = time_major
self.direction = direction self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1 self.num_directions = 2 if direction == "bidirectional" else 1
self.place = paddle.CPUPlace() if place == "cpu" \ self.place = place
else paddle.CUDAPlace(0)
def setUp(self): def setUp(self):
paddle.disable_static(self.place) # Since `set_device` is global, set `set_device` in `setUp` rather than
# `__init__` to avoid using an error device set by another test case.
place = paddle.set_device(self.place)
paddle.disable_static(place)
rnn1 = SimpleRNN( rnn1 = SimpleRNN(
16, 32, 2, time_major=self.time_major, direction=self.direction) 16, 32, 2, time_major=self.time_major, direction=self.direction)
rnn2 = paddle.nn.SimpleRNN( rnn2 = paddle.nn.SimpleRNN(
...@@ -103,11 +105,13 @@ class TestGRU(unittest.TestCase): ...@@ -103,11 +105,13 @@ class TestGRU(unittest.TestCase):
self.time_major = time_major self.time_major = time_major
self.direction = direction self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1 self.num_directions = 2 if direction == "bidirectional" else 1
self.place = paddle.CPUPlace() if place == "cpu" \ self.place = place
else paddle.CUDAPlace(0)
def setUp(self): def setUp(self):
paddle.disable_static(self.place) # Since `set_device` is global, set `set_device` in `setUp` rather than
# `__init__` to avoid using an error device set by another test case.
place = paddle.set_device(self.place)
paddle.disable_static(place)
rnn1 = GRU(16, rnn1 = GRU(16,
32, 32,
2, 2,
...@@ -183,11 +187,13 @@ class TestLSTM(unittest.TestCase): ...@@ -183,11 +187,13 @@ class TestLSTM(unittest.TestCase):
self.time_major = time_major self.time_major = time_major
self.direction = direction self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1 self.num_directions = 2 if direction == "bidirectional" else 1
self.place = paddle.CPUPlace() if place == "cpu" \ self.place = place
else paddle.CUDAPlace(0)
def setUp(self): def setUp(self):
paddle.disable_static(self.place) # Since `set_device` is global, set `set_device` in `setUp` rather than
# `__init__` to avoid using an error device set by another test case.
place = paddle.set_device(self.place)
paddle.disable_static(place)
rnn1 = LSTM( rnn1 = LSTM(
16, 32, 2, time_major=self.time_major, direction=self.direction) 16, 32, 2, time_major=self.time_major, direction=self.direction)
rnn2 = paddle.nn.LSTM( rnn2 = paddle.nn.LSTM(
...@@ -251,10 +257,68 @@ class TestLSTM(unittest.TestCase): ...@@ -251,10 +257,68 @@ class TestLSTM(unittest.TestCase):
np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5) np.testing.assert_allclose(h1, h2.numpy(), atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5) np.testing.assert_allclose(c1, c2.numpy(), atol=1e-8, rtol=1e-5)
def test_predict(self):
place = paddle.set_device(self.place)
paddle.manual_seed(123)
np.random.seed(123)
class Net(paddle.nn.Layer):
def __init__(self):
super(Net, self).__init__()
self.rnn1 = paddle.nn.LSTM(
16, 32, 2, direction="bidirectional", dropout=0.1)
def forward(self, input):
return self.rnn1(input)
x = paddle.randn((4, 10, 16))
x.stop_gradient = False
seq_len = paddle.to_tensor(np.array([10, 6, 8, 5]))
mask = sequence_mask(seq_len, maxlen=10, dtype=x.dtype)
mask = paddle.unsqueeze(mask, [2])
rnn = Net()
y, (h, c) = rnn(x)
y = y * mask
loss = paddle.mean(y)
loss.backward()
optimizer = paddle.optimizer.Adam(
learning_rate=0.1, parameters=rnn.parameters())
optimizer.step()
rnn.eval()
y, (h, c) = rnn(x)
# `jit.to_static` would include a train_program, eval mode might cause
# some errors currently, such as dropout grad op gets `is_test == True`.
rnn.train()
rnn = paddle.jit.to_static(
rnn,
[paddle.static.InputSpec(
shape=[None, None, 16], dtype=x.dtype)])
paddle.jit.save(rnn, "./inference/lstm_infer")
paddle.enable_static()
new_scope = paddle.static.Scope()
with paddle.static.scope_guard(new_scope):
exe = paddle.static.Executor(place)
[inference_program, feed_target_names,
fetch_targets] = paddle.static.load_inference_model(
dirname="./inference",
executor=exe,
model_filename="lstm_infer.pdmodel",
params_filename="lstm_infer.pdiparams")
results = exe.run(inference_program,
feed={feed_target_names[0]: x.numpy()},
fetch_list=fetch_targets)
np.testing.assert_equal(
y.numpy(), results[0]) # eval results equal predict results
paddle.disable_static()
def runTest(self): def runTest(self):
self.test_with_initial_state() self.test_with_initial_state()
self.test_with_zero_state() self.test_with_zero_state()
self.test_with_input_lengths() self.test_with_input_lengths()
self.test_predict()
def load_tests(loader, tests, pattern): def load_tests(loader, tests, pattern):
......
...@@ -30,10 +30,12 @@ class TestSimpleRNN(unittest.TestCase): ...@@ -30,10 +30,12 @@ class TestSimpleRNN(unittest.TestCase):
self.time_major = time_major self.time_major = time_major
self.direction = direction self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1 self.num_directions = 2 if direction == "bidirectional" else 1
self.place = paddle.CPUPlace() if place == "cpu" \ self.place = place
else paddle.CUDAPlace(0)
def setUp(self): def setUp(self):
# Since `set_device` is global, set `set_device` in `setUp` rather than
# `__init__` to avoid using an error device set by another test case.
place = paddle.set_device(self.place)
rnn1 = SimpleRNN( rnn1 = SimpleRNN(
16, 32, 2, time_major=self.time_major, direction=self.direction) 16, 32, 2, time_major=self.time_major, direction=self.direction)
...@@ -48,7 +50,6 @@ class TestSimpleRNN(unittest.TestCase): ...@@ -48,7 +50,6 @@ class TestSimpleRNN(unittest.TestCase):
time_major=self.time_major, time_major=self.time_major,
direction=self.direction) direction=self.direction)
place = self.place
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
scope = paddle.fluid.Scope() scope = paddle.fluid.Scope()
with paddle.static.scope_guard(scope): with paddle.static.scope_guard(scope):
...@@ -172,10 +173,12 @@ class TestGRU(unittest.TestCase): ...@@ -172,10 +173,12 @@ class TestGRU(unittest.TestCase):
self.time_major = time_major self.time_major = time_major
self.direction = direction self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1 self.num_directions = 2 if direction == "bidirectional" else 1
self.place = paddle.CPUPlace() if place == "cpu" \ self.place = place
else paddle.CUDAPlace(0)
def setUp(self): def setUp(self):
# Since `set_device` is global, set `set_device` in `setUp` rather than
# `__init__` to avoid using an error device set by another test case.
place = paddle.set_device(self.place)
rnn1 = GRU(16, rnn1 = GRU(16,
32, 32,
2, 2,
...@@ -192,7 +195,6 @@ class TestGRU(unittest.TestCase): ...@@ -192,7 +195,6 @@ class TestGRU(unittest.TestCase):
time_major=self.time_major, time_major=self.time_major,
direction=self.direction) direction=self.direction)
place = self.place
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
scope = paddle.fluid.Scope() scope = paddle.fluid.Scope()
with paddle.static.scope_guard(scope): with paddle.static.scope_guard(scope):
...@@ -316,10 +318,12 @@ class TestLSTM(unittest.TestCase): ...@@ -316,10 +318,12 @@ class TestLSTM(unittest.TestCase):
self.time_major = time_major self.time_major = time_major
self.direction = direction self.direction = direction
self.num_directions = 2 if direction == "bidirectional" else 1 self.num_directions = 2 if direction == "bidirectional" else 1
self.place = paddle.CPUPlace() if place == "cpu" \ self.place = place
else paddle.CUDAPlace(0)
def setUp(self): def setUp(self):
# Since `set_device` is global, set `set_device` in `setUp` rather than
# `__init__` to avoid using an error device set by another test case.
place = paddle.set_device(self.place)
rnn1 = LSTM( rnn1 = LSTM(
16, 32, 2, time_major=self.time_major, direction=self.direction) 16, 32, 2, time_major=self.time_major, direction=self.direction)
...@@ -334,7 +338,6 @@ class TestLSTM(unittest.TestCase): ...@@ -334,7 +338,6 @@ class TestLSTM(unittest.TestCase):
time_major=self.time_major, time_major=self.time_major,
direction=self.direction) direction=self.direction)
place = self.place
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
scope = paddle.fluid.Scope() scope = paddle.fluid.Scope()
with paddle.static.scope_guard(scope): with paddle.static.scope_guard(scope):
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册