diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt
index cdd703e679d95cbea55dfda96810ad080a309789..8166c43e65db1fa7fb6e78884e67e695a88dfdd1 100755
--- a/paddle/fluid/framework/ir/CMakeLists.txt
+++ b/paddle/fluid/framework/ir/CMakeLists.txt
@@ -89,6 +89,7 @@ pass_library(delete_quant_dequant_filter_op_pass inference)
 pass_library(delete_weight_dequant_linear_op_pass inference)
 pass_library(delete_quant_dequant_linear_op_pass inference)
 pass_library(delete_dropout_op_pass inference)
+pass_library(delete_fill_constant_op_pass inference)
 pass_library(simplify_with_basic_ops_pass base)
 pass_library(fc_elementwise_layernorm_fuse_pass base)
 pass_library(skip_layernorm_fuse_pass base)
diff --git a/paddle/fluid/framework/ir/delete_fill_constant_op_pass.cc b/paddle/fluid/framework/ir/delete_fill_constant_op_pass.cc
new file mode 100644
index 0000000000000000000000000000000000000000..e86bb2926b640b33eed8378166ab417048aa20db
--- /dev/null
+++ b/paddle/fluid/framework/ir/delete_fill_constant_op_pass.cc
@@ -0,0 +1,103 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/framework/ir/delete_fill_constant_op_pass.h"
+#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
+
+namespace paddle {
+namespace framework {
+namespace ir {
+
+template <typename T>
+void FillConstData(LoDTensor* out_t, T value) {
+  auto output_data = out_t->mutable_data<T>(platform::CPUPlace());
+  for (int i = 0; i < out_t->numel(); i++) {
+    output_data[i] = value;
+  }
+}
+
+void DeleteFillConstantOpPass::ApplyImpl(ir::Graph* graph) const {
+  FusePassBase::Init("delete_fill_constant_op_pass", graph);
+  GraphPatternDetector detector;
+  auto fill_constant_op = detector.mutable_pattern()
+                              ->NewNode("fill_constant")
+                              ->assert_is_op("fill_constant")
+                              ->assert_is_not_op_input("ValueTensor")
+                              ->assert_is_not_op_input("str_value")
+                              ->assert_is_not_op_input("ShapeTensor")
+                              ->assert_is_not_op_input("ShapeTensorList");
+  auto fill_constant_out =
+      detector.mutable_pattern()
+          ->NewNode("fill_constant_out")
+          ->assert_is_op_output("fill_constant")
+          ->assert_more([](Node* x) { return x->outputs.size() == 1UL; });
+  auto next_op = detector.mutable_pattern()
+                     ->NewNode("next_op")
+                     ->assert_is_not_op_type("conditional_block")
+                     ->assert_is_not_op_type("while");
+  // Create the topological connections for the above pattern nodes.
+  fill_constant_op->LinksTo({fill_constant_out});
+  next_op->LinksFrom({fill_constant_out});
+
+  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
+                     Graph* graph) {
+    Node* fill_constant_op_node = subgraph.at(fill_constant_op);
+    Node* fill_constant_out_node = subgraph.at(fill_constant_out);
+    // Get fill_constant's attr
+    auto fill_constant = fill_constant_op_node->Op();
+    auto value = BOOST_GET_CONST(float, fill_constant->GetAttr("value"));
+    auto shape =
+        BOOST_GET_CONST(std::vector<int64_t>, fill_constant->GetAttr("shape"));
+    auto* scope = param_scope();
+    auto fill_constant_out_desc = fill_constant_out_node->Var();
+    fill_constant_out_desc->SetShape(shape);
+    fill_constant_out_desc->SetPersistable(true);
+    auto* fill_constant_out_tensor =
+        scope->Var(fill_constant_out_desc->Name())->GetMutable<LoDTensor>();
+    auto dtype =
+        framework::TransToPhiDataType(fill_constant_out_desc->GetDataType());
+    fill_constant_out_tensor->Resize(phi::make_ddim(shape));
+    switch (dtype) {
+      case paddle::experimental::DataType::BOOL:
+        FillConstData<bool>(fill_constant_out_tensor, static_cast<bool>(value));
+        break;
+      case paddle::experimental::DataType::INT32:
+        FillConstData<int32_t>(fill_constant_out_tensor,
+                               static_cast<int32_t>(value));
+        break;
+      case paddle::experimental::DataType::INT64:
+        FillConstData<int64_t>(fill_constant_out_tensor,
+                               static_cast<int64_t>(value));
+        break;
+      case paddle::experimental::DataType::FLOAT32:
+        FillConstData<float>(fill_constant_out_tensor,
+                             static_cast<float>(value));
+        break;
+      default:
+        LOG(WARNING) << "Unsupported dtype for fill_constant op: " << dtype;
+        return;
+    }
+    // Remove links in graph
+    GraphSafeRemoveNodes(graph, {fill_constant_op_node});
+  };
+
+  detector(graph, handler);
+}
+
+}  // namespace ir
+}  // namespace framework
+}  // namespace paddle
+
+REGISTER_PASS(delete_fill_constant_op_pass,
+              paddle::framework::ir::DeleteFillConstantOpPass);
diff --git a/paddle/fluid/framework/ir/delete_fill_constant_op_pass.h b/paddle/fluid/framework/ir/delete_fill_constant_op_pass.h
new file mode 100644
index 0000000000000000000000000000000000000000..33d10f4502f2ab7e3c9d4d363361c7ee920070b2
--- /dev/null
+++ b/paddle/fluid/framework/ir/delete_fill_constant_op_pass.h
@@ -0,0 +1,39 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+#include <vector>
+
+#include "paddle/fluid/framework/convert_utils.h"
+#include "paddle/fluid/framework/ir/fuse_pass_base.h"
+#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
+#include "paddle/fluid/platform/enforce.h"
+
+namespace paddle {
+namespace framework {
+namespace ir {
+
+class Graph;
+
+class DeleteFillConstantOpPass : public FusePassBase {
+ protected:
+  void ApplyImpl(ir::Graph* graph) const override;
+
+ private:
+  virtual ~DeleteFillConstantOpPass() = default;
+};
+
+}  // namespace ir
+}  // namespace framework
+}  // namespace paddle
diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc
index 8c8d9fdddec851c9854ebb0c784d2b56d6dd8526..f7c1a68c826f0935fb6c551a744776679fc0bb69 100644
--- a/paddle/fluid/framework/ir/graph_pattern_detector.cc
+++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc
@@ -408,6 +408,13 @@ PDNode *PDNode::assert_is_op(const std::string &op_type) {
   return this;
 }
 
