diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 4150d0ca555c9d2ddc706ef3d17ff05bde02c360..449849762cb10190f5eedffdc2206e8e2e933999 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2263,15 +2263,34 @@ PDNode *patterns::QuantizePlacement::operator()( PDNode *patterns::Bfloat16Placement::operator()( const std::unordered_set &bfloat16_enabled_op_types) { std::unordered_set supported_op_types = - std::unordered_set( - {"concat", "conv2d", "conv2d_transpose", - "elementwise_add", "elementwise_mul", "fc", - "fusion_gru", "fusion_lstm", "gelu", - "layer_norm", "matmul", "matmul_v2", - "pool2d", "prelu", "relu", - "reshape2", "softmax", "split", - "squeeze", "squeeze2", "sum", - "transpose2"}); + std::unordered_set({"cast", + "clip", + "concat", + "conv2d", + "conv2d_transpose", + "elementwise_add", + "elementwise_mul", + "expand_v2", + "fc", + "fusion_gru", + "fusion_lstm", + "gelu", + "layer_norm", + "matmul", + "matmul_v2", + "pool2d", + "prelu", + "relu", + "reshape2", + "scale", + "sigmoid", + "slice", + "softmax", + "split", + "squeeze", + "squeeze2", + "sum", + "transpose2"}); if (!bfloat16_enabled_op_types.empty()) { supported_op_types = bfloat16_enabled_op_types; }