未验证 提交 7039c06d 编写于 作者: P pangyoki 提交者: GitHub

[NPU] Support npu save load (#31893)

* support save load for NPU

* add save load npu unittest

* support np.array transform in NPU

* fix errors

* delete dygraph in unittest

* add Wait

* fix unittest

* fix review comment

* fix unittest problem

* fix little problem
上级 853af66f
......@@ -822,6 +822,29 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
#else
PADDLE_THROW(platform::errors::Unimplemented(
"XPUPlace is not supported when not compiled with XPU"));
#endif
} else if (platform::is_npu_place(tensor.place())) {
#ifdef PADDLE_WITH_ASCEND_CL
constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB
std::unique_ptr<char[]> buf(new char[kBufSize]);
auto& npu_dev_ctx =
static_cast<const platform::NPUDeviceContext&>(dev_ctx);
platform::CPUPlace cpu;
uintptr_t data = reinterpret_cast<uintptr_t>(data_ptr);
while (size != 0) {
size_t size_to_write = std::min(kBufSize, static_cast<size_t>(size));
memory::Copy(cpu, buf.get(),
BOOST_GET_CONST(platform::NPUPlace, tensor.place()),
reinterpret_cast<const void*>(data), size_to_write,
npu_dev_ctx.stream());
npu_dev_ctx.Wait();
os.write(buf.get(), size_to_write);
data += size_to_write;
size -= size_to_write;
}
#else
PADDLE_THROW(platform::errors::Unimplemented(
"NPUPlace is not supported when not compiled with NPU"));
#endif
} else {
os.write(static_cast<const char*>(data_ptr),
......@@ -877,8 +900,10 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
auto ctx = platform::CPUDeviceContext();
size_t size = tensor->numel() * framework::SizeOfType(desc.data_type());
if (platform::is_gpu_place(dev_ctx.GetPlace()) ||
platform::is_xpu_place(dev_ctx.GetPlace())) {
#if defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU
platform::is_xpu_place(dev_ctx.GetPlace()) ||
platform::is_npu_place(dev_ctx.GetPlace())) {
#if defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU || \
defined PADDLE_WITH_ASCEND_CL
Tensor cpu_tensor;
cpu_tensor.Resize(framework::make_ddim(shape));
framework::VisitDataType(
......@@ -887,13 +912,19 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
is.read(static_cast<char*>(buf), size);
auto dst_place = dev_ctx.GetPlace();
framework::TensorCopy(cpu_tensor, dst_place, dev_ctx, tensor);
if (platform::is_npu_place(dev_ctx.GetPlace())) {
dev_ctx.Wait();
}
#else
if (platform::is_gpu_place(dev_ctx.GetPlace())) {
PADDLE_THROW(platform::errors::Unimplemented(
"CUDAPlace is not supported when not compiled with CUDA"));
} else {
} else if (platform::is_xpu_place(dev_ctx.GetPlace())) {
PADDLE_THROW(platform::errors::Unimplemented(
"XPUPlace is not supported when not compiled with XPU"));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"NPUPlace is not supported when not compiled with NPU"));
}
#endif
} else {
......@@ -934,8 +965,10 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
auto ctx = platform::CPUDeviceContext();
size_t size = tensor->numel() * framework::SizeOfType(desc.data_type());
if (platform::is_gpu_place(dev_ctx.GetPlace()) ||
platform::is_xpu_place(dev_ctx.GetPlace())) {
#if defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU
platform::is_xpu_place(dev_ctx.GetPlace()) ||
platform::is_npu_place(dev_ctx.GetPlace())) {
#if defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU || \
defined PADDLE_WITH_ASCEND_CL
Tensor cpu_tensor;
cpu_tensor.Resize(framework::make_ddim(dims));
framework::VisitDataType(
......@@ -944,13 +977,19 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
is.read(static_cast<char*>(buf), size);
auto dst_place = dev_ctx.GetPlace();
framework::TensorCopy(cpu_tensor, dst_place, dev_ctx, tensor);
if (platform::is_npu_place(dev_ctx.GetPlace())) {
dev_ctx.Wait();
}
#else
if (platform::is_gpu_place(dev_ctx.GetPlace())) {
PADDLE_THROW(platform::errors::Unimplemented(
"CUDAPlace is not supported when not compiled with CUDA"));
} else {
} else if (platform::is_xpu_place(dev_ctx.GetPlace())) {
PADDLE_THROW(platform::errors::Unimplemented(
"XPUPlace is not supported when not compiled with XPU"));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"NPUPlace is not supported when not compiled with NPU"));
}
#endif
} else {
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/load_combine_op.h"
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
load_combine,
ops::LoadCombineOpKernel<paddle::platform::NPUDeviceContext, float>,
ops::LoadCombineOpKernel<paddle::platform::NPUDeviceContext, double>,
ops::LoadCombineOpKernel<paddle::platform::NPUDeviceContext, int>,
ops::LoadCombineOpKernel<paddle::platform::NPUDeviceContext, int8_t>,
ops::LoadCombineOpKernel<paddle::platform::NPUDeviceContext, int64_t>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/load_op.h"
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
load, ops::LoadOpKernel<paddle::platform::NPUDeviceContext, float>,
ops::LoadOpKernel<paddle::platform::NPUDeviceContext, double>,
ops::LoadOpKernel<paddle::platform::NPUDeviceContext, int>,
ops::LoadOpKernel<paddle::platform::NPUDeviceContext, int8_t>,
ops::LoadOpKernel<paddle::platform::NPUDeviceContext, int64_t>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/save_combine_op.h"
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
save_combine,
ops::SaveCombineOpKernel<paddle::platform::NPUDeviceContext, float>,
ops::SaveCombineOpKernel<paddle::platform::NPUDeviceContext, double>,
ops::SaveCombineOpKernel<paddle::platform::NPUDeviceContext, int>,
ops::SaveCombineOpKernel<paddle::platform::NPUDeviceContext, int64_t>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/save_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
save, ops::SaveOpKernel<paddle::platform::NPUDeviceContext, float>,
ops::SaveOpKernel<paddle::platform::NPUDeviceContext, double>,
ops::SaveOpKernel<paddle::platform::NPUDeviceContext, int>,
ops::SaveOpKernel<paddle::platform::NPUDeviceContext, uint8_t>,
ops::SaveOpKernel<paddle::platform::NPUDeviceContext, int8_t>,
ops::SaveOpKernel<paddle::platform::NPUDeviceContext, int64_t>,
ops::SaveOpKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
......@@ -644,6 +644,7 @@ inline py::array TensorToPyArray(const framework::Tensor &tensor,
}
bool is_gpu_tensor = platform::is_gpu_place(tensor.place());
bool is_xpu_tensor = platform::is_xpu_place(tensor.place());
bool is_npu_tensor = platform::is_npu_place(tensor.place());
const auto &tensor_dims = tensor.dims();
auto tensor_dtype = tensor.type();
size_t sizeof_dtype = framework::SizeOfType(tensor_dtype);
......@@ -662,7 +663,7 @@ inline py::array TensorToPyArray(const framework::Tensor &tensor,
std::string py_dtype_str = details::TensorDTypeToPyDTypeStr(tensor.type());
if (!is_gpu_tensor && !is_xpu_tensor) {
if (!is_gpu_tensor && !is_xpu_tensor && !is_npu_tensor) {
if (!need_deep_copy) {
auto base = py::cast(std::move(tensor));
return py::array(py::dtype(py_dtype_str.c_str()), py_dims, py_strides,
......@@ -729,6 +730,34 @@ inline py::array TensorToPyArray(const framework::Tensor &tensor,
PADDLE_THROW(platform::errors::PermissionDenied(
"Cannot use CUDAPlace in CPU only version, "
"Please recompile or reinstall Paddle with CUDA support."));
#endif
} else if (is_npu_tensor) {
#ifdef PADDLE_WITH_ASCEND_CL
py::array py_arr(py::dtype(py_dtype_str.c_str()), py_dims, py_strides);
PADDLE_ENFORCE_EQ(py_arr.writeable(), true,
platform::errors::InvalidArgument(
"PyArray is not writable, in which case memory leak "
"or double free would occur"));
PADDLE_ENFORCE_EQ(
py_arr.owndata(), true,
platform::errors::InvalidArgument(
"PyArray does not own data, in which case memory leak "
"or double free would occur"));
size_t copy_bytes = sizeof_dtype * numel;
auto p = BOOST_GET_CONST(platform::NPUPlace, tensor.place());
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &ctx = *pool.Get(tensor.place());
paddle::memory::Copy(
platform::CPUPlace(), py_arr.mutable_data(), p, tensor_buf_ptr,
copy_bytes,
reinterpret_cast<const platform::NPUDeviceContext &>(ctx).stream());
ctx.Wait();
return py_arr;
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Cannot use NPUPlace in CPU/GPU/XPU version, "
"Please recompile or reinstall Paddle with NPU support."));
#endif
}
PADDLE_THROW(platform::errors::Unimplemented("Place is not supported"));
......
......@@ -1973,6 +1973,10 @@ def load(program, model_path, executor=None, var_list=None):
p = paddle.fluid.core.Place()
p.set_place(t._place())
place = paddle.fluid.XPUPlace(p.xpu_device_id())
elif p.is_npu_place():
p = paddle.fluid.core.Place()
p.set_place(t._place())
place = paddle.fluid.NPUPlace(p.npu_device_id())
else:
p = paddle.fluid.core.Place()
p.set_place(t._place())
......@@ -2115,8 +2119,8 @@ def load_program_state(model_path, var_list=None):
error_str = "Failed to load model/variables `%s`, please make sure " \
"model/variables file is saved with the following APIs: " \
"save_params, save_persistables, save_vars."
filenames = [var.name for var in vars
] if filename is None else filename
filenames = [var.name for var in
vars] if filename is None else filename
if raise_error:
raise RuntimeError(error_str % filenames)
else:
......@@ -2256,6 +2260,10 @@ def set_program_state(program, state_dict):
p = paddle.fluid.core.Place()
p.set_place(ten_place)
py_place = paddle.fluid.XPUPlace(p.xpu_device_id())
elif ten_place.is_npu_place():
p = paddle.fluid.core.Place()
p.set_place(ten_place)
py_place = paddle.fluid.NPUPlace(p.npu_device_id())
ten.set(new_para_np, py_place)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import sys
sys.path.append("..")
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.nn import Embedding
import paddle.fluid.framework as framework
from paddle.fluid.optimizer import Adam
from paddle.fluid.dygraph.base import to_variable
from test_imperative_base import new_program_scope
from paddle.fluid.executor import global_scope
import numpy as np
import six
import pickle
import os
import errno
from test_static_save_load import *
paddle.enable_static()
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestNPUSaveLoadBase(TestSaveLoadBase):
def set_place(self):
return fluid.CPUPlace() if not core.is_compiled_with_npu(
) else paddle.NPUPlace(0)
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestNPUSaveLoadPartial(TestSaveLoadPartial):
def set_place(self):
return fluid.CPUPlace() if not core.is_compiled_with_npu(
) else paddle.NPUPlace(0)
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestNPUSaveLoadSetStateDict(TestSaveLoadSetStateDict):
def set_place(self):
return fluid.CPUPlace() if not core.is_compiled_with_npu(
) else paddle.NPUPlace(0)
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestNPUProgramStatePartial(TestProgramStatePartial):
def set_place(self):
return fluid.CPUPlace() if not core.is_compiled_with_npu(
) else paddle.NPUPlace(0)
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestNPULoadFromOldInterface(TestLoadFromOldInterface):
def set_place(self):
return fluid.CPUPlace() if not core.is_compiled_with_npu(
) else paddle.NPUPlace(0)
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestNPULoadFromOldInterfaceSingleFile(TestLoadFromOldInterfaceSingleFile):
def set_place(self):
return fluid.CPUPlace() if not core.is_compiled_with_npu(
) else paddle.NPUPlace(0)
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestNPUProgramStateOldSave(TestProgramStateOldSave):
def setUp(self):
self.test_dygraph = False
def set_place(self):
return fluid.CPUPlace() if not core.is_compiled_with_npu(
) else paddle.NPUPlace(0)
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestNPUProgramStateOldSaveSingleModel(TestProgramStateOldSaveSingleModel):
def set_place(self):
return fluid.CPUPlace() if not core.is_compiled_with_npu(
) else paddle.NPUPlace(0)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -18,7 +18,7 @@ import unittest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.dygraph.nn import Embedding
from paddle.nn import Embedding
import paddle.fluid.framework as framework
from paddle.fluid.optimizer import Adam
from paddle.fluid.dygraph.base import to_variable
......@@ -30,6 +30,8 @@ import pickle
import os
import errno
paddle.enable_static()
class SimpleLSTMRNN(fluid.Layer):
def __init__(self,
......@@ -158,11 +160,10 @@ class PtbModel(fluid.Layer):
num_layers=num_layers,
init_scale=init_scale,
dropout=dropout)
self.embedding = Embedding(
size=[vocab_size, hidden_size],
dtype='float32',
is_sparse=False,
param_attr=fluid.ParamAttr(
self.embedding = paddle.nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=hidden_size,
weight_attr=fluid.ParamAttr(
name='embedding_para',
initializer=fluid.initializer.UniformInitializer(
low=-init_scale, high=init_scale)))
......@@ -186,6 +187,8 @@ class PtbModel(fluid.Layer):
init_c = fluid.layers.reshape(
init_cell, shape=[self.num_layers, -1, self.hidden_size])
# NPU 'tok_k' kernel only support `int32` dtype, so cast `input` from `int64` to `int32`.
input = fluid.layers.cast(input, "int32")
x_emb = self.embedding(input)
x_emb = fluid.layers.reshape(
x_emb, shape=[-1, self.num_steps, self.hidden_size])
......@@ -213,6 +216,10 @@ class PtbModel(fluid.Layer):
class TestSaveLoadBase(unittest.TestCase):
def set_place(self):
return fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
def test_ptb_rnn_cpu_float32(self):
seed = 90
hidden_size = 10
......@@ -234,8 +241,7 @@ class TestSaveLoadBase(unittest.TestCase):
num_steps=num_steps,
init_scale=init_scale)
place = fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
place = self.set_place()
exe = fluid.Executor(place)
sgd = Adam(learning_rate=1e-3)
x = fluid.layers.data(
......@@ -314,6 +320,10 @@ class TestSaveLoadBase(unittest.TestCase):
class TestSaveLoadPartial(unittest.TestCase):
def set_place(self):
return fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
def test_ptb_rnn_cpu_float32(self):
seed = 90
hidden_size = 10
......@@ -335,8 +345,7 @@ class TestSaveLoadPartial(unittest.TestCase):
num_steps=num_steps,
init_scale=init_scale)
place = fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
place = self.set_place()
exe = fluid.Executor(place)
sgd = Adam(learning_rate=1e-3)
x = fluid.layers.data(
......@@ -424,6 +433,10 @@ class TestSaveLoadPartial(unittest.TestCase):
class TestSaveLoadSetStateDict(unittest.TestCase):
def set_place(self):
return fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
def test_ptb_rnn_cpu_float32(self):
seed = 90
hidden_size = 10
......@@ -445,8 +458,7 @@ class TestSaveLoadSetStateDict(unittest.TestCase):
num_steps=num_steps,
init_scale=init_scale)
place = fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
place = self.set_place()
exe = fluid.Executor(place)
sgd = Adam(learning_rate=1e-3)
x = fluid.layers.data(
......@@ -525,6 +537,10 @@ class TestSaveLoadSetStateDict(unittest.TestCase):
class TestProgramStatePartial(unittest.TestCase):
def set_place(self):
return fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
def test_ptb_rnn_cpu_float32(self):
seed = 90
hidden_size = 10
......@@ -546,8 +562,7 @@ class TestProgramStatePartial(unittest.TestCase):
num_steps=num_steps,
init_scale=init_scale)
place = fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
place = self.set_place()
exe = fluid.Executor(place)
sgd = Adam(learning_rate=1e-3)
x = fluid.layers.data(
......@@ -707,14 +722,17 @@ class TestProgramStatePartial(unittest.TestCase):
class TestVariableInit(unittest.TestCase):
def set_place(self):
return fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
def test_variable_init(self):
x = fluid.data(name="x", shape=[10, 10], dtype='float32')
y = fluid.layers.fc(x, 10)
z = fluid.layers.fc(y, 10)
place = fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
place = self.set_place()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
......@@ -737,8 +755,7 @@ class TestVariableInit(unittest.TestCase):
program = fluid.default_main_program()
new_scope = fluid.core.Scope()
place = fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
place = self.set_place()
exe = fluid.Executor(place)
parameter_list = list(
filter(fluid.io.is_parameter, program.list_vars()))
......@@ -797,6 +814,10 @@ class TestLoadFromOldInterface(unittest.TestCase):
if os.path.exists("test_static_load_var_list.pdparams"):
os.remove("test_static_load_var_list.pdparams")
def set_place(self):
return fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
def test_load_from_old_interface(self):
seed = 90
hidden_size = 10
......@@ -818,8 +839,7 @@ class TestLoadFromOldInterface(unittest.TestCase):
num_steps=num_steps,
init_scale=init_scale)
place = fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
place = self.set_place()
exe = fluid.Executor(place)
sgd = Adam(learning_rate=1e-3)
x = fluid.layers.data(
......@@ -934,8 +954,7 @@ class TestLoadFromOldInterface(unittest.TestCase):
num_steps=num_steps,
init_scale=init_scale)
place = fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
place = self.set_place()
exe = fluid.Executor(place)
sgd = Adam(learning_rate=1e-3)
x = fluid.layers.data(
......@@ -1026,6 +1045,10 @@ class TestLoadFromOldInterface(unittest.TestCase):
class TestLoadFromOldInterfaceSingleFile(unittest.TestCase):
def set_place(self):
return fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
def test_load_from_old_interface(self):
seed = 90
hidden_size = 10
......@@ -1047,8 +1070,7 @@ class TestLoadFromOldInterfaceSingleFile(unittest.TestCase):
num_steps=num_steps,
init_scale=init_scale)
place = fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
place = self.set_place()
exe = fluid.Executor(place)
sgd = Adam(learning_rate=1e-3)
x = fluid.layers.data(
......@@ -1170,6 +1192,13 @@ class TestLoadFromOldInterfaceSingleFile(unittest.TestCase):
class TestProgramStateOldSave(unittest.TestCase):
def setUp(self):
self.test_dygraph = True
def set_place(self):
return fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
def test_ptb_rnn_cpu_float32(self):
seed = 90
hidden_size = 10
......@@ -1191,8 +1220,7 @@ class TestProgramStateOldSave(unittest.TestCase):
num_steps=num_steps,
init_scale=init_scale)
place = fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
place = self.set_place()
exe = fluid.Executor(place)
sgd = Adam(learning_rate=1e-3)
x = fluid.layers.data(
......@@ -1298,11 +1326,12 @@ class TestProgramStateOldSave(unittest.TestCase):
fluid.set_program_state(main_program, program_state)
self.check_in_static(main_program, base_map)
# make sure `load_program_state` can be used in dynamic graph mode
with fluid.dygraph.guard(place):
load_state = fluid.load_program_state("test_program_1")
for k, v in load_state.items():
self.assertTrue(np.array_equal(base_map[k], v))
if self.test_dygraph:
# make sure `load_program_state` can be used in dynamic graph mode
with fluid.dygraph.guard(place):
load_state = fluid.load_program_state("test_program_1")
for k, v in load_state.items():
self.assertTrue(np.array_equal(base_map[k], v))
def check_in_static(self, main_program, base_map):
for var in main_program.list_vars():
......@@ -1313,40 +1342,11 @@ class TestProgramStateOldSave(unittest.TestCase):
self.assertTrue(np.array_equal(new_t, base_t))
class TestStaticSaveLoadLargeParameters(unittest.TestCase):
def test_large_parameters_static_save(self):
# enable static mode
paddle.enable_static()
LARGE_PARAM = 2**26
with new_program_scope():
# create network
x = paddle.static.data(
name="static_save_load_large_x",
shape=[None, 10],
dtype='float32')
z = paddle.static.nn.fc(x, LARGE_PARAM)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
prog = paddle.static.default_main_program()
inputs = np.random.randn(1, 10).astype("float32")
result_z = exe.run(program=prog,
feed={"static_save_load_large_x": inputs},
fetch_list=[z.name])
path = "test_static_save_load_large_param/static_save"
paddle.fluid.save(prog, path)
paddle.fluid.load(prog, path)
result_load = exe.run(program=prog,
feed={"static_save_load_large_x": inputs},
fetch_list=[z.name])
# compare results before and after saving
self.assertTrue(
np.sum(np.abs(result_z[0] - result_load[0])) < 1e-15)
class TestProgramStateOldSaveSingleModel(unittest.TestCase):
def set_place(self):
return fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
def test_ptb_rnn_cpu_float32(self):
seed = 90
hidden_size = 10
......@@ -1368,8 +1368,7 @@ class TestProgramStateOldSaveSingleModel(unittest.TestCase):
num_steps=num_steps,
init_scale=init_scale)
place = fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
place = self.set_place()
exe = fluid.Executor(place)
sgd = Adam(learning_rate=1e-3)
x = fluid.layers.data(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册