未验证 提交 b414645a 编写于 作者: 1 123malin 提交者: GitHub

fix #17430: int64类型的attr训练非预期 (#18264)

* fix int64_t

* update fill constant op unittest

* add empty line
上级 db212bb9
......@@ -133,6 +133,32 @@ struct ExtractAttribute<std::vector<int64_t>> {
const std::string& attr_name_;
};
template <>
struct ExtractAttribute<float> {
explicit ExtractAttribute(const std::string& attr_name)
: attr_name_(attr_name) {}
float* operator()(Attribute& attr) const {
if (attr.type() == typeid(int)) { // NOLINT
int val = boost::get<int>(attr);
attr = static_cast<float>(val);
} else if (attr.type() == typeid(int64_t)) { // NOLINT
int64_t val = boost::get<int64_t>(attr);
attr = static_cast<float>(val);
}
float* attr_value = nullptr;
try {
attr_value = &boost::get<float>(attr);
} catch (boost::bad_get& bad_get) {
PADDLE_THROW("Cannot get attribute %s by type float, its type is %s",
attr_name_, paddle::platform::demangle(attr.type().name()));
}
return attr_value;
}
const std::string& attr_name_;
};
template <typename T>
inline proto::AttrType AttrTypeID() {
Attribute tmp = T();
......
......@@ -77,6 +77,15 @@ struct paddle_variant_caster<V<Ts...>> {
}
}
if (std::is_same<T, float>::value) {
auto caster_int64 = make_caster<int64_t>();
if (caster_int64.load(src, convert)) {
VLOG(4) << "this value are float and int64 satisfy simula.";
value = cast_op<int64_t>(caster_int64);
return true;
}
}
value = cast_op<T>(caster);
return true;
}
......
......@@ -50,6 +50,34 @@ class TestFillConstantOp2(OpTest):
self.check_output()
class TestFillConstantOp3(OpTest):
def setUp(self):
'''Test fill_constant op with specified int64 value
'''
self.op_type = "fill_constant"
self.inputs = {}
self.attrs = {'shape': [123, 92], 'value': 10000000000}
self.outputs = {'Out': np.full((123, 92), 10000000000)}
def test_check_output(self):
self.check_output()
class TestFillConstantOp4(OpTest):
def setUp(self):
'''Test fill_constant op with specified int value
'''
self.op_type = "fill_constant"
self.inputs = {}
self.attrs = {'shape': [123, 92], 'value': 3}
self.outputs = {'Out': np.full((123, 92), 3)}
def test_check_output(self):
self.check_output()
class TestFillConstantOpWithSelectedRows(OpTest):
def check_with_place(self, place):
scope = core.Scope()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册