From 0c8fde7dce86ca19d48bc303760cec6079fcd42a Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Wed, 1 Aug 2018 14:49:01 +0800 Subject: [PATCH] "cherry picked cpp tests" (#12182) * "cherry picked cpp tests" * "cherry picked" * "cherry picked tests" * "merge develop branch" --- paddle/fluid/framework/CMakeLists.txt | 1 + paddle/fluid/framework/data_type.cc | 4 +- paddle/fluid/framework/data_type_test.cc | 40 +++++++++++ paddle/fluid/framework/op_kernel_type_test.cc | 7 ++ paddle/fluid/framework/operator.cc | 17 +++++ paddle/fluid/framework/tensor_test.cc | 15 ++++ .../paddle/fluid/tests/unittests/op_test.py | 72 ++++++++++++++++--- .../fluid/tests/unittests/test_hsigmoid_op.py | 2 + .../paddle/fluid/tests/unittests/testsuite.py | 26 ++++--- 9 files changed, 163 insertions(+), 21 deletions(-) create mode 100644 paddle/fluid/framework/data_type_test.cc diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 139411f3e0..6440607dbe 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -7,6 +7,7 @@ cc_library(ddim SRCS ddim.cc DEPS eigen3 boost) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim) cc_library(data_type SRCS data_type.cc DEPS framework_proto ddim device_context) +cc_test(data_type_test SRCS data_type_test.cc DEPS data_type place tensor) if(WITH_GPU) nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS place memory data_type device_context) else() diff --git a/paddle/fluid/framework/data_type.cc b/paddle/fluid/framework/data_type.cc index 60382faffb..1a9ce746ea 100644 --- a/paddle/fluid/framework/data_type.cc +++ b/paddle/fluid/framework/data_type.cc @@ -17,6 +17,8 @@ #include #include +using float16 = paddle::platform::float16; + namespace paddle { namespace framework { @@ -53,7 +55,7 @@ static DataTypeMap* InitDataTypeMap() { RegisterType(retv, proto_type, #cc_type) // NOTE: Add your customize type here. - RegType(platform::float16, proto::VarType::FP16); + RegType(float16, proto::VarType::FP16); RegType(float, proto::VarType::FP32); RegType(double, proto::VarType::FP64); RegType(int, proto::VarType::INT32); diff --git a/paddle/fluid/framework/data_type_test.cc b/paddle/fluid/framework/data_type_test.cc new file mode 100644 index 0000000000..54c41c55ba --- /dev/null +++ b/paddle/fluid/framework/data_type_test.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/framework/data_type.h" + +#include +#include "gtest/gtest.h" +#include "paddle/fluid/framework/tensor.h" + +TEST(DataType, float16) { + using paddle::framework::Tensor; + using paddle::platform::CPUPlace; + using paddle::platform::float16; + namespace f = paddle::framework; + f::proto::VarType::Type dtype = f::proto::VarType::FP16; + + Tensor tensor; + CPUPlace cpu; + tensor.mutable_data(cpu, f::ToTypeIndex(dtype)); + + // test fp16 tensor + EXPECT_EQ(tensor.type(), std::type_index(typeid(float16))); + + // test fp16 size + EXPECT_EQ(f::SizeOfType(f::ToTypeIndex(dtype)), 2u); + + // test debug info + std::string type = "float16"; + EXPECT_STREQ(f::DataTypeToString(dtype).c_str(), type.c_str()); +} diff --git a/paddle/fluid/framework/op_kernel_type_test.cc b/paddle/fluid/framework/op_kernel_type_test.cc index db95861c51..3e17a512ce 100644 --- a/paddle/fluid/framework/op_kernel_type_test.cc +++ b/paddle/fluid/framework/op_kernel_type_test.cc @@ -29,6 +29,13 @@ TEST(OpKernelType, ToString) { ASSERT_EQ(paddle::framework::KernelTypeToString(op_kernel_type), "data_type[float]:data_layout[NCHW]:place[CPUPlace]:library_type[" "CUDNN]"); + + using CUDAPlace = paddle::platform::CUDAPlace; + OpKernelType op_kernel_type2(DataType::FP16, CUDAPlace(0), DataLayout::kNCHW, + LibraryType::kCUDNN); + ASSERT_EQ(paddle::framework::KernelTypeToString(op_kernel_type2), + "data_type[float16]:data_layout[NCHW]:place[CUDAPlace(0)]:library_" + "type[CUDNN]"); } TEST(OpKernelType, Hash) { diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 7c1c29fd9a..38c4297380 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -69,6 +69,21 @@ static DDim GetDims(const Scope& scope, const std::string& name, } } +static std::string GetDtype(const Scope& scope, const std::string& name) { + Variable* var = scope.FindVar(name); + if (var == nullptr) { + return ""; + } + if (var->IsType()) { + return DataTypeToString(ToDataType(var->Get().type())); + } else if (var->IsType()) { + return DataTypeToString( + ToDataType(var->Get().value().type())); + } else { + return ""; + } +} + static int GetRowSize(const Scope& scope, const std::string& name) { Variable* var = scope.FindVar(name); if (var == nullptr) { @@ -172,6 +187,8 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const { if (row_size >= 0) { ss << "[row_size=" << row_size << "]"; } + std::string dtype = GetDtype(*scope, input.second[i]); + ss << ":" << dtype; ss << "[" << GetDims(*scope, input.second[i], true) << "]"; ss << "(" << GetLoD(*scope, input.second[i]) << ")"; } diff --git a/paddle/fluid/framework/tensor_test.cc b/paddle/fluid/framework/tensor_test.cc index 0a1cb6d570..cb2061c06a 100644 --- a/paddle/fluid/framework/tensor_test.cc +++ b/paddle/fluid/framework/tensor_test.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/tensor.h" #include #include +#include "paddle/fluid/platform/float16.h" namespace framework = paddle::framework; namespace platform = paddle::platform; @@ -213,3 +214,17 @@ TEST(Tensor, Layout) { src.set_layout(framework::DataLayout::kAnyLayout); ASSERT_EQ(src.layout(), framework::DataLayout::kAnyLayout); } + +TEST(Tensor, FP16) { + using platform::float16; + framework::Tensor src; + float16* src_ptr = src.mutable_data({2, 3}, platform::CPUPlace()); + for (int i = 0; i < 2 * 3; ++i) { + src_ptr[i] = static_cast(i); + } + EXPECT_EQ(src.memory_size(), 2 * 3 * sizeof(float16)); + // EXPECT a human readable error message + // src.data(); + // Tensor holds the wrong type, it holds N6paddle8platform7float16E at + // [/paddle/Paddle/paddle/fluid/framework/tensor_impl.h:43] +} diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 82b5e7cf0b..2ddfd47fe0 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -66,6 +66,10 @@ def get_numeric_gradient(place, tensor_to_check_dtype = np.float32 elif tensor_to_check_dtype == core.VarDesc.VarType.FP64: tensor_to_check_dtype = np.float64 + elif tensor_to_check_dtype == core.VarDesc.VarType.FP16: + tensor_to_check_dtype = np.float16 + # set delta as np.float16, will automatic convert to float32, float64 + delta = np.array(delta).astype(np.float16) else: raise ValueError("Not supported data type " + str( tensor_to_check_dtype)) @@ -73,13 +77,24 @@ def get_numeric_gradient(place, gradient_flat = np.zeros(shape=(tensor_size, ), dtype=tensor_to_check_dtype) def __get_elem__(tensor, i): - if tensor_to_check_dtype == np.float32: + if tensor_to_check_dtype == np.float16: + numpy_tensor = np.array(tensor).astype(np.float16) + numpy_tensor = numpy_tensor.flatten() + return numpy_tensor[i] + elif tensor_to_check_dtype == np.float32: return tensor._get_float_element(i) else: return tensor._get_double_element(i) def __set_elem__(tensor, i, e): - if tensor_to_check_dtype == np.float32: + if tensor_to_check_dtype == np.float16: + numpy_tensor = np.array(tensor).astype(np.float16) + shape = numpy_tensor.shape + numpy_tensor = numpy_tensor.flatten() + numpy_tensor[i] = e + numpy_tensor = numpy_tensor.reshape(shape).view(np.uint16) + tensor.set(numpy_tensor, place) + elif tensor_to_check_dtype == np.float32: tensor._set_float_element(i, e) else: tensor._set_double_element(i, e) @@ -133,6 +148,11 @@ class OpTest(unittest.TestCase): if not self.call_once: self.call_once = True self.dtype = data_type + # See the comment of np_dtype_to_fluid_dtype + # If the input type is uint16, we assume use float16 + # for lodtensor dtype. + if self.dtype == np.uint16: + self.dtype == np.float16 def infer_dtype_from_inputs_outputs(self, inputs, outputs): def infer_dtype(numpy_dict): @@ -161,19 +181,25 @@ class OpTest(unittest.TestCase): for name, np_value in self.inputs[var_name]: tensor = core.LoDTensor() if isinstance(np_value, tuple): - tensor.set(np_value[0], place) + tensor.set( + OpTest.np_value_to_fluid_value(np_value[0]), place) tensor.set_recursive_sequence_lengths(np_value[1]) else: - tensor.set(np_value, place) + tensor.set( + OpTest.np_value_to_fluid_value(np_value), place) feed_map[name] = tensor else: tensor = core.LoDTensor() if isinstance(self.inputs[var_name], tuple): - tensor.set(self.inputs[var_name][0], place) + tensor.set( + OpTest.np_value_to_fluid_value(self.inputs[var_name][ + 0]), place) tensor.set_recursive_sequence_lengths(self.inputs[var_name][ 1]) else: - tensor.set(self.inputs[var_name], place) + tensor.set( + OpTest.np_value_to_fluid_value(self.inputs[var_name]), + place) feed_map[var_name] = tensor return feed_map @@ -307,13 +333,22 @@ class OpTest(unittest.TestCase): np.allclose( actual_t, expect_t, atol=atol), "Output (" + out_name + ") has diff at " + str(place) + - str(actual_t) + "\n" + str(expect_t)) + "\nExpect " + str(expect_t) + "\n" + "But Got" + + str(actual_t)) if isinstance(expect, tuple): self.assertListEqual(actual.recursive_sequence_lengths(), expect[1], "Output (" + out_name + ") has different lod at " + str(place)) def _get_places(self): + if self.dtype == np.float16: + if core.is_compiled_with_cuda() and core.op_support_gpu( + self.op_type): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + return [place] + else: + return [] places = [fluid.CPUPlace()] if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type): places.append(core.CUDAPlace(0)) @@ -344,9 +379,9 @@ class OpTest(unittest.TestCase): def err_msg(): offset = np.argmax(diff_mat > max_relative_error) return ("%s Variable %s max gradient diff %f over limit %f, " - "the first error element is %d, %f, %f") % ( - msg_prefix, name, max_diff, max_relative_error, - offset, a.flatten()[offset], b.flatten()[offset]) + "the first error element is %d, expected %f, but got %f" + ) % (msg_prefix, name, max_diff, max_relative_error, + offset, a.flatten()[offset], b.flatten()[offset]) self.assertLessEqual(max_diff, max_relative_error, err_msg()) @@ -435,6 +470,21 @@ class OpTest(unittest.TestCase): input.dtype = np.uint16 return input + @staticmethod + def fluid_dtype_to_np_dtype(self, dtype): + """ + See above, convert the dtype to normal type. + """ + if dtype == np.uint16: + dtype = np.float16 + return dtype + + @staticmethod + def np_value_to_fluid_value(input): + if input.dtype == np.float16: + input = input.view(np.uint16) + return input + def _get_gradient(self, input_to_check, place, @@ -457,7 +507,7 @@ class OpTest(unittest.TestCase): if isinstance(place, fluid.CUDAPlace(0)): use_cuda = True executor = fluid.ParallelExecutor( - use_cuda=use_cuda, loss_name=loss.name, main_program=program) + use_cuda=use_cuda, loss_name=loss.name, main_program=prog) else: executor = Executor(place) return map(np.array, diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py index d090960c84..daa5da8d95 100644 --- a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py +++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py @@ -17,6 +17,8 @@ import numpy as np import math from op_test import OpTest +np.random.seed(100) + def find_latest_set(num): return 1 + int(math.floor(math.log(num, 2))) diff --git a/python/paddle/fluid/tests/unittests/testsuite.py b/python/paddle/fluid/tests/unittests/testsuite.py index 55c6e54906..910d9538b0 100644 --- a/python/paddle/fluid/tests/unittests/testsuite.py +++ b/python/paddle/fluid/tests/unittests/testsuite.py @@ -18,14 +18,6 @@ import paddle.fluid.core as core from paddle.fluid.op import Operator -def as_lodtensor(np_array, lod, place): - tensor = core.LoDTensor() - tensor.set(np_value, place) - if lod is not None: - tensor.set_recursive_sequence_lengths(lod) - return tensor - - def create_op(scope, op_type, inputs, outputs, attrs): kwargs = dict() @@ -69,6 +61,11 @@ def create_op(scope, op_type, inputs, outputs, attrs): def set_input(scope, op, inputs, place): + def np_value_to_fluid_value(input): + if input.dtype == np.float16: + input = input.view(np.uint16) + return input + def __set_input__(var_name, var): if isinstance(var, tuple) or isinstance(var, np.ndarray): tensor = scope.find_var(var_name).get_tensor() @@ -76,7 +73,7 @@ def set_input(scope, op, inputs, place): tensor.set_recursive_sequence_lengths(var[1]) var = var[0] tensor._set_dims(var.shape) - tensor.set(var, place) + tensor.set(np_value_to_fluid_value(var), place) elif isinstance(var, float): scope.find_var(var_name).set_float(var) elif isinstance(var, int): @@ -104,6 +101,7 @@ def append_input_output(block, op_proto, np_list, is_input, dtype): if name not in np_list: assert var_proto.intermediate, "{} not found".format(name) else: + # inferece the dtype from numpy value. np_value = np_list[name] if isinstance(np_value, tuple): dtype = np_value[0].dtype @@ -116,6 +114,16 @@ def append_input_output(block, op_proto, np_list, is_input, dtype): if is_input: shape = list(np_value.shape) lod_level = 0 + # NOTE(dzhwinter): type hacking + # numpy float16 is binded to paddle::platform::float16 + # in tensor_py.h via the help of uint16 datatype. Because + # the internal memory representation of float16 is + # actually uint16_t in paddle. So we use np.uint16 in numpy for + # raw memory, it can pass through the pybind. So in the testcase, + # we feed data use data.view(uint16), but the dtype is float16 in fact. + # The data.view(uint16) means do not cast the data type, but process data as the uint16 + if dtype == np.uint16: + dtype = np.float16 return block.create_var( dtype=dtype, shape=shape, lod_level=lod_level, name=name) -- GitLab