未验证 提交 c3ba8056 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle-TRT] fix_fill_constant (#44481)

* fix_fill_constant

* fix_fill_constant

* fix_ernie
上级 ba89a3d3
......@@ -32,6 +32,10 @@ class FillConstantOpConverter : public OpConverter {
PADDLE_GET_CONST(std::string, op_desc.GetAttr("str_value"));
std::vector<int64_t> shape =
PADDLE_GET_CONST(std::vector<int64_t>, op_desc.GetAttr("shape"));
if (str_value == "") {
float value = PADDLE_GET_CONST(float, op_desc.GetAttr("value"));
str_value = std::to_string(value);
}
std::unique_ptr<framework::Tensor> out_tensor(new framework::Tensor());
out_tensor->Resize(phi::make_ddim(shape));
nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT;
......
......@@ -42,8 +42,14 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
for dtype in [5, 2, 3]:
for str_value in ["2", "23", "-1"]:
self.num_input = num_input
value = float(str_value)
if np.random.choice([False, True]):
str_value = str_value
else:
str_value = ""
dics = [{
"str_value": str_value,
"value": value,
"shape": shape,
"dtype": dtype
}, {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册