+PDNode *PDNode::assert_is_not_op_type(const std::string &op_type) {
+  asserts_.emplace_back([op_type](Node *x) {
+    return x && x->IsOp() && x->Op()->Type() != op_type;
+  });
+  return this;
+}
+
 PDNode *PDNode::assert_is_var() {
   asserts_.emplace_back([](Node *x) { return x && x->IsVar(); });
   return this;
diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h
index 9e5a82fc4458603da8b2b51587cad39047bc75e9..cab8f82660d901d0a8318ce4c5079adf6231ab54 100644
--- a/paddle/fluid/framework/ir/graph_pattern_detector.h
+++ b/paddle/fluid/framework/ir/graph_pattern_detector.h
@@ -110,6 +110,7 @@ struct PDNode {
   // Assertions, helper functions to simplify the pattern definition.
   PDNode* assert_is_op();
   PDNode* assert_is_op(const std::string& op_type);
+  PDNode* assert_is_not_op_type(const std::string& op_type);
   PDNode* assert_is_var();
   PDNode* assert_var_dtype(proto::VarType::Type dtype);
   PDNode* assert_is_not_ctrl_var();
diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc
index 735e1b7be4c1fadacb9fc6fe90fb578863a5c32a..adc3fc46f72ac8898d1ba0565eedc3ded4f65989 100644
--- a/paddle/fluid/inference/api/analysis_config.cc
+++ b/paddle/fluid/inference/api/analysis_config.cc
@@ -633,6 +633,11 @@ void AnalysisConfig::Update() {
           (pass == "conv_bn_fuse_pass")) {
         continue;
       }
+      // delete_fill_constant_op_pass is not used under trt dynamic shape
+      if ((!min_input_shape_.empty() || trt_tuned_dynamic_shape_) &&
+          pass == "delete_fill_constant_op_pass") {
+        continue;
+      }
       pass_builder()->AppendPass(pass);
     }
   }
diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc
index 6c81997d1356262332464d717044e242b2048811..13f81059df5e3320cb8166708e2f3c795548c504 100644
--- a/paddle/fluid/inference/api/analysis_predictor.cc
+++ b/paddle/fluid/inference/api/analysis_predictor.cc
@@ -1731,6 +1731,10 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<AnalysisConfig>(
 
 #if PADDLE_WITH_TENSORRT
 USE_TRT_CONVERTER(elementwise_add_weight);
+USE_TRT_CONVERTER(elementwise_sub_weight);
+USE_TRT_CONVERTER(elementwise_mul_weight);
+USE_TRT_CONVERTER(elementwise_div_weight);
+USE_TRT_CONVERTER(elementwise_pow_weight);
 USE_TRT_CONVERTER(elementwise_add_tensor);
 USE_TRT_CONVERTER(elementwise_sub_tensor);
 USE_TRT_CONVERTER(elementwise_div_tensor);
diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc
index 77203b069e602a828073aa20f4bc0b1a70e64b21..fdb979283f76ecdf38d0082cc5b72470d3032ddf 100644
--- a/paddle/fluid/inference/api/paddle_pass_builder.cc
+++ b/paddle/fluid/inference/api/paddle_pass_builder.cc
@@ -85,6 +85,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
   "adaptive_pool2d_convert_global_pass",
       "shuffle_channel_detect_pass",           //
       "quant_conv2d_dequant_fuse_pass",        //
+      "delete_fill_constant_op_pass",          //
       "delete_quant_dequant_op_pass",          //
       "delete_quant_dequant_filter_op_pass",   //
       "delete_weight_dequant_linear_op_pass",  //
diff --git a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
index 8fd0e1bbd068db709130624fd5c68f008608644f..35d3ead0097203b5b45a96b789fbb45579126d6e 100644
--- a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
+++ b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
@@ -53,20 +53,14 @@ class ElementwiseWeightOpConverter : public OpConverter {
     auto output_name = op_desc.Output("Out")[0];
     weight_data = engine_->GetWeightCPUData(op_desc.Input("Y").front(), Y_t);
     nvinfer1::Dims dims_x = X->getDimensions();
+    std::vector<int> dims_y = phi::vectorize<int>(Y_t->dims());
 
     auto regist_eltwise_weight = [&](nvinfer1::ScaleMode scale_mode) {
-      TensorRTEngine::Weight shift_weights{nvinfer1::DataType::kFLOAT,
-                                           static_cast<void*>(weight_data),
-                                           static_cast<size_t>(Y_t->numel())};
-      TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT, nullptr,
-                                           0};
-      TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr,
-                                           0};
-
       nvinfer1::IShuffleLayer* expand_layer = nullptr;
       nvinfer1::IShuffleLayer* squeeze_layer = nullptr;
       int dynamic_shape_offset = engine_->with_dynamic_shape() ? 1 : 0;
       auto input_dim = X->getDimensions();
+      // reshape
       if (input_dim.nbDims < 3 + dynamic_shape_offset) {
         nvinfer1::Dims expand_shape;
         expand_shape.nbDims = 3 + dynamic_shape_offset;
@@ -85,17 +79,45 @@ class ElementwiseWeightOpConverter : public OpConverter {
         expand_layer->setName(
             ("Elewise: Shuffle: (Output: " + output_name + ")").c_str());
       }
+      // eltwise_ops
+      TensorRTEngine::Weight shift_weights{nvinfer1::DataType::kFLOAT, nullptr,
+                                           0};
+      TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT, nullptr,
+                                           0};
+      TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr,
+                                           0};
       if (op_type_ == "add") {
-        nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER(
-            engine_, ScaleNd, *X, scale_mode, shift_weights.get(),
-            scale_weights.get(), power_weights.get(), dynamic_shape_offset);
-        layer = scale_layer;
+        shift_weights = TensorRTEngine::Weight(
+            nvinfer1::DataType::kFLOAT, static_cast<void*>(weight_data),
+            static_cast<size_t>(Y_t->numel()));
+      } else if (op_type_ == "sub") {
+        for (int i = 0; i < Y_t->numel(); i++) {
+          weight_data[i] = -weight_data[i];
+        }
+        shift_weights = TensorRTEngine::Weight(
+            nvinfer1::DataType::kFLOAT, static_cast<void*>(weight_data),
+            static_cast<size_t>(Y_t->numel()));
       } else if (op_type_ == "mul") {
-        nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER(
-            engine_, Scale, *X, scale_mode, scale_weights.get(),
-            shift_weights.get(), power_weights.get());
-        layer = scale_layer;
+        scale_weights = TensorRTEngine::Weight(
+            nvinfer1::DataType::kFLOAT, static_cast<void*>(weight_data),
+            static_cast<size_t>(Y_t->numel()));
+      } else if (op_type_ == "div") {
+        for (int i = 0; i < Y_t->numel(); i++) {
+          weight_data[i] = 1.f / weight_data[i];
+        }
+        scale_weights = TensorRTEngine::Weight(
+            nvinfer1::DataType::kFLOAT, static_cast<void*>(weight_data),
+            static_cast<size_t>(Y_t->numel()));
+      } else if (op_type_ == "pow") {
+        power_weights = TensorRTEngine::Weight(
+            nvinfer1::DataType::kFLOAT, static_cast<void*>(weight_data),
+            static_cast<size_t>(Y_t->numel()));
       }
+      nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER(
+          engine_, ScaleNd, *X, scale_mode, shift_weights.get(),
+          scale_weights.get(), power_weights.get(), dynamic_shape_offset);
+      layer = scale_layer;
+      // reshape
       if (input_dim.nbDims < 3 + dynamic_shape_offset) {
         nvinfer1::Dims squeeze_shape;
         squeeze_shape.nbDims = input_dim.nbDims;
@@ -113,71 +135,43 @@ class ElementwiseWeightOpConverter : public OpConverter {
       }
     };
 
