未验证 提交 ff25c5b3 编写于 作者: L liym27 提交者: GitHub

Fix bug: GetAttrValue should deal with attr with attrType vector<double> (#30536)

上级 572c466d
......@@ -69,6 +69,15 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
}
return val;
}
case proto::AttrType::FLOAT64S: {
std::vector<double> val(attr_desc.float64s_size());
for (int i = 0; i < attr_desc.float64s_size(); ++i) {
val[i] = attr_desc.float64s(i);
}
return val;
}
default:
PADDLE_THROW(platform::errors::Unavailable("Unsupport attribute type %d.",
attr_desc.type()));
......
......@@ -95,6 +95,18 @@ def test_set_value(x):
return x
class LayerWithSetValue(paddle.nn.Layer):
def __init__(self, input_dim, hidden):
super(LayerWithSetValue, self).__init__()
self.linear = paddle.nn.Linear(input_dim, hidden)
@paddle.jit.to_static
def forward(self, x):
x = self.linear(x)
x[0] = 1
return x
class TestSliceWithoutControlFlow(unittest.TestCase):
def setUp(self):
self.init_input()
......@@ -152,5 +164,17 @@ class TestSetValue(TestSliceWithoutControlFlow):
self.dygraph_func = test_set_value
class TestSetValueWithLayerAndSave(unittest.TestCase):
def test_set_value_with_save(self):
prog_trans.enable(True)
model = LayerWithSetValue(input_dim=10, hidden=1)
x = paddle.full(shape=[5, 10], fill_value=5.0, dtype="float32")
paddle.jit.save(
layer=model,
path="./layer_use_set_value",
input_spec=[x],
output_spec=None)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册