未验证 提交 100a0750 编写于 作者: 傅剑寒 提交者: GitHub

[Cherry-pick] Add fp16 dtype support for set_value op (#46906)

Fix set_value failure when source tensor is fp16 Dtype and destiny value is a number
(dev PR link:#46801)
上级 0280c0b9
......@@ -104,7 +104,8 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker {
framework::proto::VarType::INT32,
framework::proto::VarType::INT64,
framework::proto::VarType::FP32,
framework::proto::VarType::FP64})
framework::proto::VarType::FP64,
framework::proto::VarType::FP16})
.SetDefault(framework::proto::VarType::FP32);
AddAttr<std::vector<int64_t>>(
"axes", "(list<int64_t>) Axes that `starts` and `ends` apply to.");
......@@ -135,6 +136,8 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault({});
AddAttr<std::vector<double>>("fp64_values", "Store the float64 values.")
.SetDefault({});
AddAttr<std::vector<float>>("fp16_values", "Store the float16 values.")
.SetDefault({});
AddAttr<std::vector<int64_t>>("shape", "(vector<int64_t>) Shape of values.")
.SetDefault({});
......
......@@ -1151,11 +1151,15 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
} else if (self->tensor.dtype() ==
paddle::experimental::DataType::BOOL) {
attrs["bool_values"] = std::vector<int>{value_obj_tmp.cast<bool>()};
} else if (self->tensor.dtype() ==
paddle::experimental::DataType::FLOAT16) {
attrs["fp16_values"] =
std::vector<float>{value_obj_tmp.cast<float>()};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"When assign a value to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, "
"float32, int32 or int64, "
"float32, int32, int64 or float16, "
"please check the type of tensor."));
}
attrs["shape"] = std::vector<int64_t>{1};
......
......@@ -964,11 +964,15 @@ void BindImperative(py::module *m_ptr) {
framework::proto::VarType::BOOL) {
attrs["bool_values"] =
std::vector<int>{value_obj.cast<bool>()};
} else if (self->DataType() ==
framework::proto::VarType::FP16) {
attrs["fp16_values"] =
std::vector<float>{value_obj.cast<float>()};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"When assign a value to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, "
"float32, int32 or int64, "
"float32, int32, int64 or float16, "
"please check the type of tensor."));
}
attrs["shape"] = std::vector<int64_t>{1};
......
......@@ -724,6 +724,21 @@ KernelSignature SetValueOpArgumentMapping(const ArgumentMappingContext& ctx) {
"shape",
"bool_values"},
{"Out"});
} else if (ctx.HasAttr("fp16_values") &&
!paddle::any_cast<std::vector<float>>(
ctx.Attr("fp16_values"))
.empty()) {
return KernelSignature("set_value",
{"Input"},
{"starts",
"ends",
"steps",
"axes",
"decrease_axes",
"none_axes",
"shape",
"fp16_values"},
{"Out"});
}
}
}
......
......@@ -576,6 +576,28 @@ create_test_value_int64(TestSetValueItemSlice3)
create_test_value_int64(TestSetValueItemSlice4)
def create_test_value_fp16(parent):
class TestValueInt(parent):
def set_value(self):
self.value = 3.7
def set_dtype(self):
self.dtype = "float16"
cls_name = "{0}_{1}".format(parent.__name__, "Valuefp16")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_fp16(TestSetValueItemInt)
create_test_value_fp16(TestSetValueItemSlice)
create_test_value_fp16(TestSetValueItemSlice2)
create_test_value_fp16(TestSetValueItemSlice3)
create_test_value_fp16(TestSetValueItemSlice4)
def create_test_value_fp32(parent):
class TestValueInt(parent):
......@@ -1015,7 +1037,6 @@ class TestError(TestSetValueBase):
paddle.enable_static()
with paddle.static.program_guard(self.program):
self._value_type_error()
self._dtype_error()
self._step_error()
self._bool_list_error()
self._bool_tensor_error()
......
......@@ -730,10 +730,13 @@ def _setitem_impl_(var, item, value):
elif dtype == core.VarDesc.VarType.INT64:
value_name = "int64_values"
values = [int(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.FP16:
value_name = "fp16_values"
values = [float(v) for v in value.flat]
else:
raise TypeError(
"When assign a numpy.ndarray, integer or float to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, float32, int32 or int64, but "
"the data type of the paddle.Tensor must be bool, float32, int32, int64 or float16, but "
"received %s." % convert_dtype(dtype))
attrs[value_name] = values
attrs["shape"] = shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册