diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 2f2d69e50cf0984038dd08f8e288e953eba4d5a4..23be7d954bc3de0d34abc61ec3f84bbc8ba2c204 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -457,55 +457,12 @@ PYBIND11_MODULE(core_noavx, m) { return reinterpret_cast(self.mutable_data(place, type)); }) .def("_clear", &Tensor::clear) - .def("set", PyCPUTensorSetFromArray, py::arg("array"), - py::arg("place")) - .def("set", PyCPUTensorSetFromArray, py::arg("array"), - py::arg("place")) - .def("set", PyCPUTensorSetFromArray, py::arg("array"), - py::arg("place")) - .def("set", PyCPUTensorSetFromArray, py::arg("array"), - py::arg("place")) - .def("set", PyCPUTensorSetFromArray, py::arg("array"), - py::arg("place")) - .def("set", PyCPUTensorSetFromArray, py::arg("array"), - py::arg("place")) - .def("set", PyCPUTensorSetFromArray, py::arg("array"), - py::arg("place")) - .def("set", PyCPUTensorSetFromArray, py::arg("array"), - py::arg("place")) -#ifdef PADDLE_WITH_CUDA - .def("set", PyCUDATensorSetFromArray, py::arg("array"), - py::arg("place")) - .def("set", PyCUDATensorSetFromArray, py::arg("array"), - py::arg("place")) - .def("set", PyCUDATensorSetFromArray, py::arg("array"), - py::arg("place")) - .def("set", PyCUDATensorSetFromArray, py::arg("array"), - py::arg("place")) - .def("set", PyCUDATensorSetFromArray, py::arg("array"), - py::arg("place")) - .def("set", PyCUDATensorSetFromArray, py::arg("array"), - py::arg("place")) - .def("set", PyCUDATensorSetFromArray, py::arg("array"), - py::arg("place")) - .def("set", PyCUDATensorSetFromArray, py::arg("array"), - py::arg("place")) - .def("set", PyCUDAPinnedTensorSetFromArray, py::arg("array"), - py::arg("place")) - .def("set", PyCUDAPinnedTensorSetFromArray, py::arg("array"), - py::arg("place")) - .def("set", PyCUDAPinnedTensorSetFromArray, py::arg("array"), - py::arg("place")) - .def("set", PyCUDAPinnedTensorSetFromArray, py::arg("array"), - py::arg("place")) - .def("set", PyCUDAPinnedTensorSetFromArray, py::arg("array"), - py::arg("place")) - .def("set", PyCUDAPinnedTensorSetFromArray, py::arg("array"), - py::arg("place")) - .def("set", PyCUDAPinnedTensorSetFromArray, py::arg("array"), - py::arg("place")) - .def("set", PyCUDAPinnedTensorSetFromArray, py::arg("array"), - py::arg("place"), R"DOC( + .def("set", SetTensorFromPyArray, + py::arg("array"), py::arg("place")) + .def("set", SetTensorFromPyArray, + py::arg("array"), py::arg("place")) + .def("set", SetTensorFromPyArray, + py::arg("array"), py::arg("place"), R"DOC( Set the data of LoDTensor on place with given numpy array. Args: @@ -525,7 +482,7 @@ PYBIND11_MODULE(core_noavx, m) { t = fluid.LoDTensor() t.set(np.ndarray([5, 30]), fluid.CPUPlace()) )DOC") -#endif + .def("shape", [](Tensor &self) { return vectorize(self.dims()); }, R"DOC( Return the shape of LoDTensor. diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index 08e43bf24ce1a6863f13b6334f9b3272e4414ff5..2aae3e0f8374b8214044e6b9d9f59de69553a9b0 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -30,9 +30,81 @@ limitations under the License. */ namespace py = pybind11; +namespace pybind11 { +namespace detail { + +// Note: use same enum number of float16 in numpy. +// import numpy as np +// print np.dtype(np.float16).num # 23 +constexpr int NPY_FLOAT16_ = 23; + +// Note: Since float16 is not a builtin type in C++, we register +// paddle::platform::float16 as numpy.float16. +// Ref: https://github.com/pybind/pybind11/issues/1776 +template <> +struct npy_format_descriptor { + static py::dtype dtype() { + handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16_); + return reinterpret_borrow(ptr); + } + static std::string format() { + // Note: "e" represents float16. + // Details at: + // https://docs.python.org/3/library/struct.html#format-characters. + return "e"; + } + static PYBIND11_DESCR name() { return _("float16"); } +}; + +} // namespace detail +} // namespace pybind11 + namespace paddle { namespace pybind { +namespace details { + +template +struct ValidDTypeToPyArrayChecker { + static constexpr bool kValue = false; +}; + +#define DECLARE_VALID_DTYPE_TO_PY_ARRAY(type) \ + template <> \ + struct ValidDTypeToPyArrayChecker { \ + static constexpr bool kValue = true; \ + } + +DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::float16); +DECLARE_VALID_DTYPE_TO_PY_ARRAY(float); +DECLARE_VALID_DTYPE_TO_PY_ARRAY(double); +DECLARE_VALID_DTYPE_TO_PY_ARRAY(bool); +DECLARE_VALID_DTYPE_TO_PY_ARRAY(int8_t); +DECLARE_VALID_DTYPE_TO_PY_ARRAY(uint8_t); +DECLARE_VALID_DTYPE_TO_PY_ARRAY(int); +DECLARE_VALID_DTYPE_TO_PY_ARRAY(int64_t); + +inline std::string TensorDTypeToPyDTypeStr( + framework::proto::VarType::Type type) { +#define TENSOR_DTYPE_TO_PY_DTYPE(T, proto_type) \ + if (type == proto_type) { \ + if (std::is_same::value) { \ + return "e"; \ + } else { \ + constexpr auto kIsValidDType = ValidDTypeToPyArrayChecker::kValue; \ + PADDLE_ENFORCE_EQ(kIsValidDType, true, \ + "This type of tensor cannot be expose to Python"); \ + return py::format_descriptor::format(); \ + } \ + } + + _ForEachDataType_(TENSOR_DTYPE_TO_PY_DTYPE); +#undef TENSOR_DTYPE_TO_PY_DTYPE + PADDLE_THROW("Unsupported data type %d", static_cast(type)); +} + +} // namespace details + template T TensorGetElement(const framework::Tensor &self, size_t offset) { PADDLE_ENFORCE_LT(offset, self.numel()); @@ -65,6 +137,71 @@ void TensorSetElement(framework::Tensor *self, size_t offset, T elem) { } } +template +void SetTensorFromPyArrayT( + framework::Tensor *self, + py::array_t array, P place) { + std::vector dims; + dims.reserve(array.ndim()); + for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) { + dims.push_back(static_cast(array.shape()[i])); + } + self->Resize(framework::make_ddim(dims)); + auto dst = self->mutable_data(place); + + if (paddle::platform::is_cpu_place(place)) { + std::memcpy(dst, array.data(), array.nbytes()); + } else { +#ifdef PADDLE_WITH_CUDA + if (paddle::platform::is_cuda_pinned_place(place)) { + std::memcpy(dst, array.data(), array.nbytes()); + } else if (paddle::platform::is_gpu_place(place)) { + paddle::platform::GpuMemcpySync(dst, array.data(), array.nbytes(), + cudaMemcpyHostToDevice); + } else { + PADDLE_THROW( + "Incompatible place type: Tensor.set() supports CPUPlace, CUDAPlace " + "and CUDAPinnedPlace, but got %s!", + place); + } +#else + PADDLE_THROW("Not supported GPU, please compile WITH_GPU option"); +#endif + } +} + +template +void SetTensorFromPyArray(framework::Tensor *self, pybind11::array array, + P place) { + if (py::isinstance>(array)) { + SetTensorFromPyArrayT(self, array, place); + } else if (py::isinstance>(array)) { + SetTensorFromPyArrayT(self, array, place); + } else if (py::isinstance>(array)) { + SetTensorFromPyArrayT(self, array, place); + } else if (py::isinstance>(array)) { + SetTensorFromPyArrayT(self, array, place); + } else if (py::isinstance>(array)) { + SetTensorFromPyArrayT(self, array, place); + } else if (py::isinstance>(array)) { + SetTensorFromPyArrayT(self, array, place); + } else if (py::isinstance>(array)) { + SetTensorFromPyArrayT(self, array, place); + } else if (py::isinstance>(array)) { + // TODO(cql): temporary keeping uint16, should be depracated later + SetTensorFromPyArrayT(self, array, place); + } else if (py::isinstance>(array)) { + SetTensorFromPyArrayT(self, array, place); + } else { + PADDLE_THROW( + "Incompatible data or style type: tensor.set() supports bool, float16, " + "float32, " + "float64, " + "int8, int32, int64 and uint8, uint16, but got %s!", + array.dtype()); + } +} + template void PyCPUTensorSetFromArray( framework::Tensor *self, @@ -96,7 +233,6 @@ inline void PyCPUTensorSetFromArray( for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) { dims.push_back(static_cast(array.shape()[i])); } - self->Resize(framework::make_ddim(dims)); auto *dst = self->mutable_data(place); std::memcpy(dst, array.data(), sizeof(uint16_t) * array.size()); @@ -361,7 +497,6 @@ void PyCUDATensorSetFromArray( for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) { dims.push_back(static_cast(array.shape()[i])); } - self->Resize(framework::make_ddim(dims)); auto *dst = self->mutable_data(place); paddle::platform::GpuMemcpySync(dst, array.data(), sizeof(T) * array.size(), @@ -428,49 +563,6 @@ inline void PyCUDAPinnedTensorSetFromArray( } #endif -namespace details { - -template -struct ValidDTypeToPyArrayChecker { - static constexpr bool kValue = false; -}; - -#define DECLARE_VALID_DTYPE_TO_PY_ARRAY(type) \ - template <> \ - struct ValidDTypeToPyArrayChecker { \ - static constexpr bool kValue = true; \ - } - -DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::float16); -DECLARE_VALID_DTYPE_TO_PY_ARRAY(float); -DECLARE_VALID_DTYPE_TO_PY_ARRAY(double); -DECLARE_VALID_DTYPE_TO_PY_ARRAY(bool); -DECLARE_VALID_DTYPE_TO_PY_ARRAY(int8_t); -DECLARE_VALID_DTYPE_TO_PY_ARRAY(uint8_t); -DECLARE_VALID_DTYPE_TO_PY_ARRAY(int); -DECLARE_VALID_DTYPE_TO_PY_ARRAY(int64_t); - -inline std::string TensorDTypeToPyDTypeStr( - framework::proto::VarType::Type type) { -#define TENSOR_DTYPE_TO_PY_DTYPE(T, proto_type) \ - if (type == proto_type) { \ - if (std::is_same::value) { \ - return "e"; \ - } else { \ - constexpr auto kIsValidDType = ValidDTypeToPyArrayChecker::kValue; \ - PADDLE_ENFORCE(kIsValidDType, \ - "This type of tensor cannot be expose to Python"); \ - return py::format_descriptor::format(); \ - } \ - } - - _ForEachDataType_(TENSOR_DTYPE_TO_PY_DTYPE); -#undef TENSOR_DTYPE_TO_PY_DTYPE - PADDLE_THROW("Unsupported data type %d", static_cast(type)); -} - -} // namespace details - inline py::array TensorToPyArray(const framework::Tensor &tensor) { if (!tensor.IsInitialized()) { return py::array(); diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index 85b8dbfa18174ec809109bb61a5b938057d19004..5a3774c8993dc703007701049c6d96d24144d4ba 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -199,8 +199,6 @@ def to_variable(value, block=None, name=None): stop_gradient=True) var = py_var._ivar.value() tensor = var.get_tensor() - if value.dtype == np.float16: - value = value.view(np.uint16) tensor.set(value, framework._current_expected_place()) return py_var elif isinstance(value, framework.Variable): diff --git a/python/paddle/fluid/tests/unittests/gradient_checker.py b/python/paddle/fluid/tests/unittests/gradient_checker.py index 644a9a92ab9ea806c55e2bdfceb1b246e80cd691..d3e285d3a6ed5903409f2947d0ef9fedaace9de5 100644 --- a/python/paddle/fluid/tests/unittests/gradient_checker.py +++ b/python/paddle/fluid/tests/unittests/gradient_checker.py @@ -64,7 +64,7 @@ def _set_item(t, i, e, np_dtype): shape = np_t.shape np_t = np_t.flatten() np_t[i] = e - np_t = np_t.reshape(shape).view(np.uint16) + np_t = np_t.reshape(shape) t.set(np_t, place) elif np_dtype == np.float32: t._set_float_element(i, e) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index aed0008350be7ce4e93e75ee1a5aeb5f75e71175..b78679f0e4b604046e01769541ddd65e26a35ef8 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -99,7 +99,7 @@ def get_numeric_gradient(place, shape = numpy_tensor.shape numpy_tensor = numpy_tensor.flatten() numpy_tensor[i] = e - numpy_tensor = numpy_tensor.reshape(shape).view(np.uint16) + numpy_tensor = numpy_tensor.reshape(shape) tensor.set(numpy_tensor, place) elif tensor_to_check_dtype == np.float32: tensor._set_float_element(i, e) @@ -155,11 +155,6 @@ class OpTest(unittest.TestCase): if not self.call_once: self.call_once = True self.dtype = data_type - # See the comment of np_dtype_to_fluid_dtype - # If the input type is uint16, we assume use float16 - # for lodtensor dtype. - if self.dtype == np.uint16: - self.dtype == np.float16 def infer_dtype_from_inputs_outputs(self, inputs, outputs): def infer_dtype(numpy_dict): @@ -188,25 +183,19 @@ class OpTest(unittest.TestCase): for name, np_value in self.inputs[var_name]: tensor = core.LoDTensor() if isinstance(np_value, tuple): - tensor.set( - OpTest.np_value_to_fluid_value(np_value[0]), place) + tensor.set(np_value[0], place) tensor.set_recursive_sequence_lengths(np_value[1]) else: - tensor.set( - OpTest.np_value_to_fluid_value(np_value), place) + tensor.set(np_value, place) feed_map[name] = tensor else: tensor = core.LoDTensor() if isinstance(self.inputs[var_name], tuple): - tensor.set( - OpTest.np_value_to_fluid_value(self.inputs[var_name][ - 0]), place) + tensor.set(self.inputs[var_name][0], place) tensor.set_recursive_sequence_lengths(self.inputs[var_name][ 1]) else: - tensor.set( - OpTest.np_value_to_fluid_value(self.inputs[var_name]), - place) + tensor.set(self.inputs[var_name], place) feed_map[var_name] = tensor return feed_map @@ -978,39 +967,14 @@ class OpTest(unittest.TestCase): @staticmethod def np_dtype_to_fluid_dtype(input): - """Change the dtype of float16 numpy array - - numpy float16 is binded to paddle::platform::float16 - in tensor_py.h via the help of uint16 data type since - the internal memory representation of float16 is - uint16_t in paddle and np.uint16 in numpy, which are - themselves binded together by pybind. - - Args: - input: input numpy array - - Returns: - input: The dtype of input will be changed to np.uint16 if - it is originally np.float16, such that the internal memory - of input will be reinterpreted as of dtype np.uint16. - """ - if input.dtype == np.float16: - input.dtype = np.uint16 return input @staticmethod def fluid_dtype_to_np_dtype(self, dtype): - """ - See above, convert the dtype to normal type. - """ - if dtype == np.uint16: - dtype = np.float16 return dtype @staticmethod def np_value_to_fluid_value(input): - if input.dtype == np.float16: - input = input.view(np.uint16) return input def _get_gradient(self, diff --git a/python/paddle/fluid/tests/unittests/test_cast_op.py b/python/paddle/fluid/tests/unittests/test_cast_op.py index 53f7df60b8a541fe18946479a84571fbef5d63f1..4cd0966dca10dbe0c31d5ed648938e5eee12f58c 100644 --- a/python/paddle/fluid/tests/unittests/test_cast_op.py +++ b/python/paddle/fluid/tests/unittests/test_cast_op.py @@ -43,8 +43,7 @@ class TestCastOp1(op_test.OpTest): class TestCastOp2(op_test.OpTest): def setUp(self): ipt = np.random.random(size=[10, 10]) - # numpy float16 is binded to fluid float16 via uint16 - self.inputs = {'X': ipt.astype('float16').view(np.uint16)} + self.inputs = {'X': ipt.astype('float16')} self.outputs = {'Out': ipt.astype('float32')} self.attrs = { 'in_dtype': int(core.VarDesc.VarType.FP16), diff --git a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py index 8fe814dc50d486c8a59c74f965f7e9c5e9b40d7c..7cd27e2c89c326f80abb2769dc0dc6acd048dd08 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py @@ -132,10 +132,9 @@ class TestFakeQuantizeRangeAbsMaxOp2(OpTest): } x = (np.random.random((8, 16, 7, 7)) - 0.5) * 10 x = x.astype("float32") - scale = np.max(np.abs(x)).astype("float32") - 1.0 + scale = np.array([np.max(np.abs(x)).astype("float32") - 1.0]) out_scales = np.zeros(self.attrs['window_size']).astype("float32") out_scales[0] = scale - self.inputs = { 'X': x, 'Iter': np.zeros(1).astype("int64"), diff --git a/python/paddle/fluid/tests/unittests/test_mix_precision_all_reduce_fuse.py b/python/paddle/fluid/tests/unittests/test_mix_precision_all_reduce_fuse.py index 5ccf855ebc3604389fa6e8b30367b040978c3ed4..a3fa84c224e4f89c9b30bb714fb2468180af1e6f 100644 --- a/python/paddle/fluid/tests/unittests/test_mix_precision_all_reduce_fuse.py +++ b/python/paddle/fluid/tests/unittests/test_mix_precision_all_reduce_fuse.py @@ -71,7 +71,7 @@ class TestResnet(TestParallelExecutorBase): def check_model(self, use_cuda): img, label = init_data( batch_size=batch_size, img_shape=img_shape, label_range=9) - img = np.float16(img).view(np.uint16) + img = np.float16(img) feed_dict = {"image": img, "label": label} TestParallelExecutorBase.check_network_convergence( diff --git a/python/paddle/fluid/tests/unittests/test_mse_loss.py b/python/paddle/fluid/tests/unittests/test_mse_loss.py index 64b4004e4becbc24fa4533f8797f4057b6ae43ce..4e8d9c4955840e5ed0ea7b1862da5557e03e872f 100644 --- a/python/paddle/fluid/tests/unittests/test_mse_loss.py +++ b/python/paddle/fluid/tests/unittests/test_mse_loss.py @@ -34,16 +34,14 @@ class TestMseLoss(unittest.TestCase): input_var = layers.create_tensor(dtype="float32", name="input") label_var = layers.create_tensor(dtype="float32", name="label") - layers.assign(input=input_val, output=input_var) - layers.assign(input=label_val, output=label_var) output = layers.mse_loss(input=input_var, label=label_var) for use_cuda in ([False, True] if core.is_compiled_with_cuda() else [False]): place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = Executor(place) result = exe.run(fluid.default_main_program(), - feed={"input": input_var, - "label": label_var}, + feed={"input": input_val, + "label": label_val}, fetch_list=[output]) self.assertTrue(np.isclose(np_result, result).all()) diff --git a/python/paddle/fluid/tests/unittests/test_npair_loss_op.py b/python/paddle/fluid/tests/unittests/test_npair_loss_op.py index d1a015a16e46c38be8d3c8255d1d07cc6aa31572..de9a4366342c219d02c61fcfed84dfac05cbcde2 100644 --- a/python/paddle/fluid/tests/unittests/test_npair_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_npair_loss_op.py @@ -59,6 +59,7 @@ class TestNpairLossOp(unittest.TestCase): place = core.CPUPlace() exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) + embeddings_anchor = np.random.rand(num_data, feat_dim).astype(np.float32) embeddings_positive = np.random.rand(num_data, @@ -71,21 +72,29 @@ class TestNpairLossOp(unittest.TestCase): row_labels, l2_reg=reg_lambda) - anc = fluid.layers.create_tensor( - dtype='float32', persistable=True, name='anc') - pos = fluid.layers.create_tensor( - dtype='float32', persistable=True, name='pos') - lab = fluid.layers.create_tensor( - dtype='float32', persistable=True, name='lab') - fluid.layers.assign(input=embeddings_anchor, output=anc) - fluid.layers.assign(input=embeddings_positive, output=pos) - fluid.layers.assign(input=row_labels, output=lab) + anc = fluid.layers.data( + dtype='float32', + name='anc', + shape=embeddings_anchor.shape, + append_batch_size=False) + pos = fluid.layers.data( + dtype='float32', + name='pos', + shape=embeddings_positive.shape, + append_batch_size=False) + lab = fluid.layers.data( + dtype='float32', + name='lab', + shape=row_labels.shape, + append_batch_size=False) npair_loss_op = fluid.layers.npair_loss( anchor=anc, positive=pos, labels=lab, l2_reg=reg_lambda) - out_tensor = exe.run(feed={'anc': anc, - 'pos': pos, - 'lab': lab}, + out_tensor = exe.run(feed={ + 'anc': embeddings_anchor, + 'pos': embeddings_positive, + 'lab': row_labels + }, fetch_list=[npair_loss_op.name]) self.__assert_close( diff --git a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py index d37731146d9c431bb6a0c333149ac62a0c4efd3b..58929374797265ba2a900bec1617e523c422458d 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py @@ -128,10 +128,7 @@ class TestSoftmaxWithCrossEntropyOpFp16(TestSoftmaxWithCrossEntropyOp): loss = cross_entropy(softmax, labels, self.soft_label, self.axis) - self.inputs = { - "Logits": logits.astype(self.dtype).view(np.uint16), - "Label": labels - } + self.inputs = {"Logits": logits.astype(self.dtype), "Label": labels} self.outputs = { "Softmax": softmax.astype(self.dtype), "Loss": loss.astype(self.dtype) diff --git a/python/paddle/fluid/tests/unittests/test_square_error_cost.py b/python/paddle/fluid/tests/unittests/test_square_error_cost.py index 056bbfcd5302b780a8ffe3cdbda426fdc282784a..b83c0ba63450874ea4a2f25ab369bf75d84d8d77 100644 --- a/python/paddle/fluid/tests/unittests/test_square_error_cost.py +++ b/python/paddle/fluid/tests/unittests/test_square_error_cost.py @@ -33,9 +33,6 @@ class TestSquareErrorCost(unittest.TestCase): input_var = layers.create_tensor(dtype="float32", name="input") label_var = layers.create_tensor(dtype="float32", name="label") - - layers.assign(input=input_val, output=input_var) - layers.assign(input=label_val, output=label_var) output = layers.square_error_cost(input=input_var, label=label_var) for use_cuda in ([False, True] @@ -44,8 +41,8 @@ class TestSquareErrorCost(unittest.TestCase): place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = Executor(place) result = exe.run(fluid.default_main_program(), - feed={"input": input_var, - "label": label_var}, + feed={"input": input_val, + "label": label_val}, fetch_list=[output]) self.assertTrue(np.isclose(np_result, result).all()) diff --git a/python/paddle/fluid/tests/unittests/testsuite.py b/python/paddle/fluid/tests/unittests/testsuite.py index c4eb26893cd1faac72ac06c70a68c52f26b39182..c92d9a429b6c76e1f37cf0b4f044672bbdbd1abf 100644 --- a/python/paddle/fluid/tests/unittests/testsuite.py +++ b/python/paddle/fluid/tests/unittests/testsuite.py @@ -68,11 +68,6 @@ def create_op(scope, op_type, inputs, outputs, attrs, cache_list=None): def set_input(scope, op, inputs, place): - def np_value_to_fluid_value(input): - if input.dtype == np.float16: - input = input.view(np.uint16) - return input - def __set_input__(var_name, var): if isinstance(var, tuple) or isinstance(var, np.ndarray): tensor = scope.find_var(var_name).get_tensor() @@ -80,7 +75,7 @@ def set_input(scope, op, inputs, place): tensor.set_recursive_sequence_lengths(var[1]) var = var[0] tensor._set_dims(var.shape) - tensor.set(np_value_to_fluid_value(var), place) + tensor.set(var, place) elif isinstance(var, float): scope.find_var(var_name).set_float(var) elif isinstance(var, int): @@ -121,16 +116,6 @@ def append_input_output(block, op_proto, np_list, is_input, dtype): if is_input: shape = list(np_value.shape) lod_level = 0 - # NOTE(dzhwinter): type hacking - # numpy float16 is binded to paddle::platform::float16 - # in tensor_py.h via the help of uint16 datatype. Because - # the internal memory representation of float16 is - # actually uint16_t in paddle. So we use np.uint16 in numpy for - # raw memory, it can pass through the pybind. So in the testcase, - # we feed data use data.view(uint16), but the dtype is float16 in fact. - # The data.view(uint16) means do not cast the data type, but process data as the uint16 - if dtype == np.uint16: - dtype = np.float16 return block.create_var( dtype=dtype, shape=shape, lod_level=lod_level, name=name)