未验证 提交 b4989fb7 编写于 作者: L liym27 提交者: GitHub

Support vector<double> as type of op attribute and op set_value suppport...

Support vector<double> as type of op attribute and op set_value suppport vector<double> as value (#30126)
上级 c6296b2b
...@@ -165,6 +165,35 @@ struct ExtractAttribute<float> { ...@@ -165,6 +165,35 @@ struct ExtractAttribute<float> {
const std::string& attr_name_; const std::string& attr_name_;
}; };
template <>
struct ExtractAttribute<std::vector<double>> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
std::vector<double>* operator()(Attribute& attr) const {
if (attr.type() == typeid(std::vector<int>)) { // NOLINT
std::vector<int> val = BOOST_GET_CONST(std::vector<int>, attr);
std::vector<double> vec(val.begin(), val.end());
attr = vec;
} else if (attr.type() == typeid(std::vector<float>)) { // NOLINT
std::vector<float> val = BOOST_GET_CONST(std::vector<float>, attr);
std::vector<double> vec(val.begin(), val.end());
attr = vec;
}
std::vector<double>* attr_value = nullptr;
try {
attr_value = &boost::get<std::vector<double>>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot get attribute (%s) by type std::vector<double>, its type is "
"%s.",
attr_name_, paddle::platform::demangle(attr.type().name())));
}
return attr_value;
}
const std::string& attr_name_;
};
template <typename T> template <typename T>
inline proto::AttrType AttrTypeID() { inline proto::AttrType AttrTypeID() {
Attribute tmp = T(); Attribute tmp = T();
......
...@@ -35,6 +35,7 @@ enum AttrType { ...@@ -35,6 +35,7 @@ enum AttrType {
LONG = 9; LONG = 9;
BLOCKS = 10; BLOCKS = 10;
LONGS = 11; LONGS = 11;
FLOAT64S = 12;
} }
// OpDesc describes an instance of a C++ framework::OperatorBase // OpDesc describes an instance of a C++ framework::OperatorBase
...@@ -56,6 +57,7 @@ message OpDesc { ...@@ -56,6 +57,7 @@ message OpDesc {
optional int64 l = 13; optional int64 l = 13;
repeated int32 blocks_idx = 14; repeated int32 blocks_idx = 14;
repeated int64 longs = 15; repeated int64 longs = 15;
repeated double float64s = 16;
}; };
message Var { message Var {
......
...@@ -714,6 +714,10 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> { ...@@ -714,6 +714,10 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
VectorToRepeated(v, attr_->mutable_longs()); VectorToRepeated(v, attr_->mutable_longs());
} }
void operator()(const std::vector<double> &v) const {
VectorToRepeated(v, attr_->mutable_float64s());
}
void operator()(boost::blank) const { void operator()(boost::blank) const {
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
"Unsupported calling method of SetAttrDescVisitor object for " "Unsupported calling method of SetAttrDescVisitor object for "
......
...@@ -38,11 +38,10 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>; ...@@ -38,11 +38,10 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
using VariableValueMap = std::map<std::string, std::vector<Variable*>>; using VariableValueMap = std::map<std::string, std::vector<Variable*>>;
// The order should be as same as framework.proto // The order should be as same as framework.proto
using Attribute = using Attribute = boost::variant<
boost::variant<boost::blank, int, float, std::string, std::vector<int>, boost::blank, int, float, std::string, std::vector<int>, std::vector<float>,
std::vector<float>, std::vector<std::string>, bool, std::vector<std::string>, bool, std::vector<bool>, BlockDesc*, int64_t,
std::vector<bool>, BlockDesc*, int64_t, std::vector<BlockDesc*>, std::vector<int64_t>, std::vector<double>>;
std::vector<BlockDesc*>, std::vector<int64_t>>;
using AttributeMap = std::unordered_map<std::string, Attribute>; using AttributeMap = std::unordered_map<std::string, Attribute>;
......
...@@ -79,6 +79,8 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker { ...@@ -79,6 +79,8 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<int64_t>>("int64_values", "store the int64 values") AddAttr<std::vector<int64_t>>("int64_values", "store the int64 values")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<double>>("fp64_values", "store the float64 values")
.SetDefault({});
AddAttr<std::vector<int64_t>>("shape", "(vector<int64_t>) Shape of values.") AddAttr<std::vector<int64_t>>("shape", "(vector<int64_t>) Shape of values.")
.SetDefault({}); .SetDefault({});
......
...@@ -43,9 +43,13 @@ inline std::string GetValueName(framework::proto::VarType::Type data_type) { ...@@ -43,9 +43,13 @@ inline std::string GetValueName(framework::proto::VarType::Type data_type) {
case framework::proto::VarType::FP32: case framework::proto::VarType::FP32:
value_name = "fp32_values"; value_name = "fp32_values";
break; break;
case framework::proto::VarType::FP64:
value_name = "fp64_values";
break;
case framework::proto::VarType::BOOL: case framework::proto::VarType::BOOL:
value_name = "bool_values"; value_name = "bool_values";
break; break;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported data type(code %d) for SetValue operator, only " "Unsupported data type(code %d) for SetValue operator, only "
......
...@@ -1897,9 +1897,10 @@ class Variable(object): ...@@ -1897,9 +1897,10 @@ class Variable(object):
dtype = self.dtype dtype = self.dtype
attrs['dtype'] = dtype attrs['dtype'] = dtype
from .data_feeder import convert_dtype
# 2.1 value is an integer of float # 2.1 value is an integer of float
if isinstance(value, (int, float)): if isinstance(value, (int, float)):
value = np.array([value]) value = np.array([value]).astype(convert_dtype(dtype))
# 2.2 value is a np.ndarray # 2.2 value is a np.ndarray
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
...@@ -1910,6 +1911,9 @@ class Variable(object): ...@@ -1910,6 +1911,9 @@ class Variable(object):
elif dtype == core.VarDesc.VarType.FP32: elif dtype == core.VarDesc.VarType.FP32:
value_name = "fp32_values" value_name = "fp32_values"
values = [float(v) for v in value.flat] values = [float(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.FP64:
value_name = "fp64_values"
values = [float(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.INT32: elif dtype == core.VarDesc.VarType.INT32:
value_name = "int32_values" value_name = "int32_values"
values = [int(v) for v in value.flat] values = [int(v) for v in value.flat]
...@@ -1917,7 +1921,6 @@ class Variable(object): ...@@ -1917,7 +1921,6 @@ class Variable(object):
value_name = "int64_values" value_name = "int64_values"
values = [int(v) for v in value.flat] values = [int(v) for v in value.flat]
else: else:
from .data_feeder import convert_dtype
raise TypeError( raise TypeError(
"When assign a numpy.ndarray, integer or float to a paddle.Tensor, " "When assign a numpy.ndarray, integer or float to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, float32, int32 or int64, but " "the data type of the paddle.Tensor must be bool, float32, int32 or int64, but "
......
...@@ -102,7 +102,7 @@ class TestSetValueItemSlice4(TestSetValueApi): ...@@ -102,7 +102,7 @@ class TestSetValueItemSlice4(TestSetValueApi):
# 2. Test different type of value: int, float, numpy.ndarray, Tensor # 2. Test different type of value: int, float, numpy.ndarray, Tensor
# 2.1 value is int32, int64, float32, bool # 2.1 value is int32, int64, float32, float64, bool
def create_test_value_int32(parent): def create_test_value_int32(parent):
...@@ -165,6 +165,26 @@ create_test_value_fp32(TestSetValueItemSlice3) ...@@ -165,6 +165,26 @@ create_test_value_fp32(TestSetValueItemSlice3)
create_test_value_fp32(TestSetValueItemSlice4) create_test_value_fp32(TestSetValueItemSlice4)
def create_test_value_fp64(parent):
class TestValueInt(parent):
def set_value(self):
self.value = 2.0**127 # float32:[-2^128, 2^128)
def set_dtype(self):
self.dtype = "float64"
cls_name = "{0}_{1}".format(parent.__name__, "ValueFp64")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_fp64(TestSetValueItemInt)
create_test_value_fp64(TestSetValueItemSlice)
create_test_value_fp64(TestSetValueItemSlice2)
create_test_value_fp64(TestSetValueItemSlice3)
create_test_value_fp64(TestSetValueItemSlice4)
def create_test_value_bool(parent): def create_test_value_bool(parent):
class TestValueInt(parent): class TestValueInt(parent):
def set_value(self): def set_value(self):
...@@ -185,7 +205,7 @@ create_test_value_bool(TestSetValueItemSlice3) ...@@ -185,7 +205,7 @@ create_test_value_bool(TestSetValueItemSlice3)
create_test_value_bool(TestSetValueItemSlice4) create_test_value_bool(TestSetValueItemSlice4)
# 2.2 value is numpy.array (int32, int64, float32, bool) # 2.2 value is numpy.array (int32, int64, float32, float64, bool)
def create_test_value_numpy_int32(parent): def create_test_value_numpy_int32(parent):
class TestValueInt(parent): class TestValueInt(parent):
def set_value(self): def set_value(self):
...@@ -246,6 +266,26 @@ create_test_value_numpy_fp32(TestSetValueItemSlice3) ...@@ -246,6 +266,26 @@ create_test_value_numpy_fp32(TestSetValueItemSlice3)
create_test_value_numpy_fp32(TestSetValueItemSlice4) create_test_value_numpy_fp32(TestSetValueItemSlice4)
def create_test_value_numpy_fp64(parent):
class TestValueInt(parent):
def set_value(self):
self.value = np.array([2**127]).astype("float64")
def set_dtype(self):
self.dtype = "float64"
cls_name = "{0}_{1}".format(parent.__name__, "ValueNumpyFp64")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_numpy_fp64(TestSetValueItemInt)
create_test_value_numpy_fp64(TestSetValueItemSlice)
create_test_value_numpy_fp64(TestSetValueItemSlice2)
create_test_value_numpy_fp64(TestSetValueItemSlice3)
create_test_value_numpy_fp64(TestSetValueItemSlice4)
def create_test_value_numpy_bool(parent): def create_test_value_numpy_bool(parent):
class TestValueInt(parent): class TestValueInt(parent):
def set_value(self): def set_value(self):
...@@ -451,7 +491,7 @@ class TestError(TestSetValueBase): ...@@ -451,7 +491,7 @@ class TestError(TestSetValueBase):
TypeError, TypeError,
"When assign a numpy.ndarray, integer or float to a paddle.Tensor, " "When assign a numpy.ndarray, integer or float to a paddle.Tensor, "
): ):
y = paddle.ones(shape=self.shape, dtype="float64") y = paddle.ones(shape=self.shape, dtype="float16")
y[0] = 1 y[0] = 1
def _step_error(self): def _step_error(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册