+    // dynamic shape
     if (engine_->with_dynamic_shape()) {
-      if (Y_t->dims().size() == 1) {
-        auto scale_mode = nvinfer1::ScaleMode::kCHANNEL;
-        PADDLE_ENFORCE_EQ(Y_t->dims()[0], dims_x.d[1],
-                          platform::errors::InvalidArgument(
-                              "The Bias's size(%d) should be equal to the "
-                              "first dim(%d) of the Input.",
-                              Y_t->dims()[0], dims_x.d[1]));
-        regist_eltwise_weight(scale_mode);
+      if (dims_y.size() == 1 && dims_y[0] == dims_x.d[1]) {
+        regist_eltwise_weight(nvinfer1::ScaleMode::kCHANNEL);
+      } else if (dims_y.size() == 1 && dims_y[0] == 1) {
+        regist_eltwise_weight(nvinfer1::ScaleMode::kUNIFORM);
+      } else if (dims_y.size() == static_cast<size_t>(dims_x.nbDims)) {
+        regist_eltwise_weight(nvinfer1::ScaleMode::kELEMENTWISE);
       } else {
         PADDLE_THROW(platform::errors::InvalidArgument(
-            "The size of input bias's dims is %d, but TensorRT dynamic shape "
-            "only support size = 1 for Elementwise op!",
-            Y_t->dims().size()));
+            "The size of input_y's dims is %d, but TensorRT dynamic shape "
+            "only support size = 1 or size = input_x.size() for Elementwise "
+            "op!",
+            dims_y.size()));
       }
       return;
     }
 
+    // static shape with dynamic batch
     std::vector<int> no_batch_dims;
     int start_index = 0;
-
-    for (; start_index < dims_x.nbDims; start_index++)
+    for (; start_index < dims_x.nbDims; start_index++) {
       no_batch_dims.push_back(dims_x.d[start_index]);
-
-    auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
-
-    std::vector<int> dims_y = phi::vectorize<int>(Y_t->dims());
-    if (dims_y.size() == no_batch_dims.size() + 1) {
-      if (dims_y[0] == 1) dims_y.erase(dims_y.begin());
     }
-
     if (dims_y.size() == 1 && dims_y[0] == no_batch_dims[0]) {
-      scale_mode = nvinfer1::ScaleMode::kCHANNEL;
-    } else if (dims_y.size() == no_batch_dims.size() &&
-               dims_y[0] == no_batch_dims[0]) {
-      scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
-      for (size_t i = 1; i < no_batch_dims.size(); i++) {
-        if (dims_y[i] != no_batch_dims[i]) {
-          scale_mode = nvinfer1::ScaleMode::kCHANNEL;
-          break;
-        }
-      }
-      if (scale_mode == nvinfer1::ScaleMode::kCHANNEL) {
-        for (size_t i = 1; i < no_batch_dims.size(); i++) {
-          if (dims_y[i] != 1)
-            PADDLE_THROW(platform::errors::InvalidArgument(
-                "The bias's %d dim is %d, but TensorRT dynamic shape only "
-                "support it equals to 1 for Elementwise op!",
-                i, dims_y[i]));
-        }
-      }
+      regist_eltwise_weight(nvinfer1::ScaleMode::kCHANNEL);
+    } else if (dims_y.size() == 1 && dims_y[0] == 1) {
+      regist_eltwise_weight(nvinfer1::ScaleMode::kUNIFORM);
+    } else if (dims_y.size() == no_batch_dims.size() + 1) {
+      regist_eltwise_weight(nvinfer1::ScaleMode::kELEMENTWISE);
     } else {
-      if (dims_y.size() >= 1) {
-        PADDLE_THROW(platform::errors::InvalidArgument(
-            "The size of bias's dims is %d and bias's size is %d. TensorRT "
-            "doesn't support this shape for Elementwise op!",
-            dims_y.size(), dims_y[0]));
-      } else {
-        PADDLE_THROW(platform::errors::InvalidArgument(
-            "The size of bias's dims is %d. TensorRT doesn't support "
-            "this shape for Elementwise op!",
-            dims_y.size()));
-      }
+      PADDLE_THROW(platform::errors::InvalidArgument(
+          "The size of input_y's dims is %d, but TensorRT dynamic shape "
+          "only support size = 1 or size = input_x.size() for Elementwise "
+          "op!",
+          dims_y.size()));
     }
-    regist_eltwise_weight(scale_mode);
   }
 
  protected:
