From 2341ed5ea233b4e066ab121c87be864dca89846a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Tue, 11 Oct 2022 13:09:08 +0800 Subject: [PATCH] Fix set_value failure when source tensor is fp16 Dtype (#46801) * add fp16 data type for set_value * cancel flip modification * add fp16 dtype support for set_value --- paddle/fluid/operators/set_value_op.cc | 5 +++- paddle/fluid/pybind/eager_method.cc | 6 ++++- paddle/fluid/pybind/imperative.cc | 6 ++++- paddle/phi/ops/compat/set_value_sig.cc | 13 +++++++++++ .../tests/unittests/test_set_value_op.py | 23 ++++++++++++++++++- python/paddle/fluid/variable_index.py | 5 +++- 6 files changed, 53 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/set_value_op.cc b/paddle/fluid/operators/set_value_op.cc index b1fe2dedcb2..86049bec1eb 100644 --- a/paddle/fluid/operators/set_value_op.cc +++ b/paddle/fluid/operators/set_value_op.cc @@ -104,7 +104,8 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker { framework::proto::VarType::INT32, framework::proto::VarType::INT64, framework::proto::VarType::FP32, - framework::proto::VarType::FP64}) + framework::proto::VarType::FP64, + framework::proto::VarType::FP16}) .SetDefault(framework::proto::VarType::FP32); AddAttr>( "axes", "(list) Axes that `starts` and `ends` apply to."); @@ -135,6 +136,8 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker { .SetDefault({}); AddAttr>("fp64_values", "Store the float64 values.") .SetDefault({}); + AddAttr>("fp16_values", "Store the float16 values.") + .SetDefault({}); AddAttr>("shape", "(vector) Shape of values.") .SetDefault({}); diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 30e2ee25c83..1bce2d48e58 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -1217,11 +1217,15 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self, } else if (self->tensor.dtype() == paddle::experimental::DataType::BOOL) { attrs["bool_values"] = std::vector{value_obj_tmp.cast()}; + } else if (self->tensor.dtype() == + paddle::experimental::DataType::FLOAT16) { + attrs["fp16_values"] = + std::vector{value_obj_tmp.cast()}; } else { PADDLE_THROW(platform::errors::InvalidArgument( "When assign a value to a paddle.Tensor, " "the data type of the paddle.Tensor must be bool, " - "float32, int32 or int64, " + "float32, int32, int64 or float16, " "please check the type of tensor.")); } attrs["shape"] = std::vector{1}; diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 420a9839474..1c0c6488bda 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -964,11 +964,15 @@ void BindImperative(py::module *m_ptr) { framework::proto::VarType::BOOL) { attrs["bool_values"] = std::vector{value_obj.cast()}; + } else if (self->DataType() == + framework::proto::VarType::FP16) { + attrs["fp16_values"] = + std::vector{value_obj.cast()}; } else { PADDLE_THROW(platform::errors::InvalidArgument( "When assign a value to a paddle.Tensor, " "the data type of the paddle.Tensor must be bool, " - "float32, int32 or int64, " + "float32, int32, int64 or float16, " "please check the type of tensor.")); } attrs["shape"] = std::vector{1}; diff --git a/paddle/phi/ops/compat/set_value_sig.cc b/paddle/phi/ops/compat/set_value_sig.cc index 6ff94a6e263..c967f139939 100644 --- a/paddle/phi/ops/compat/set_value_sig.cc +++ b/paddle/phi/ops/compat/set_value_sig.cc @@ -724,6 +724,19 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "bool_values"}, {"Out"}); + } else if (ctx.HasAttr("fp16_values")) { + // NOTE(LiuYang):Here any_cast doesn't support fp16 values. + return KernelSignature("set_value", + {"Input"}, + {"starts", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp16_values"}, + {"Out"}); } } } diff --git a/python/paddle/fluid/tests/unittests/test_set_value_op.py b/python/paddle/fluid/tests/unittests/test_set_value_op.py index 8b74e209e1e..d47db363dcd 100644 --- a/python/paddle/fluid/tests/unittests/test_set_value_op.py +++ b/python/paddle/fluid/tests/unittests/test_set_value_op.py @@ -574,6 +574,28 @@ create_test_value_int64(TestSetValueItemSlice3) create_test_value_int64(TestSetValueItemSlice4) +def create_test_value_fp16(parent): + + class TestValueInt(parent): + + def set_value(self): + self.value = 3.7 + + def set_dtype(self): + self.dtype = "float16" + + cls_name = "{0}_{1}".format(parent.__name__, "Valuefp16") + TestValueInt.__name__ = cls_name + globals()[cls_name] = TestValueInt + + +create_test_value_fp16(TestSetValueItemInt) +create_test_value_fp16(TestSetValueItemSlice) +create_test_value_fp16(TestSetValueItemSlice2) +create_test_value_fp16(TestSetValueItemSlice3) +create_test_value_fp16(TestSetValueItemSlice4) + + def create_test_value_fp32(parent): class TestValueInt(parent): @@ -1013,7 +1035,6 @@ class TestError(TestSetValueBase): paddle.enable_static() with paddle.static.program_guard(self.program): self._value_type_error() - self._dtype_error() self._step_error() self._bool_list_error() self._bool_tensor_error() diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index f16de1ce060..3d67cd7e230 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -730,10 +730,13 @@ def _setitem_impl_(var, item, value): elif dtype == core.VarDesc.VarType.INT64: value_name = "int64_values" values = [int(v) for v in value.flat] + elif dtype == core.VarDesc.VarType.FP16: + value_name = "fp16_values" + values = [float(v) for v in value.flat] else: raise TypeError( "When assign a numpy.ndarray, integer or float to a paddle.Tensor, " - "the data type of the paddle.Tensor must be bool, float32, int32 or int64, but " + "the data type of the paddle.Tensor must be bool, float32, int32, int64 or float16, but " "received %s." % convert_dtype(dtype)) attrs[value_name] = values attrs["shape"] = shape -- GitLab