未验证 提交 d8edf487 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle-TRT]fix bug in fill_constant_batch_size_like op (#46334)

* fix beta bug in fill_constant_batch_size_like
上级 2963e6a0
......@@ -44,6 +44,10 @@ class FillConstantBatchSizeLikeOpConverter : public OpConverter {
PADDLE_GET_CONST(std::string, op_desc.GetAttr("str_value"));
std::vector<int32_t> shape =
PADDLE_GET_CONST(std::vector<int32_t>, op_desc.GetAttr("shape"));
if (str_value == "") {
float value = PADDLE_GET_CONST(float, op_desc.GetAttr("value"));
str_value = std::to_string(value);
}
float value = std::stof(str_value);
auto* input_shape_tensor = Shape(input);
......@@ -65,7 +69,7 @@ class FillConstantBatchSizeLikeOpConverter : public OpConverter {
auto layer = TRT_ENGINE_ADD_LAYER(
engine_, Fill, nvinfer1::Dims{}, nvinfer1::FillOperation::kLINSPACE);
std::vector<float> value_vec(1, value);
std::vector<float> beta_vec(3, 0.);
std::vector<float> beta_vec(shape.size(), 0.);
layer->setAlpha(value);
layer->setBeta(0.f);
layer->setInput(0, *out_shape_tensor);
......
......@@ -87,7 +87,9 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
"input_dim_idx":
0,
"str_value":
"0.0",
"",
"value":
0.0,
"shape": [K * num_layers, -1, hidden_size],
"output_dim_idx":
1,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册