@@ -215,7 +209,6 @@ class ElementwiseTensorOpConverter : public OpConverter {
     auto common_func = [&](nvinfer1::ILayer* layer) {
       RreplenishLayerAndOutput(layer, "elementwise", {output_name}, test_mode);
     };
-
     if (dims_x.nbDims == dims_y.nbDims) {
       // The two input tensor should have the same dims
       VLOG(3) << "Convert a fluid elementwise op to TensorRT IElementWiseLayer";
@@ -244,7 +237,6 @@ class ElementwiseTensorOpConverter : public OpConverter {
         auto* plugin_layer = engine_->AddPlugin(
             inputs.data(), inputs.size(),
             reinterpret_cast<plugin::PluginTensorRT*>(plugin));
-
         layer = plugin_layer;
       }
     }
@@ -278,6 +270,21 @@ class ElementwiseWeightMulOpConverter : public ElementwiseWeightOpConverter {
   ElementwiseWeightMulOpConverter() { op_type_ = "mul"; }
 };
 
+class ElementwiseWeightSubOpConverter : public ElementwiseWeightOpConverter {
+ public:
+  ElementwiseWeightSubOpConverter() { op_type_ = "sub"; }
+};
+
+class ElementwiseWeightDivOpConverter : public ElementwiseWeightOpConverter {
+ public:
+  ElementwiseWeightDivOpConverter() { op_type_ = "div"; }
+};
+
+class ElementwiseWeightPowOpConverter : public ElementwiseWeightOpConverter {
+ public:
+  ElementwiseWeightPowOpConverter() { op_type_ = "pow"; }
+};
+
 class ElementwiseTensorAddOpConverter : public ElementwiseTensorOpConverter {
  public:
   ElementwiseTensorAddOpConverter() { op_type_ = "add"; }
@@ -321,6 +328,12 @@ REGISTER_TRT_OP_CONVERTER(elementwise_add_weight,
                           ElementwiseWeightAddOpConverter);
 REGISTER_TRT_OP_CONVERTER(elementwise_mul_weight,
                           ElementwiseWeightMulOpConverter);
+REGISTER_TRT_OP_CONVERTER(elementwise_sub_weight,
+                          ElementwiseWeightSubOpConverter);
+REGISTER_TRT_OP_CONVERTER(elementwise_div_weight,
+                          ElementwiseWeightDivOpConverter);
+REGISTER_TRT_OP_CONVERTER(elementwise_pow_weight,
+                          ElementwiseWeightPowOpConverter);
 
 REGISTER_TRT_OP_CONVERTER(elementwise_add_tensor,
                           ElementwiseTensorAddOpConverter);
diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h
index f7eb7f859afaa3700ff3703992291e02188f1a2a..0a99b12edc25c0b27fbccdc2972f3f653bd2111f 100644
--- a/paddle/fluid/inference/tensorrt/convert/op_converter.h
+++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h
@@ -67,10 +67,8 @@ class OpConverter {
     if (op_desc.Type().find("elementwise") != std::string::npos) {
       static std::unordered_set<std::string> add_tensor_op_set{
           "add", "mul", "sub", "div", "max", "min", "pow"};
-      // TODO(xingzhaolong): all mul, sub, div
-      // static std::unordered_set<std::string> add_weight_op_set {"add", "mul",
-      // "sub", "div"};
-      static std::unordered_set<std::string> add_weight_op_set{"add", "mul"};
+      static std::unordered_set<std::string> add_weight_op_set{
+          "add", "mul", "sub", "div", "pow"};
       PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1UL,
                         platform::errors::InvalidArgument(
                             "The input op's Input(\"Y\")."
diff --git a/paddle/fluid/inference/tensorrt/convert/strided_slice_op.cc b/paddle/fluid/inference/tensorrt/convert/strided_slice_op.cc
index 26046d38bcbd9f47dbedc9fdef29280cb69d4055..9680e90b2e29d624e457ba829efdb3c9884f34e3 100644
--- a/paddle/fluid/inference/tensorrt/convert/strided_slice_op.cc
+++ b/paddle/fluid/inference/tensorrt/convert/strided_slice_op.cc
@@ -39,7 +39,7 @@ class StridedSliceOpConverter : public OpConverter {
     framework::OpDesc op_desc(op, nullptr);
     auto* input = engine_->GetITensor(op_desc.Input("Input")[0]);
     nvinfer1::Dims input_dims = input->getDimensions();
-
+    auto output_name = op_desc.Output("Out")[0];
     std::vector<int> axes =
         BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("axes"));
     std::vector<int> starts =
@@ -48,79 +48,116 @@ class StridedSliceOpConverter : public OpConverter {
         BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("ends"));
     std::vector<int> strides =
         BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("strides"));
-
-    nvinfer1::Dims start;
-    start.nbDims = input_dims.nbDims;
     int axes_size = axes.size();
-    for (int i = 0; i < start.nbDims; i++) {
-      start.d[i] = 0;
-    }
-    for (int i = 0; i < axes_size; i++) {
-      start.d[axes[i]] = starts[i];
-    }
-
+    nvinfer1::Dims start;
     nvinfer1::Dims stride;
-    stride.nbDims = input_dims.nbDims;
-    for (int i = 0; i < stride.nbDims; i++) {
-      stride.d[i] = 1;
-    }
-    for (int i = 0; i < axes_size; i++) {
-      stride.d[axes[i]] = strides[i];
-    }
-
     nvinfer1::Dims size;
+    start.nbDims = input_dims.nbDims;
+    stride.nbDims = input_dims.nbDims;
     size.nbDims = input_dims.nbDims;
-    for (int i = 0; i < size.nbDims; i++) {
-      size.d[i] = 1;
+    for (int i = 0; i < input_dims.nbDims; i++) {
+      start.d[i] = 0;
+      stride.d[i] = 1;
+      size.d[i] = input_dims.d[i];
     }
 
-    auto output_name = op_desc.Output("Out")[0];
-
-    auto create_weights = [&](const std::vector<int>& data,
-                              const std::string& type) -> int* {
-      std::unique_ptr<framework::Tensor> tmp_tensor(new framework::Tensor());
-      int data_size = data.size();
-      tmp_tensor->Resize({data_size});
-      auto* tmp_data = tmp_tensor->mutable_data<int>(platform::CPUPlace());
-      for (int i = 0; i < data_size; i++) {
-        tmp_data[i] = data[i];
+    if (!engine_->with_dynamic_shape()) {
+      for (int i = 0; i < axes_size; i++) {
+        start.d[axes[i] - 1] = starts[i];
+      }
+      for (int i = 0; i < axes_size; i++) {
+        stride.d[axes[i] - 1] = strides[i];
+      }
+      for (int i = 0; i < axes_size; ++i) {
+        int dim = size.d[axes[i] - 1];
+        if (dim > 0) {
+          int start = starts[i] < 0 ? (starts[i] + dim) : starts[i];
+          int end = ends[i] < 0 ? (ends[i] + dim) : ends[i];
+          int stride = std::abs(strides[i]);
+          start = std::max(start, 0);
+          end = std::max(end, 0);
+          end = std::min(end, dim);
+          size.d[axes[i] - 1] = (std::abs(end - start) + stride - 1) / stride;
+        }
+      }
+      auto* layer =
+          TRT_ENGINE_ADD_LAYER(engine_, Slice, *input, start, size, stride);
+      RreplenishLayerAndOutput(layer, "strided_slice", {output_name},
+                               test_mode);
+    } else {
+      for (int i = 0; i < axes_size; i++) {
+        start.d[axes[i]] = starts[i];
+      }
+      for (int i = 0; i < axes_size; i++) {
+        stride.d[axes[i]] = strides[i];
+      }
+      for (int i = 0; i < axes_size; ++i) {
+        int dim = size.d[axes[i]];
+        if (dim > 0) {
+          int start = starts[i] < 0 ? (starts[i] + dim) : starts[i];
+          int end = ends[i] < 0 ? (ends[i] + dim) : ends[i];
+          int stride = std::abs(strides[i]);
+          start = std::max(start, 0);
+          end = std::max(end, 0);
+          end = std::min(end, dim);
+          size.d[axes[i]] = (std::abs(end - start) + stride - 1) / stride;
+        }
       }
 
-      engine_->SetWeights(output_name + "_add_slice_op_" + type,
-                          std::move(tmp_tensor));
-      return tmp_data;
-    };
+      auto create_weights = [&](const std::vector<int>& data,
+                                const std::string& type) -> int* {
+        std::unique_ptr<framework::Tensor> tmp_tensor(new framework::Tensor());
+        int data_size = data.size();
+        tmp_tensor->Resize({data_size});
+        auto* tmp_data = tmp_tensor->mutable_data<int>(platform::CPUPlace());
+        for (int i = 0; i < data_size; i++) {
+          tmp_data[i] = data[i];
+        }
+
+        engine_->SetWeights(output_name + "_add_slice_op_" + type,
+                            std::move(tmp_tensor));
+        return tmp_data;
+      };
+
+      std::vector<int> const_weight(input_dims.nbDims, 0);
+      for (int i = 0; i < axes_size; i++) {
+        int dim = input_dims.d[axes[i]];
+        int start = starts[i] < 0 ? (starts[i] + dim) : starts[i];
+        int end = ends[i] < 0 ? (ends[i] + dim) : ends[i];
+        int stride = std::abs(strides[i]);
+        start = std::max(start, 0);
+        end = std::max(end, 0);
+        end = std::min(end, dim);
+        const_weight[axes[i]] =
+            dim - ((std::abs(end - start) + stride - 1) / stride);
+      }
 
-    std::vector<int> const_weight(input_dims.nbDims, 1);
-    for (int i = 0; i < axes_size; i++) {
-      const_weight[axes[i]] = strides[i];
+      int* weight_data = create_weights(const_weight, "size");
+
+      TensorRTEngine::Weight weight{nvinfer1::DataType::kINT32,
+                                    static_cast<void*>(weight_data),
+                                    static_cast<size_t>(input_dims.nbDims)};
+
+      int input_dim_size = input_dims.nbDims;
+      nvinfer1::Dims input_shape;
+      input_shape.nbDims = 1;
+      input_shape.d[0] = input_dim_size;
+
+      auto const_layer =
+          TRT_ENGINE_ADD_LAYER(engine_, Constant, input_shape, weight.get());
+
+      auto shape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shape, *input);
+      // slice layer
+      auto* layer =
+          TRT_ENGINE_ADD_LAYER(engine_, Slice, *input, start, size, stride);
+      // elementwise layer for get size tensor
+      auto size_layer = TRT_ENGINE_ADD_LAYER(
+          engine_, ElementWise, *shape_layer->getOutput(0),
+          *const_layer->getOutput(0), nvinfer1::ElementWiseOperation::kSUB);
+      layer->setInput(2, *size_layer->getOutput(0));
+      RreplenishLayerAndOutput(layer, "strided_slice", {output_name},
+                               test_mode);
     }
-
-    int* weight_data = create_weights(const_weight, "size");
-
-    TensorRTEngine::Weight weight{nvinfer1::DataType::kINT32,
-                                  static_cast<void*>(weight_data),
-                                  static_cast<size_t>(input_dims.nbDims)};
-
-    int input_dim_size = input_dims.nbDims;
-    nvinfer1::Dims input_shape;
-    input_shape.nbDims = 1;
-    input_shape.d[0] = input_dim_size;
-
-    auto const_layer =
-        TRT_ENGINE_ADD_LAYER(engine_, Constant, input_shape, weight.get());
-
-    auto shape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shape, *input);
-
-    auto size_layer = TRT_ENGINE_ADD_LAYER(
-        engine_, ElementWise, *shape_layer->getOutput(0),
-        *const_layer->getOutput(0), nvinfer1::ElementWiseOperation::kDIV);
-
-    auto* layer =
-        TRT_ENGINE_ADD_LAYER(engine_, Slice, *input, start, size, stride);
-    layer->setInput(2, *size_layer->getOutput(0));
-
-    RreplenishLayerAndOutput(layer, "strided_slice", {output_name}, test_mode);
   }
 };
 
diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc
index ba5b28a4dfed9fa114fe38276c6dc18d3931610d..cbe151294db099040c70be79342fd94ef9106658 100644
--- a/paddle/fluid/inference/tensorrt/op_teller.cc
+++ b/paddle/fluid/inference/tensorrt/op_teller.cc
@@ -79,6 +79,7 @@ struct SimpleOpTypeSetTeller : public Teller {
       "elementwise_sub",
       "elementwise_mul",
       "elementwise_div",
+      "elementwise_pow",
       "dropout",
       "prelu",
       "conv2d_transpose",
@@ -145,6 +146,7 @@ struct SimpleOpTypeSetTeller : public Teller {
       "elementwise_sub",
       "elementwise_mul",
       "elementwise_div",
+      "elementwise_pow",
       "dropout",
       "prelu",
       "conv2d_transpose",
@@ -958,9 +960,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
           << "strided_slice converter does not support trt versions below 7.0";
       return false;
 #endif
-      if (!with_dynamic_shape) {
-        return false;
-      }
       if (!desc.HasAttr("axes") || !desc.HasAttr("starts") ||
           !desc.HasAttr("ends") || !desc.HasAttr("strides")) {
         VLOG(3)
@@ -1026,7 +1025,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
     }
 
     if (op_type == "elementwise_add" || op_type == "elementwise_mul" ||
-        op_type == "elementwise_sub" || op_type == "elementwise_div") {
+        op_type == "elementwise_sub" || op_type == "elementwise_div" ||
+        op_type == "elementwise_pow") {
       if (desc.Input("X").size() != 1) {
         VLOG(3) << "The input op's Input(\"X\").size() "
                    "should equal to 1, but received Input(\"X\").size() = "
@@ -1056,32 +1056,15 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
       auto* y_var_desc = block->FindVar(desc.Input("Y")[0]);
       const auto x_shape = x_var_desc->GetShape();
       const auto y_shape = y_var_desc->GetShape();
-      if (op_type == "elementwise_add" && y_var_desc->Persistable()) {
-        if (y_shape.size() != 1) {
-          return false;
-        }
-        if (y_shape[0] != x_shape[1]) {
-          return false;
-        }
-      }
       if (x_shape.size() == 1 && y_shape.size() == 1) {
         VLOG(3) << "Now trt may not support two 1d tensor elementwise op.";
         return false;
       }
-      if (op_type == "elementwise_add" || op_type == "elementwise_mul") {
-        if (x_var_desc->Persistable()) {
-          VLOG(3) << "Input X is a parameter which is not supported for "
-                     "elementwise_add/elementwise_mul in tensorrt, swap x and "
-                     "y will work";
-          return false;
-        }
-      }
-      if (op_type == "elementwise_sub" || op_type == "elementwise_div") {
-        if (x_var_desc->Persistable() || y_var_desc->Persistable()) {
-          VLOG(3) << "Input X or Input Y is a parameter which is not supported "
-                     "for elementwise_sub/elementwise_div in tensorrt";
-          return false;
-        }
+      if (x_var_desc->Persistable()) {
+        VLOG(3) << "Input X is a parameter which is not supported for "
+                   "elementwise_add/elementwise_mul in tensorrt, swap x and "
+                   "y will work";
+        return false;
       }
     }
 
diff --git a/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu
index c9163e62a2e19ea9c4449a5eaffd637844710d6d..1070a88cee7372cdbe6bcbef83681c624b7470a2 100644
--- a/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu
+++ b/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu
@@ -35,6 +35,19 @@ template <typename T>
 struct Div {
   __device__ T operator()(const T &a, const T &b) const { return a / b; }
 };
+
+template <typename T>
+struct Sub {
+  __device__ T operator()(const T &a, const T &b) const { return a - b; }
+};
+
+template <typename T>
+struct Pow {
+  __device__ T operator()(const T &a, const T &b) const {
+    return static_cast<T>(::powf(static_cast<float>(a), static_cast<float>(b)));
+  }
+};
+
 }  // namespace details
 
 template <typename T, typename Operator>
@@ -139,6 +152,14 @@ int ElementWisePlugin::enqueue(int batch_size, const void *const *inputs,
     elementwise_kernel<<<block, thread, 0, stream>>>(
         num, x, y, out, prev_size_, batch_size * midd_size_, post_size_,
         details::Div<float>());
+  } else if (type_ == "sub") {
+    elementwise_kernel<<<block, thread, 0, stream>>>(
+        num, x, y, out, prev_size_, batch_size * midd_size_, post_size_,
+        details::Sub<float>());
+  } else if (type_ == "pow") {
+    elementwise_kernel<<<block, thread, 0, stream>>>(
+        num, x, y, out, prev_size_, batch_size * midd_size_, post_size_,
+        details::Pow<float>());
   } else {
     PADDLE_THROW(platform::errors::Fatal(
         "The %s type elementwise is not implemented in trt plugin.", type_));
@@ -254,12 +275,18 @@ int ElementwisePluginDynamic::enqueue(
   } else if (type_ == "div") {
     elementwise_kernel<<<block, thread, 0, stream>>>(
         num, x, y, out, prev_size, midd_size, post_size, details::Div<float>());
+  } else if (type_ == "sub") {
+    elementwise_kernel<<<block, thread, 0, stream>>>(
+        num, x, y, out, prev_size, midd_size, post_size, details::Sub<float>());
+  } else if (type_ == "pow") {
+    elementwise_kernel<<<block, thread, 0, stream>>>(
+        num, x, y, out, prev_size, midd_size, post_size, details::Pow<float>());
   } else {
-    PADDLE_THROW(
-        platform::errors::Unimplemented("Paddle-TRT only support elementwise "
-                                        "operation: {add, mul, div} currently, "
-                                        "but got %s.",
-                                        type_));
+    PADDLE_THROW(platform::errors::Unimplemented(
+        "Paddle-TRT only support elementwise "
+        "operation: {add, mul, div, sub, pow} currently, "
+        "but got %s.",
+        type_));
   }
 
   return cudaGetLastError() != cudaSuccess;
diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py
index ec02a357a48b6a79150bd82705122e354fdc3364..27d8247aded5a26a7f535b6ce99727c995eebc1a 100644
--- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py
+++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_elementwise.py
@@ -150,7 +150,7 @@ class TrtConvertElementwiseTest_two_input_without_broadcast(
         for shape in [[4], [4, 32], [2, 64, 32], [1, 8, 16, 32]]:
             for op_type in [
                     "elementwise_add", "elementwise_mul", "elementwise_sub",
-                    "elementwise_div"
+                    "elementwise_div", "elementwise_pow"
             ]:
                 for axis in [0, -1]:
                     self.dims = len(shape)
@@ -309,7 +309,7 @@ class TrtConvertElementwiseTest_two_input_with_broadcast(TrtLayerAutoScanTest):
                 input2_shape = input2_shape_list[j][i]
                 for op_type in [
                         "elementwise_add", "elementwise_mul", "elementwise_sub",
-                        "elementwise_div"
+                        "elementwise_div", "elementwise_pow"
                 ]:
                     for axis in axis_list[j][i]:
                         self.shape1 = input1_shape
@@ -411,7 +411,7 @@ class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest):
                           [batch, 32, 16, 32]]:
                 for op_type in [
                         "elementwise_add", "elementwise_mul", "elementwise_sub",
-                        "elementwise_div"
+                        "elementwise_div", "elementwise_pow"
                 ]:
                     for axis in [-1 if len(shape) == 1 else 1]:
                         self.dims = len(shape)
@@ -511,18 +511,11 @@ class TrtConvertElementwiseTest_one_input_corner_case(TrtLayerAutoScanTest):
             for weight_name in program_config.weights:
                 if weight_name in input_x_names:
                     return True
-            op_type = program_config.ops[0].type
-            if op_type in ["elementwise_sub", "elementwise_div"]:
-                input_y_names = program_config.ops[0].inputs["Y"]
-                for weight_name in program_config.weights:
-                    if weight_name in input_y_names:
-                        return True
             return False
 
         self.add_skip_case(
             teller1, SkipReasons.TRT_NOT_SUPPORT,
-            "Input X should not be parameters in elementwise op and Input Y should not be parameters in elementwise_sub or elementwise_div op"
-        )
+            "Input X should not be parameters in elementwise op.")
 
     def test(self):
         self.add_skip_trt_case()
diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_strided_slice.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_strided_slice.py
index 6a204ebbad27d7a5738cc28e62c89756502a329f..8bc48047c1397409d843efcbbeca342041ef8b10 100644
--- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_strided_slice.py
+++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_strided_slice.py
@@ -113,6 +113,12 @@ class TrtConvertStridedSliceTest(TrtLayerAutoScanTest):
             for i in range(len(program_config.ops))
         ]
 
+        # for static_shape
+        clear_dynamic_shape()
+        self.trt_param.precision = paddle_infer.PrecisionType.Float32
+        yield self.create_inference_config(), generate_trt_nodes_num(
+            attrs, False), 1e-5
+
         # for dynamic_shape
         generate_dynamic_shape(attrs)
         self.trt_param.precision = paddle_infer.PrecisionType.Float32
@@ -121,3 +127,7 @@ class TrtConvertStridedSliceTest(TrtLayerAutoScanTest):
 
     def test(self):
         self.run_test()
+
+
+if __name__ == "__main__":
+    unittest.main()