未验证 提交 ad2813b1 编写于 作者: C cc 提交者: GitHub

[cherry-pick] Add weight quantization in post_training_quanzitaion (#22445) (#22493)

* Add weight quantization in post_training_quanzitaion (#22445)

* [cherry-pick]Support int16 for Tensor (#22423)

* add int16 support, test=develop, test=release/1.7
Co-authored-by: NLeo Chen <chenqiuliang@baidu.com>
上级 ff4f0757
...@@ -90,4 +90,5 @@ REGISTER_OP_CPU_KERNEL( ...@@ -90,4 +90,5 @@ REGISTER_OP_CPU_KERNEL(
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, double>, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int>, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -106,9 +106,10 @@ DECLARE_VALID_DTYPE_TO_PY_ARRAY(float); ...@@ -106,9 +106,10 @@ DECLARE_VALID_DTYPE_TO_PY_ARRAY(float);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(double); DECLARE_VALID_DTYPE_TO_PY_ARRAY(double);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(bool); DECLARE_VALID_DTYPE_TO_PY_ARRAY(bool);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(int8_t); DECLARE_VALID_DTYPE_TO_PY_ARRAY(int8_t);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(uint8_t); DECLARE_VALID_DTYPE_TO_PY_ARRAY(int16_t);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(int); DECLARE_VALID_DTYPE_TO_PY_ARRAY(int);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(int64_t); DECLARE_VALID_DTYPE_TO_PY_ARRAY(int64_t);
DECLARE_VALID_DTYPE_TO_PY_ARRAY(uint8_t);
inline std::string TensorDTypeToPyDTypeStr( inline std::string TensorDTypeToPyDTypeStr(
framework::proto::VarType::Type type) { framework::proto::VarType::Type type) {
...@@ -218,13 +219,16 @@ void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj, ...@@ -218,13 +219,16 @@ void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj,
SetTensorFromPyArrayT<double, P>(self, array, place, zero_copy); SetTensorFromPyArrayT<double, P>(self, array, place, zero_copy);
} else if (py::isinstance<py::array_t<int8_t>>(array)) { } else if (py::isinstance<py::array_t<int8_t>>(array)) {
SetTensorFromPyArrayT<int8_t, P>(self, array, place, zero_copy); SetTensorFromPyArrayT<int8_t, P>(self, array, place, zero_copy);
} else if (py::isinstance<py::array_t<int16_t>>(array)) {
SetTensorFromPyArrayT<int16_t, P>(self, array, place, zero_copy);
} else if (py::isinstance<py::array_t<uint8_t>>(array)) { } else if (py::isinstance<py::array_t<uint8_t>>(array)) {
SetTensorFromPyArrayT<uint8_t, P>(self, array, place, zero_copy); SetTensorFromPyArrayT<uint8_t, P>(self, array, place, zero_copy);
} else if (py::isinstance<py::array_t<paddle::platform::float16>>(array)) { } else if (py::isinstance<py::array_t<paddle::platform::float16>>(array)) {
SetTensorFromPyArrayT<paddle::platform::float16, P>(self, array, place, SetTensorFromPyArrayT<paddle::platform::float16, P>(self, array, place,
zero_copy); zero_copy);
} else if (py::isinstance<py::array_t<uint16_t>>(array)) { } else if (py::isinstance<py::array_t<uint16_t>>(array)) {
// TODO(cql): temporary keeping uint16, should be depracated later // TODO(cql): temporary keeping uint16, which is used for casting float16
// before. It should be depracated later.
SetTensorFromPyArrayT<paddle::platform::float16, P>(self, array, place, SetTensorFromPyArrayT<paddle::platform::float16, P>(self, array, place,
zero_copy); zero_copy);
} else if (py::isinstance<py::array_t<bool>>(array)) { } else if (py::isinstance<py::array_t<bool>>(array)) {
...@@ -234,7 +238,7 @@ void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj, ...@@ -234,7 +238,7 @@ void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj,
"Incompatible data or style type: tensor.set() supports bool, float16, " "Incompatible data or style type: tensor.set() supports bool, float16, "
"float32, " "float32, "
"float64, " "float64, "
"int8, int32, int64 and uint8, uint16, but got %s!", "int8, int16, int32, int64 and uint8, uint16, but got %s!",
array.dtype()); array.dtype());
} }
} }
...@@ -435,16 +439,18 @@ inline framework::Tensor *_sliceTensor(const framework::Tensor &self, ...@@ -435,16 +439,18 @@ inline framework::Tensor *_sliceTensor(const framework::Tensor &self,
return _sliceAndConcat<float>(self, obj, dim); return _sliceAndConcat<float>(self, obj, dim);
case framework::proto::VarType::FP64: case framework::proto::VarType::FP64:
return _sliceAndConcat<double>(self, obj, dim); return _sliceAndConcat<double>(self, obj, dim);
case framework::proto::VarType::INT8:
return _sliceAndConcat<int8_t>(self, obj, dim);
case framework::proto::VarType::INT16:
return _sliceAndConcat<int16_t>(self, obj, dim);
case framework::proto::VarType::INT32: case framework::proto::VarType::INT32:
return _sliceAndConcat<int>(self, obj, dim); return _sliceAndConcat<int>(self, obj, dim);
case framework::proto::VarType::INT64: case framework::proto::VarType::INT64:
return _sliceAndConcat<int64_t>(self, obj, dim); return _sliceAndConcat<int64_t>(self, obj, dim);
case framework::proto::VarType::BOOL: case framework::proto::VarType::BOOL:
return _sliceAndConcat<bool>(self, obj, dim); return _sliceAndConcat<bool>(self, obj, dim);
case framework::proto::VarType::INT16:
return _sliceAndConcat<bool>(self, obj, dim);
case framework::proto::VarType::UINT8: case framework::proto::VarType::UINT8:
return _sliceAndConcat<bool>(self, obj, dim); return _sliceAndConcat<uint8_t>(self, obj, dim);
default: default:
PADDLE_THROW("Not support type %d", src_type); PADDLE_THROW("Not support type %d", src_type);
} }
......
...@@ -16,10 +16,10 @@ import os ...@@ -16,10 +16,10 @@ import os
import re import re
import logging import logging
import numpy as np import numpy as np
from ....executor import global_scope
from .... import io from .... import io
from .... import core from .... import core
from .... import framework from .... import framework
from ....executor import global_scope, Executor
from ....framework import IrGraph from ....framework import IrGraph
from ....log_helper import get_logger from ....log_helper import get_logger
from .quantization_pass import QuantizationTransformPass from .quantization_pass import QuantizationTransformPass
...@@ -27,12 +27,31 @@ from .quantization_pass import QuantizationFreezePass ...@@ -27,12 +27,31 @@ from .quantization_pass import QuantizationFreezePass
from .quantization_pass import AddQuantDequantPass from .quantization_pass import AddQuantDequantPass
from .quantization_pass import _op_real_in_out_name from .quantization_pass import _op_real_in_out_name
__all__ = ['PostTrainingQuantization'] __all__ = ['PostTrainingQuantization', 'WeightQuantization']
_logger = get_logger( _logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
def _load_variable_data(scope, var_name):
'''
Load variable value from scope
'''
return np.array(scope.find_var(var_name).get_tensor())
def _set_variable_data(scope, place, var_name, np_value):
'''
Set the value of var node by name, if the node exits,
'''
assert isinstance(np_value, np.ndarray), \
'The type of value should be numpy array.'
var_node = scope.find_var(var_name)
if var_node != None:
tensor = var_node.get_tensor()
tensor.set(np_value, place)
class PostTrainingQuantization(object): class PostTrainingQuantization(object):
def __init__(self, def __init__(self,
executor, executor,
...@@ -297,12 +316,12 @@ class PostTrainingQuantization(object): ...@@ -297,12 +316,12 @@ class PostTrainingQuantization(object):
''' '''
for var_name in self._quantized_weight_var_name: for var_name in self._quantized_weight_var_name:
if var_name not in self._sampling_data: if var_name not in self._sampling_data:
var_tensor = self._load_var_value(var_name) var_tensor = _load_variable_data(self._scope, var_name)
self._sampling_data[var_name] = var_tensor self._sampling_data[var_name] = var_tensor
if self._is_use_cache_file: if self._is_use_cache_file:
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
var_tensor = self._load_var_value(var_name) var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = var_tensor.ravel() var_tensor = var_tensor.ravel()
save_path = os.path.join(self._cache_dir, save_path = os.path.join(self._cache_dir,
var_name + "_" + str(iter) + ".npy") var_name + "_" + str(iter) + ".npy")
...@@ -311,7 +330,7 @@ class PostTrainingQuantization(object): ...@@ -311,7 +330,7 @@ class PostTrainingQuantization(object):
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
if var_name not in self._sampling_data: if var_name not in self._sampling_data:
self._sampling_data[var_name] = [] self._sampling_data[var_name] = []
var_tensor = self._load_var_value(var_name) var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = var_tensor.ravel() var_tensor = var_tensor.ravel()
self._sampling_data[var_name].append(var_tensor) self._sampling_data[var_name].append(var_tensor)
...@@ -397,11 +416,17 @@ class PostTrainingQuantization(object): ...@@ -397,11 +416,17 @@ class PostTrainingQuantization(object):
# save scale factor to scale var node # save scale factor to scale var node
for key, val in self._quantized_var_scale_factor.items(): for key, val in self._quantized_var_scale_factor.items():
self._set_var_node_value( _set_variable_data(
key + ".scale", np.array( self._scope,
self._place,
key + ".scale",
np.array(
[val], dtype=np.float32)) [val], dtype=np.float32))
self._set_var_node_value( _set_variable_data(
key + ".quant_dequant.scale", np.array( self._scope,
self._place,
key + ".quant_dequant.scale",
np.array(
[val], dtype=np.float32)) [val], dtype=np.float32))
# apply QuantizationFreezePass, and obtain the final quant model # apply QuantizationFreezePass, and obtain the final quant model
...@@ -430,23 +455,6 @@ class PostTrainingQuantization(object): ...@@ -430,23 +455,6 @@ class PostTrainingQuantization(object):
self._quantized_var_scale_factor[ self._quantized_var_scale_factor[
output_var_name]) output_var_name])
def _load_var_value(self, var_name):
'''
Load variable value from scope
'''
return np.array(self._scope.find_var(var_name).get_tensor())
def _set_var_node_value(self, var_node_name, np_value):
'''
Set the value of var node by name, if the node exits,
'''
assert isinstance(np_value, np.ndarray), \
'The type of value should be numpy array.'
var_node = self._scope.find_var(var_node_name)
if var_node != None:
tensor = var_node.get_tensor()
tensor.set(np_value, self._place)
def _is_input_all_not_persistable(self, op, persistable_var_names): def _is_input_all_not_persistable(self, op, persistable_var_names):
''' '''
Analyze the real inputs of the op are all not persistable. Analyze the real inputs of the op are all not persistable.
...@@ -566,3 +574,132 @@ class PostTrainingQuantization(object): ...@@ -566,3 +574,132 @@ class PostTrainingQuantization(object):
tmp_sum1 += p_idx * (math.log(Q_sum * p_idx)) tmp_sum1 += p_idx * (math.log(Q_sum * p_idx))
tmp_sum2 += p_idx * (math.log(P_sum * q_idx)) tmp_sum2 += p_idx * (math.log(P_sum * q_idx))
return (tmp_sum1 - tmp_sum2) / P_sum return (tmp_sum1 - tmp_sum2) / P_sum
class WeightQuantization(object):
_supported_quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
def __init__(self, model_dir, model_filename=None, params_filename=None):
'''
This class quantizes the weight of some ops to reduce the size of model
or improve the perforemace.
Args:
model_dir(str): The path of the fp32 model that will be quantized,
and the model and params files are under the path.
model_filename(str, optional): The name of file to load the inference
program. If it is None, the default filename '__model__' will
be used. Default is 'None'.
params_filename(str, optional): The name of file to load all parameters.
When all parameters were saved in a single binary file, set it
as the real filename. If parameters were saved in separate files,
set it as 'None'. Default is 'None'.
'''
self._model_dir = model_dir
self._model_filename = model_filename
self._params_filename = params_filename
def quantize_weight_to_int(self,
save_model_dir,
save_model_filename=None,
save_params_filename=None,
quantizable_op_type=["conv2d", "mul"],
quantize_weight_bits=8,
threshold_rate=0.0):
'''
In order to reduce the size of model, this api quantizes the weight
of some ops from float32 to int8/16. In the inference stage, the
quantized weight will be dequantized to float32 again.
Args:
save_model_dir(str): The path to save the quantized model.
save_model_filename(str, optional): The name of file to
save the inference program. If it is None, the default
filename '__model__' will be used. Default is 'None'.
save_params_filename(str, optional): The name of file to
save all parameters. If it is None, parameters were
saved in separate files. If it is not None, all
parameters were saved in a single binary file.
quantizable_op_type(list[str], optional): The list of ops
that will be quantized, and the quantized ops should be
contained in ["conv2d", "depthwise_conv2d", "mul"].
Default is ["conv2d","mul"].
quantize_weight_bits(int, optional): The bits for the quantized
weight, and it should be 8 or 16. Default is 8.
threshold_rate(float, optional): This api uses abs_max methd to
quantize the weight from float32 to int8/16, and the abs max
value is important for quantization diff. When the abs_max
value is far away from the center of the numerical distribution,
we can set threshold_rate between 1e-6 and 1e-8, so the abs max
value will be optimized. Default is 0.0.
'''
for op_type in quantizable_op_type:
assert op_type in self._supported_quantizable_op_type, \
"input error:" + op_type + \
" is not supported for weight quantization."
assert quantize_weight_bits in [8, 16], \
"input error: quantize_weight_bits should be 8 or 16."
quantize_range = (1 << (quantize_weight_bits - 1)) - 1
save_weight_dtype = np.int8 if quantize_weight_bits == 8 else np.int16
place = core.CPUPlace()
exe = Executor(place)
scope = global_scope()
[program, feed_list, fetch_list] = \
io.load_inference_model(dirname=self._model_dir,
executor=exe,
model_filename=self._model_filename,
params_filename=self._params_filename)
persistable_var_names = []
for var in program.list_vars():
if var.persistable:
persistable_var_names.append(var.name)
for op in program.global_block().ops:
if op.type in quantizable_op_type:
for var_name in op.input_arg_names:
if var_name in persistable_var_names:
var_tensor_data = _load_variable_data(scope, var_name)
if abs(threshold_rate) < 1e-10:
threshold_value = np.max(np.abs(var_tensor_data))
else:
threshold_value = self._calculate_threshold(\
var_tensor_data, threshold_rate)
var_tensor_data[var_tensor_data >
threshold_value] = threshold_value
var_tensor_data[var_tensor_data <
-threshold_value] = -threshold_value
scale = threshold_value / quantize_range
quantized_var_tensor_data = \
np.around(var_tensor_data / scale)
quantized_var_tensor_data = \
quantized_var_tensor_data.astype(save_weight_dtype)
_set_variable_data(scope, place, var_name,
quantized_var_tensor_data)
op._set_attr(var_name + "_quant_scale", [scale])
op._set_attr('quantize_weight_bits',
quantize_weight_bits)
io.save_inference_model(
dirname=save_model_dir,
feeded_var_names=feed_list,
target_vars=fetch_list,
executor=exe,
main_program=program,
model_filename=save_model_filename,
params_filename=save_params_filename)
def _calculate_threshold(self, input, threshold_rate, histogram_bins=5000):
input_abs = np.abs(input)
hist, hist_edeges = np.histogram(
input_abs, bins=histogram_bins, range=(0, np.max(input_abs)))
hist = hist / float(sum(hist))
hist_sum = 0
hist_index = 0
for i in range(len(hist)):
hist_sum += hist[i]
if hist_sum >= 1.0 - threshold_rate:
hist_index = i + 1
break
bin_width = hist_edeges[1] - hist_edeges[0]
return hist_index * bin_width
...@@ -58,6 +58,7 @@ if(WIN32) ...@@ -58,6 +58,7 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_light_nas) list(REMOVE_ITEM TEST_OPS test_light_nas)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50)
list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1)
endif() endif()
# int8 image classification python api test # int8 image classification python api test
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import time
from paddle.dataset.common import download, DATA_HOME
from paddle.fluid.contrib.slim.quantization import WeightQuantization
class TestWeightQuantization(unittest.TestCase):
def setUp(self):
self.weight_quantization_dir = 'weight_quantization'
self.cache_folder = os.path.join(DATA_HOME,
self.weight_quantization_dir)
def download_model(self, model_name, data_url, data_md5):
download(data_url, self.weight_quantization_dir, data_md5)
file_name = data_url.split('/')[-1]
file_path = os.path.join(self.cache_folder, file_name)
print(model_name + ' is downloaded at ' + file_path)
unziped_path = os.path.join(self.cache_folder, model_name)
self.cache_unzipping(unziped_path, file_path)
print(model_name + ' is unziped at ' + unziped_path)
return unziped_path
def cache_unzipping(self, target_folder, zip_path):
if not os.path.exists(target_folder):
cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder,
zip_path)
os.system(cmd)
def run_test(self, model_name, model_data_url, model_data_md5,
quantize_weight_bits, quantizable_op_type, threshold_rate):
model_dir = self.download_model(model_name, model_data_url,
model_data_md5)
timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
save_model_dir = os.path.join(
os.getcwd(),
model_name + "_wq_" + str(quantize_weight_bits) + "_" + timestamp)
weight_quant = WeightQuantization(model_dir=model_dir + "/model")
weight_quant.quantize_weight_to_int(
save_model_dir=save_model_dir,
quantize_weight_bits=quantize_weight_bits,
quantizable_op_type=quantizable_op_type,
threshold_rate=threshold_rate)
print("finish weight quantization for " + model_name + "\n")
try:
os.system("rm -rf {}".format(save_model_dir))
except Exception as e:
print("Failed to delete {} due to {}".format(save_model_dir, str(
e)))
class TestWeightQuantizationMobilenetv1(TestWeightQuantization):
model_name = "mobilenetv1"
model_data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz"
model_data_md5 = "13892b0716d26443a8cdea15b3c6438b"
def test_weight_quantization_mobilenetv1_8bit(self):
quantize_weight_bits = 8
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
threshold_rate = 0.0
self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
quantize_weight_bits, quantizable_op_type, threshold_rate)
def test_weight_quantization_mobilenetv1_16bit(self):
quantize_weight_bits = 16
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
threshold_rate = 1e-9
self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
quantize_weight_bits, quantizable_op_type, threshold_rate)
if __name__ == '__main__':
unittest.main()
...@@ -22,6 +22,12 @@ import numbers ...@@ -22,6 +22,12 @@ import numbers
class TestTensor(unittest.TestCase): class TestTensor(unittest.TestCase):
def setUp(self):
self.support_dtypes = [
'bool', 'uint8', 'int8', 'int16', 'int32', 'int64', 'float16',
'float32', 'float64'
]
def test_int_tensor(self): def test_int_tensor(self):
scope = core.Scope() scope = core.Scope()
var = scope.var("test_tensor") var = scope.var("test_tensor")
...@@ -184,15 +190,15 @@ class TestTensor(unittest.TestCase): ...@@ -184,15 +190,15 @@ class TestTensor(unittest.TestCase):
tensor_array = numpy.array(tensor) tensor_array = numpy.array(tensor)
self.assertEqual((0, 1), tensor_array.shape) self.assertEqual((0, 1), tensor_array.shape)
def run_sliece_tensor(self, place): def run_slice_tensor(self, place, dtype):
tensor = fluid.Tensor() tensor = fluid.Tensor()
shape = [3, 3, 3] shape = [3, 3, 3]
tensor._set_dims(shape) tensor._set_dims(shape)
tensor_array = numpy.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]], tensor_array = numpy.array(
[[10, 11, 12], [13, 14, 15], [16, 17, 18]], [[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[19, 20, 21], [22, 23, 24], [25, 26, 27]]]) [[10, 11, 12], [13, 14, 15], [16, 17, 18]],
[[19, 20, 21], [22, 23, 24], [25, 26, 27]]]).astype(dtype)
tensor.set(tensor_array, place) tensor.set(tensor_array, place)
n1 = tensor[1] n1 = tensor[1]
...@@ -227,14 +233,15 @@ class TestTensor(unittest.TestCase): ...@@ -227,14 +233,15 @@ class TestTensor(unittest.TestCase):
t8 = tensor_array[0::1, 0::-1, 2:] t8 = tensor_array[0::1, 0::-1, 2:]
self.assertTrue((numpy.array(n8) == numpy.array(t8)).all()) self.assertTrue((numpy.array(n8) == numpy.array(t8)).all())
def test_sliece_tensor(self): def test_slice_tensor(self):
# run cpu first for dtype in self.support_dtypes:
place = core.CPUPlace() # run cpu first
self.run_sliece_tensor(place) place = core.CPUPlace()
self.run_slice_tensor(place, dtype)
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
self.run_sliece_tensor(place) self.run_slice_tensor(place, dtype)
def test_print_tensor(self): def test_print_tensor(self):
scope = core.Scope() scope = core.Scope()
...@@ -299,6 +306,25 @@ class TestTensor(unittest.TestCase): ...@@ -299,6 +306,25 @@ class TestTensor(unittest.TestCase):
self.assertEqual(tensor._dtype(), core.VarDesc.VarType.FP16) self.assertEqual(tensor._dtype(), core.VarDesc.VarType.FP16)
self.assertTrue(numpy.array_equal(numpy.array(tensor), array)) self.assertTrue(numpy.array_equal(numpy.array(tensor), array))
def test_tensor_set_int16(self):
array = numpy.random.randint(100, size=(300, 500)).astype("int16")
tensor = fluid.Tensor()
place = core.CPUPlace()
tensor.set(array, place)
self.assertEqual(tensor._dtype(), core.VarDesc.VarType.INT16)
self.assertTrue(numpy.array_equal(numpy.array(tensor), array))
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
tensor.set(array, place)
self.assertEqual(tensor._dtype(), core.VarDesc.VarType.INT16)
self.assertTrue(numpy.array_equal(numpy.array(tensor), array))
place = core.CUDAPinnedPlace()
tensor.set(array, place)
self.assertEqual(tensor._dtype(), core.VarDesc.VarType.INT16)
self.assertTrue(numpy.array_equal(numpy.array(tensor), array))
def test_tensor_set_from_array_list(self): def test_tensor_set_from_array_list(self):
array = numpy.random.randint(1000, size=(200, 300)) array = numpy.random.randint(1000, size=(200, 300))
list_array = [array, array] list_array = [array, array]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册