未验证 提交 d8edf487 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle-TRT]fix bug in fill_constant_batch_size_like op (#46334)

* fix beta bug in fill_constant_batch_size_like
上级 2963e6a0
...@@ -44,6 +44,10 @@ class FillConstantBatchSizeLikeOpConverter : public OpConverter { ...@@ -44,6 +44,10 @@ class FillConstantBatchSizeLikeOpConverter : public OpConverter {
PADDLE_GET_CONST(std::string, op_desc.GetAttr("str_value")); PADDLE_GET_CONST(std::string, op_desc.GetAttr("str_value"));
std::vector<int32_t> shape = std::vector<int32_t> shape =
PADDLE_GET_CONST(std::vector<int32_t>, op_desc.GetAttr("shape")); PADDLE_GET_CONST(std::vector<int32_t>, op_desc.GetAttr("shape"));
if (str_value == "") {
float value = PADDLE_GET_CONST(float, op_desc.GetAttr("value"));
str_value = std::to_string(value);
}
float value = std::stof(str_value); float value = std::stof(str_value);
auto* input_shape_tensor = Shape(input); auto* input_shape_tensor = Shape(input);
...@@ -65,7 +69,7 @@ class FillConstantBatchSizeLikeOpConverter : public OpConverter { ...@@ -65,7 +69,7 @@ class FillConstantBatchSizeLikeOpConverter : public OpConverter {
auto layer = TRT_ENGINE_ADD_LAYER( auto layer = TRT_ENGINE_ADD_LAYER(
engine_, Fill, nvinfer1::Dims{}, nvinfer1::FillOperation::kLINSPACE); engine_, Fill, nvinfer1::Dims{}, nvinfer1::FillOperation::kLINSPACE);
std::vector<float> value_vec(1, value); std::vector<float> value_vec(1, value);
std::vector<float> beta_vec(3, 0.); std::vector<float> beta_vec(shape.size(), 0.);
layer->setAlpha(value); layer->setAlpha(value);
layer->setBeta(0.f); layer->setBeta(0.f);
layer->setInput(0, *out_shape_tensor); layer->setInput(0, *out_shape_tensor);
......
...@@ -87,7 +87,9 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest): ...@@ -87,7 +87,9 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
"input_dim_idx": "input_dim_idx":
0, 0,
"str_value": "str_value":
"0.0", "",
"value":
0.0,
"shape": [K * num_layers, -1, hidden_size], "shape": [K * num_layers, -1, hidden_size],
"output_dim_idx": "output_dim_idx":
1, 1,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册