未验证 提交 100a0750 编写于 作者: 傅剑寒 提交者: GitHub

[Cherry-pick] Add fp16 dtype support for set_value op (#46906)

Fix set_value failure when source tensor is fp16 Dtype and destiny value is a number
(dev PR link:#46801)
上级 0280c0b9
...@@ -104,7 +104,8 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker { ...@@ -104,7 +104,8 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker {
framework::proto::VarType::INT32, framework::proto::VarType::INT32,
framework::proto::VarType::INT64, framework::proto::VarType::INT64,
framework::proto::VarType::FP32, framework::proto::VarType::FP32,
framework::proto::VarType::FP64}) framework::proto::VarType::FP64,
framework::proto::VarType::FP16})
.SetDefault(framework::proto::VarType::FP32); .SetDefault(framework::proto::VarType::FP32);
AddAttr<std::vector<int64_t>>( AddAttr<std::vector<int64_t>>(
"axes", "(list<int64_t>) Axes that `starts` and `ends` apply to."); "axes", "(list<int64_t>) Axes that `starts` and `ends` apply to.");
...@@ -135,6 +136,8 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker { ...@@ -135,6 +136,8 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<double>>("fp64_values", "Store the float64 values.") AddAttr<std::vector<double>>("fp64_values", "Store the float64 values.")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<float>>("fp16_values", "Store the float16 values.")
.SetDefault({});
AddAttr<std::vector<int64_t>>("shape", "(vector<int64_t>) Shape of values.") AddAttr<std::vector<int64_t>>("shape", "(vector<int64_t>) Shape of values.")
.SetDefault({}); .SetDefault({});
......
...@@ -1151,11 +1151,15 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self, ...@@ -1151,11 +1151,15 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
} else if (self->tensor.dtype() == } else if (self->tensor.dtype() ==
paddle::experimental::DataType::BOOL) { paddle::experimental::DataType::BOOL) {
attrs["bool_values"] = std::vector<int>{value_obj_tmp.cast<bool>()}; attrs["bool_values"] = std::vector<int>{value_obj_tmp.cast<bool>()};
} else if (self->tensor.dtype() ==
paddle::experimental::DataType::FLOAT16) {
attrs["fp16_values"] =
std::vector<float>{value_obj_tmp.cast<float>()};
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"When assign a value to a paddle.Tensor, " "When assign a value to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, " "the data type of the paddle.Tensor must be bool, "
"float32, int32 or int64, " "float32, int32, int64 or float16, "
"please check the type of tensor.")); "please check the type of tensor."));
} }
attrs["shape"] = std::vector<int64_t>{1}; attrs["shape"] = std::vector<int64_t>{1};
......
...@@ -964,11 +964,15 @@ void BindImperative(py::module *m_ptr) { ...@@ -964,11 +964,15 @@ void BindImperative(py::module *m_ptr) {
framework::proto::VarType::BOOL) { framework::proto::VarType::BOOL) {
attrs["bool_values"] = attrs["bool_values"] =
std::vector<int>{value_obj.cast<bool>()}; std::vector<int>{value_obj.cast<bool>()};
} else if (self->DataType() ==
framework::proto::VarType::FP16) {
attrs["fp16_values"] =
std::vector<float>{value_obj.cast<float>()};
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"When assign a value to a paddle.Tensor, " "When assign a value to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, " "the data type of the paddle.Tensor must be bool, "
"float32, int32 or int64, " "float32, int32, int64 or float16, "
"please check the type of tensor.")); "please check the type of tensor."));
} }
attrs["shape"] = std::vector<int64_t>{1}; attrs["shape"] = std::vector<int64_t>{1};
......
...@@ -724,6 +724,21 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -724,6 +724,21 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) {
"shape", "shape",
"bool_values"}, "bool_values"},
{"Out"}); {"Out"});
} else if (ctx.HasAttr("fp16_values") &&
!paddle::any_cast<std::vector<float>>(
ctx.Attr("fp16_values"))
.empty()) {
return KernelSignature("set_value",
{"Input"},
{"starts",
"ends",
"steps",
"axes",
"decrease_axes",
"none_axes",
"shape",
"fp16_values"},
{"Out"});
} }
} }
} }
......
...@@ -576,6 +576,28 @@ create_test_value_int64(TestSetValueItemSlice3) ...@@ -576,6 +576,28 @@ create_test_value_int64(TestSetValueItemSlice3)
create_test_value_int64(TestSetValueItemSlice4) create_test_value_int64(TestSetValueItemSlice4)
def create_test_value_fp16(parent):
class TestValueInt(parent):
def set_value(self):
self.value = 3.7
def set_dtype(self):
self.dtype = "float16"
cls_name = "{0}_{1}".format(parent.__name__, "Valuefp16")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_fp16(TestSetValueItemInt)
create_test_value_fp16(TestSetValueItemSlice)
create_test_value_fp16(TestSetValueItemSlice2)
create_test_value_fp16(TestSetValueItemSlice3)
create_test_value_fp16(TestSetValueItemSlice4)
def create_test_value_fp32(parent): def create_test_value_fp32(parent):
class TestValueInt(parent): class TestValueInt(parent):
...@@ -1015,7 +1037,6 @@ class TestError(TestSetValueBase): ...@@ -1015,7 +1037,6 @@ class TestError(TestSetValueBase):
paddle.enable_static() paddle.enable_static()
with paddle.static.program_guard(self.program): with paddle.static.program_guard(self.program):
self._value_type_error() self._value_type_error()
self._dtype_error()
self._step_error() self._step_error()
self._bool_list_error() self._bool_list_error()
self._bool_tensor_error() self._bool_tensor_error()
......
...@@ -730,10 +730,13 @@ def _setitem_impl_(var, item, value): ...@@ -730,10 +730,13 @@ def _setitem_impl_(var, item, value):
elif dtype == core.VarDesc.VarType.INT64: elif dtype == core.VarDesc.VarType.INT64:
value_name = "int64_values" value_name = "int64_values"
values = [int(v) for v in value.flat] values = [int(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.FP16:
value_name = "fp16_values"
values = [float(v) for v in value.flat]
else: else:
raise TypeError( raise TypeError(
"When assign a numpy.ndarray, integer or float to a paddle.Tensor, " "When assign a numpy.ndarray, integer or float to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, float32, int32 or int64, but " "the data type of the paddle.Tensor must be bool, float32, int32, int64 or float16, but "
"received %s." % convert_dtype(dtype)) "received %s." % convert_dtype(dtype))
attrs[value_name] = values attrs[value_name] = values
attrs["shape"] = shape attrs["shape"] = shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册