From 100a0750438b0961a05f4c458323a5cd701e7746 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Thu, 13 Oct 2022 10:30:34 +0800 Subject: [PATCH] [Cherry-pick] Add fp16 dtype support for set_value op (#46906) Fix set_value failure when source tensor is fp16 Dtype and destiny value is a number (dev PR link:#46801) --- paddle/fluid/operators/set_value_op.cc | 5 +++- paddle/fluid/pybind/eager_method.cc | 6 ++++- paddle/fluid/pybind/imperative.cc | 6 ++++- paddle/phi/ops/compat/set_value_sig.cc | 15 ++++++++++++ .../tests/unittests/test_set_value_op.py | 23 ++++++++++++++++++- python/paddle/fluid/variable_index.py | 5 +++- 6 files changed, 55 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/set_value_op.cc b/paddle/fluid/operators/set_value_op.cc index 074642e1b02..12cdccc5031 100644 --- a/paddle/fluid/operators/set_value_op.cc +++ b/paddle/fluid/operators/set_value_op.cc @@ -104,7 +104,8 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker { framework::proto::VarType::INT32, framework::proto::VarType::INT64, framework::proto::VarType::FP32, - framework::proto::VarType::FP64}) + framework::proto::VarType::FP64, + framework::proto::VarType::FP16}) .SetDefault(framework::proto::VarType::FP32); AddAttr>( "axes", "(list) Axes that `starts` and `ends` apply to."); @@ -135,6 +136,8 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker { .SetDefault({}); AddAttr>("fp64_values", "Store the float64 values.") .SetDefault({}); + AddAttr>("fp16_values", "Store the float16 values.") + .SetDefault({}); AddAttr>("shape", "(vector) Shape of values.") .SetDefault({}); diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index b384544bb60..c782b4df585 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -1151,11 +1151,15 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self, } else if (self->tensor.dtype() == paddle::experimental::DataType::BOOL) { attrs["bool_values"] = std::vector{value_obj_tmp.cast()}; + } else if (self->tensor.dtype() == + paddle::experimental::DataType::FLOAT16) { + attrs["fp16_values"] = + std::vector{value_obj_tmp.cast()}; } else { PADDLE_THROW(platform::errors::InvalidArgument( "When assign a value to a paddle.Tensor, " "the data type of the paddle.Tensor must be bool, " - "float32, int32 or int64, " + "float32, int32, int64 or float16, " "please check the type of tensor.")); } attrs["shape"] = std::vector{1}; diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 0fb00e911d9..26b3b307ef8 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -964,11 +964,15 @@ void BindImperative(py::module *m_ptr) { framework::proto::VarType::BOOL) { attrs["bool_values"] = std::vector{value_obj.cast()}; + } else if (self->DataType() == + framework::proto::VarType::FP16) { + attrs["fp16_values"] = + std::vector{value_obj.cast()}; } else { PADDLE_THROW(platform::errors::InvalidArgument( "When assign a value to a paddle.Tensor, " "the data type of the paddle.Tensor must be bool, " - "float32, int32 or int64, " + "float32, int32, int64 or float16, " "please check the type of tensor.")); } attrs["shape"] = std::vector{1}; diff --git a/paddle/phi/ops/compat/set_value_sig.cc b/paddle/phi/ops/compat/set_value_sig.cc index 6ff94a6e263..8c98606600a 100644 --- a/paddle/phi/ops/compat/set_value_sig.cc +++ b/paddle/phi/ops/compat/set_value_sig.cc @@ -724,6 +724,21 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { "shape", "bool_values"}, {"Out"}); + } else if (ctx.HasAttr("fp16_values") && + !paddle::any_cast>( + ctx.Attr("fp16_values")) + .empty()) { + return KernelSignature("set_value", + {"Input"}, + {"starts", + "ends", + "steps", + "axes", + "decrease_axes", + "none_axes", + "shape", + "fp16_values"}, + {"Out"}); } } } diff --git a/python/paddle/fluid/tests/unittests/test_set_value_op.py b/python/paddle/fluid/tests/unittests/test_set_value_op.py index 59ccff3973f..fad47fc158c 100644 --- a/python/paddle/fluid/tests/unittests/test_set_value_op.py +++ b/python/paddle/fluid/tests/unittests/test_set_value_op.py @@ -576,6 +576,28 @@ create_test_value_int64(TestSetValueItemSlice3) create_test_value_int64(TestSetValueItemSlice4) +def create_test_value_fp16(parent): + + class TestValueInt(parent): + + def set_value(self): + self.value = 3.7 + + def set_dtype(self): + self.dtype = "float16" + + cls_name = "{0}_{1}".format(parent.__name__, "Valuefp16") + TestValueInt.__name__ = cls_name + globals()[cls_name] = TestValueInt + + +create_test_value_fp16(TestSetValueItemInt) +create_test_value_fp16(TestSetValueItemSlice) +create_test_value_fp16(TestSetValueItemSlice2) +create_test_value_fp16(TestSetValueItemSlice3) +create_test_value_fp16(TestSetValueItemSlice4) + + def create_test_value_fp32(parent): class TestValueInt(parent): @@ -1015,7 +1037,6 @@ class TestError(TestSetValueBase): paddle.enable_static() with paddle.static.program_guard(self.program): self._value_type_error() - self._dtype_error() self._step_error() self._bool_list_error() self._bool_tensor_error() diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index a0a778759b0..1a13e830419 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -730,10 +730,13 @@ def _setitem_impl_(var, item, value): elif dtype == core.VarDesc.VarType.INT64: value_name = "int64_values" values = [int(v) for v in value.flat] + elif dtype == core.VarDesc.VarType.FP16: + value_name = "fp16_values" + values = [float(v) for v in value.flat] else: raise TypeError( "When assign a numpy.ndarray, integer or float to a paddle.Tensor, " - "the data type of the paddle.Tensor must be bool, float32, int32 or int64, but " + "the data type of the paddle.Tensor must be bool, float32, int32, int64 or float16, but " "received %s." % convert_dtype(dtype)) attrs[value_name] = values attrs["shape"] = shape -- GitLab