diff --git a/lite/api/paddle_api.cc b/lite/api/paddle_api.cc index 4b13ae4ed241eb1a3164a1213feec12306df89f6..bfeff4879820f132a331e9bff56a5f9c494fe775 100644 --- a/lite/api/paddle_api.cc +++ b/lite/api/paddle_api.cc @@ -270,6 +270,16 @@ void CxxConfig::set_xpu_dev_per_thread(int dev_no) { #endif } +void CxxConfig::set_xpu_multi_encoder_precision(const std::string &precision) { +#ifdef LITE_WITH_XPU + lite::Context::_multi_encoder_precision = precision; +#else + LOG(WARNING) << "The invoking of the function " + "'set_xpu_multi_encoder_precision' is " + "ignored, please rebuild it with LITE_WITH_XPU=ON."; +#endif +} + // set model data in combined format, `set_model_from_file` refers to loading // model from file, set_model_from_buffer refers to loading model from memory // buffer diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index b08f2f5c745f87cda2be181bdea2444b2c11313c..f4c7bae753eed39aae6febf6ee29f57c9f7d7777 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -216,6 +216,7 @@ class LITE_API CxxConfig : public ConfigBase { // **DEPRECATED**, use xpu_set_device() at the very beginning of each worker // thread void set_xpu_dev_per_thread(int dev_no = 0); + void set_xpu_multi_encoder_precision(const std::string& precision = "int16"); }; /// MobileConfig is the config for the light weight predictor, it will skip diff --git a/lite/core/context.cc b/lite/core/context.cc index 711c67f8b7f36edcd2d66569d964296d96e8d85c..66d0c3946397610d83f83e65a9a7b95d5019110c 100644 --- a/lite/core/context.cc +++ b/lite/core/context.cc @@ -18,6 +18,7 @@ namespace paddle { namespace lite { #ifdef LITE_WITH_XPU +std::string Context::_multi_encoder_precision; // NOLINT thread_local xdnn::Context* Context::_tls_raw_ctx{nullptr}; int Context::_workspace_l3_size_per_thread{0}; #endif diff --git a/lite/core/context.h b/lite/core/context.h index d50e458472d2d9334a1fe19413b194e79084294d..324b5552acc7d82f463b204df1b81b5209c99e9b 100644 --- a/lite/core/context.h +++ b/lite/core/context.h @@ -178,6 +178,9 @@ class Context { std::string name() const { return "XPUContext"; } + public: + static std::string _multi_encoder_precision; // NOLINT + private: static thread_local xdnn::Context* _tls_raw_ctx; static int _workspace_l3_size_per_thread; diff --git a/lite/core/mir/elimination/identity_dropout_eliminate_pass.cc b/lite/core/mir/elimination/identity_dropout_eliminate_pass.cc index 92401df875da1f500ec09b34b2786d15cea2991b..cc0cc47b76104b68f091b2413b703a19a1f198bc 100644 --- a/lite/core/mir/elimination/identity_dropout_eliminate_pass.cc +++ b/lite/core/mir/elimination/identity_dropout_eliminate_pass.cc @@ -24,13 +24,30 @@ namespace { class Eliminator : public FuseBase { public: + static bool DropoutIsTest(const Node* x) { + if (x && x->IsStmt()) { + auto* op_info = x->stmt()->op_info(); + if (op_info->HasAttr("is_test")) { + auto attr_type = op_info->GetAttrType("is_test"); + if (attr_type == paddle::lite::OpDescAPI::AttrType::INT && + op_info->GetAttr("is_test") == 1) { + return true; + } else if (attr_type == paddle::lite::OpDescAPI::AttrType::BOOLEAN && + op_info->GetAttr("is_test")) { + return true; + } + } + } + return false; + } + void BuildPattern() override { // the previous op's output need updat auto* pre_op = OpNode("preop")->assert_is_not_op_type("conditional_block"); // TODO(Superjomn) check has only one output auto* x = VarNode("x")->assert_is_op_input("dropout", "X"); auto* dropout_op = OpNode("dropout", "dropout") - ->assert_op_attr("is_test", 1) + ->assert_node_satisfied(Eliminator::DropoutIsTest) ->assert_op_attr( "dropout_implementation", "upscale_in_train"); auto* out = VarNode("out")->assert_is_op_output("dropout", "Out"); diff --git a/lite/core/mir/fusion/__xpu__multi_encoder_fuse_pass.cc b/lite/core/mir/fusion/__xpu__multi_encoder_fuse_pass.cc index a6640f107f5dd46e6570a55cf59d2ad69a2bee1a..d653f87f7b5e4f71998ba1e73ac88398d89d328a 100644 --- a/lite/core/mir/fusion/__xpu__multi_encoder_fuse_pass.cc +++ b/lite/core/mir/fusion/__xpu__multi_encoder_fuse_pass.cc @@ -13,8 +13,10 @@ // limitations under the License. #include +#include #include #include "lite/backends/xpu/math.h" +#include "lite/core/context.h" #include "lite/core/mir/pass_registry.h" #include "lite/core/mir/type_precision_cast_pass.h" // For UpdateInputs() #include "lite/core/mir/xpu_pattern_matcher_high_api.h" @@ -125,14 +127,6 @@ class XPUSingleEncoderFuser : public FuseBase { auto* qk_softmax_out = VarNode("qk_softmax_out") ->assert_is_op_output("softmax", "Out") ->AsIntermediate(); - auto* qk_dropout = OpNode("qk_dropout", "dropout")->AsIntermediate(); - auto* qk_dropout_out = VarNode("qk_dropout_out") - ->assert_is_op_output("dropout", "Out") - ->assert_is_op_input("matmul", "X") - ->AsIntermediate(); - auto* qk_dropout_mask = VarNode("qk_dropout_mask") - ->assert_is_op_output("dropout", "Mask") - ->AsIntermediate(); auto* v_mul_y = VarNode("v_mul_y")->assert_is_op_input("mul", "Y")->AsInput(); @@ -203,16 +197,7 @@ class XPUSingleEncoderFuser : public FuseBase { auto* qkv_add = OpNode("qkv_add", "elementwise_add")->AsIntermediate(); auto* qkv_add_out = VarNode("qkv_add_out") ->assert_is_op_output("elementwise_add", "Out") - ->assert_is_op_input("dropout", "X") ->AsIntermediate(); - auto* qkv_dropout = OpNode("qkv_dropout", "dropout")->AsIntermediate(); - auto* qkv_dropout_out = VarNode("qkv_dropout_out") - ->assert_is_op_output("dropout", "Out") - ->assert_is_op_input("elementwise_add", "X") - ->AsIntermediate(); - auto* qkv_dropout_mask = VarNode("qkv_dropout_mask") - ->assert_is_op_output("dropout", "Mask") - ->AsIntermediate(); auto* qkv_add_2 = OpNode("qkv_add_2", "elementwise_add")->AsIntermediate(); auto* qkv_add_2_out = VarNode("qkv_add_2_out") @@ -271,16 +256,7 @@ class XPUSingleEncoderFuser : public FuseBase { auto* qkv_add_4 = OpNode("qkv_add_4", "elementwise_add")->AsIntermediate(); auto* qkv_add_4_out = VarNode("qkv_add_4_out") ->assert_is_op_output("elementwise_add", "Out") - ->assert_is_op_input("dropout", "X") ->AsIntermediate(); - auto* qkv_dropout_4 = OpNode("qkv_dropout_4", "dropout")->AsIntermediate(); - auto* qkv_dropout_4_out = VarNode("qkv_dropout_4_out") - ->assert_is_op_output("dropout", "Out") - ->assert_is_op_input("elementwise_add", "X") - ->AsIntermediate(); - auto* qkv_dropout_4_mask = VarNode("qkv_dropout_4_mask") - ->assert_is_op_output("dropout", "Mask") - ->AsIntermediate(); auto* qkv_add_5 = OpNode("qkv_add_5", "elementwise_add")->AsIntermediate(); auto* qkv_add_5_out = VarNode("qkv_add_5_out") @@ -321,9 +297,8 @@ class XPUSingleEncoderFuser : public FuseBase { *k_transpose2 >> *k_transpose2_xshape; *qk_matmul >> *qk_matmul_out >> *qk_add >> *qk_add_out >> *qk_softmax >> - *qk_softmax_out >> *qk_dropout >> *qk_dropout_out >> *qkv_matmul; + *qk_softmax_out >> *qkv_matmul; *qk_mask >> *qk_add; - *qk_dropout >> *qk_dropout_mask; *input >> *v_mul >> *v_mul_out >> *v_add >> *v_add_out >> *v_reshape2 >> *v_reshape2_out >> *v_transpose2 >> *v_transpose2_out >> *qkv_matmul; @@ -334,13 +309,11 @@ class XPUSingleEncoderFuser : public FuseBase { *qkv_matmul >> *qkv_matmul_out >> *qkv_transpose2 >> *qkv_transpose2_out >> *qkv_reshape2 >> *qkv_reshape2_out >> *qkv_mul >> *qkv_mul_out >> - *qkv_add >> *qkv_add_out >> *qkv_dropout >> *qkv_dropout_out >> - *qkv_add_2; + *qkv_add >> *qkv_add_out >> *qkv_add_2; *qkv_transpose2 >> *qkv_transpose2_xshape; *qkv_reshape2 >> *qkv_reshape2_xshape; *qkv_mul_y >> *qkv_mul; *qkv_add_y >> *qkv_add; - *qkv_dropout >> *qkv_dropout_mask; *input >> *qkv_add_2 >> *qkv_add_2_out >> *qkv_ln_2 >> *qkv_ln_2_out; *qkv_ln_2_scale >> *qkv_ln_2; @@ -350,13 +323,11 @@ class XPUSingleEncoderFuser : public FuseBase { *qkv_ln_2_out >> *qkv_mul_3 >> *qkv_mul_3_out >> *qkv_add_3 >> *qkv_add_3_out >> *qkv_act >> *qkv_act_out >> *qkv_mul_4 >> - *qkv_mul_4_out >> *qkv_add_4 >> *qkv_add_4_out >> *qkv_dropout_4 >> - *qkv_dropout_4_out >> *qkv_add_5; + *qkv_mul_4_out >> *qkv_add_4 >> *qkv_add_4_out >> *qkv_add_5; *qkv_mul_3_y >> *qkv_mul_3; *qkv_add_3_y >> *qkv_add_3; *qkv_mul_4_y >> *qkv_mul_4; *qkv_add_4_y >> *qkv_add_4; - *qkv_dropout_4 >> *qkv_dropout_4_mask; *qkv_ln_2_out >> *qkv_add_5 >> *qkv_add_5_out >> *qkv_ln_5 >> *qkv_ln_5_out; *qkv_ln_5_scale >> *qkv_ln_5; @@ -451,6 +422,9 @@ class XPUSingleEncoderFuser : public FuseBase { class XPUMultiEncoderFuser { public: + explicit XPUMultiEncoderFuser(const std::set& fc_int31_ids) + : fc_int31_ids_(fc_int31_ids) {} + bool IsDirectPredecessorOf(Node* op1, Node* op2) { for (auto* out : op1->outlinks) { for (auto* in : op2->inlinks) { @@ -542,6 +516,8 @@ class XPUMultiEncoderFuser { op_desc.SetAttr("n_layers", all_encoders.size()); op_desc.SetAttr( "act_type", first_encoder_op_info->GetAttr("act_type")); + op_desc.SetAttr("precision", + (fc_int31_ids_.empty() ? "int16" : "int31")); auto* scope = multi_encoder_stmt->op()->scope(); std::vector fc_weight_max(arg_map["FCWeight"].size()); @@ -553,18 +529,33 @@ class XPUMultiEncoderFuser { float* weight_on_host = weight_t->mutable_data(); float max_f = paddle::lite::xpu::math::FindMaxAbs(weight_on_host, weight_len); - - std::unique_ptr weight_int16(new int16_t[weight_len]); - std::unique_ptr weight_trans_int16(new int16_t[weight_len]); - paddle::lite::xpu::math::ConvertFP32ToInt16( - weight_on_host, weight_int16.get(), max_f, weight_len); - paddle::lite::xpu::math::Transpose(weight_int16.get(), - weight_trans_int16.get(), - weight_dims[0], - weight_dims[1]); - memcpy(weight_on_host, - weight_trans_int16.get(), - weight_len * sizeof(int16_t)); + // i ranges from 0 to 6*encoder_num, so we need to do i%6 to get relative + // position in the encoder + if (fc_int31_ids_.find(i % 6) != fc_int31_ids_.end()) { + // FCs in encoder use int31 + VLOG(3) << "Use FC-int31 in FC-" << i << ", " << i / 6 << "-" << i % 6; + std::unique_ptr weight_trans_fp32(new float[weight_len]); + paddle::lite::xpu::math::Transpose(weight_on_host, + weight_trans_fp32.get(), + weight_dims[0], + weight_dims[1]); + + memcpy(weight_on_host, + weight_trans_fp32.get(), + weight_len * sizeof(float)); + } else { + std::unique_ptr weight_int16(new int16_t[weight_len]); + std::unique_ptr weight_trans_int16(new int16_t[weight_len]); + paddle::lite::xpu::math::ConvertFP32ToInt16( + weight_on_host, weight_int16.get(), max_f, weight_len); + paddle::lite::xpu::math::Transpose(weight_int16.get(), + weight_trans_int16.get(), + weight_dims[0], + weight_dims[1]); + memcpy(weight_on_host, + weight_trans_int16.get(), + weight_len * sizeof(int16_t)); + } fc_weight_max[i] = max_f; } @@ -631,6 +622,9 @@ class XPUMultiEncoderFuser { GraphSafeRemoveNodes(graph, to_remove2); } } + + private: + std::set fc_int31_ids_; }; } // namespace fusion @@ -641,15 +635,35 @@ class XPUMultiEncoderFusePass : public ProgramPass { if (GetBoolFromEnv("XPU_ENABLE_XTCL")) return; // TODO(miaotianxiang): backup graph, recover from failed match std::vector act_types{"gelu", "relu"}; + + std::set fc_int31_ids; +#ifdef LITE_WITH_XPU + // TODO(miaotianxiang): core/mir/*_pass.cc are compiled anyway and need to + // access Context::_multi_encoder_precision, but this static member + // variable in class specialization defined in lite/core/context.cc + // is only compiled iff LITE_WITH_XPU==ON. To suppress linkage error, we use + // #ifdef here. Any better idea? + if (GetStringFromEnv("XPU_ENCODER_PRECISION", "int16") == "int31" || + lite::Context::_multi_encoder_precision == "int31") { + fc_int31_ids = {0, 1, 2, 3, 4, 5}; + VLOG(3) << "Use int31 in XPUMultiEncoderOp, " + << "lite::Context<>::_multi_encoder_precision=" + << lite::Context::_multi_encoder_precision; + } else { + VLOG(3) << "Use int16 in XPUMultiEncoderOp, " + << "lite::Context<>::_multi_encoder_precision=" + << lite::Context::_multi_encoder_precision; + } +#endif + for (auto& act_type : act_types) { fusion::XPUSingleEncoderFuser single_encoder_fuser(act_type); single_encoder_fuser(graph.get()); - fusion::XPUMultiEncoderFuser multi_encoder_fuser; + fusion::XPUMultiEncoderFuser multi_encoder_fuser(fc_int31_ids); multi_encoder_fuser(graph.get()); } } }; - } // namespace mir } // namespace lite } // namespace paddle diff --git a/lite/core/mir/pattern_matcher.h b/lite/core/mir/pattern_matcher.h index 90c4359c6d3ade98cf60b5c23411e2026cdeccc9..0cbfbd986ce743985fde64b8e71b9b0e2b135b9e 100644 --- a/lite/core/mir/pattern_matcher.h +++ b/lite/core/mir/pattern_matcher.h @@ -162,6 +162,12 @@ struct PMNode { attr_name, [=](const T& src) { return src == attr; }); } + PMNode* assert_node_satisfied( + const std::function& condition) { + asserts_.push_back(condition); + return this; + } + private: PMNode(PMPattern* pattern, const std::string& name = "", diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index c095ec9697923e51ef48c1992ce56569a00177ef..3c4b6b532dd9f85319089473061f279aa2ad2305 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -76,12 +76,11 @@ class Optimizer { (defined LITE_WITH_ARM) "lite_elementwise_activation_fuse_pass", // #endif + "identity_dropout_eliminate_pass", "__xpu__resnet_fuse_pass", "__xpu__multi_encoder_fuse_pass", "__xpu__embedding_with_eltwise_add_fuse_pass", "__xpu__fc_fuse_pass", - "identity_dropout_eliminate_pass", // should be placed after - // xpu fusion "quantized_op_attributes_inference_pass", // Only for fully // quantized model, infer // the output scale and diff --git a/lite/kernels/xpu/__xpu__multi_encoder_compute.cc b/lite/kernels/xpu/__xpu__multi_encoder_compute.cc index a0ba33110d2b3efd4a5e164da86ea949c95bbb63..781a5482413f27fb6e6c44166f04a2b2ea92bb34 100644 --- a/lite/kernels/xpu/__xpu__multi_encoder_compute.cc +++ b/lite/kernels/xpu/__xpu__multi_encoder_compute.cc @@ -46,26 +46,50 @@ void XPUMultiEncoderCompute::Run() { int batch_size = param.input->dims()[0]; int seq_len = param.input->dims()[1]; - int r = xdnn::bert_encoder_transformer_int16( - ctx.GetRawContext(), /* context */ - batch_size, /* batch_size */ - seq_len, /* from_seq_len */ - seq_len, /* to_seq_len */ - param.head_num, /* head_num */ - param.size_per_head, /* size_per_head */ - param.n_layers, /* n_layers */ - param.input->data(), /* from_tensor */ - param.input->data(), /* to_tensor */ - param.mask->data(), /* att_mask */ - &arg_fc_weight_[0], /* fc_weights */ - &arg_fc_bias_[0], /* fc_biass */ - &arg_ln_scale_[0], /* ln_scales */ - &arg_ln_bias_[0], /* ln_biass */ - param.output->mutable_data(TARGET(kXPU)), /* output */ - param.fc_weight_max->data(), /* fc_weights_max */ - true, /* pretrans_b */ - true, /* use_l3 */ - act_type_ /* act_type */); + int r = -1; + if (param.precision == "int31") { + r = xdnn::bert_encoder_transformer_int31( + ctx.GetRawContext(), /* context */ + batch_size, /* batch_size */ + seq_len, /* from_seq_len */ + seq_len, /* to_seq_len */ + param.head_num, /* head_num */ + param.size_per_head, /* size_per_head */ + param.n_layers, /* n_layers */ + param.input->data(), /* from_tensor */ + param.input->data(), /* to_tensor */ + param.mask->data(), /* att_mask */ + (const float**)(&arg_fc_weight_[0]), /* fc_weights */ + &arg_fc_bias_[0], /* fc_biass */ + &arg_ln_scale_[0], /* ln_scales */ + &arg_ln_bias_[0], /* ln_biass */ + param.output->mutable_data(TARGET(kXPU)), /* output */ + param.fc_weight_max->data(), /* fc_weights_max */ + true, /* pretrans_b */ + true, /* use_l3 */ + act_type_ /* act_type */); + } else { + r = xdnn::bert_encoder_transformer_int16( + ctx.GetRawContext(), /* context */ + batch_size, /* batch_size */ + seq_len, /* from_seq_len */ + seq_len, /* to_seq_len */ + param.head_num, /* head_num */ + param.size_per_head, /* size_per_head */ + param.n_layers, /* n_layers */ + param.input->data(), /* from_tensor */ + param.input->data(), /* to_tensor */ + param.mask->data(), /* att_mask */ + &arg_fc_weight_[0], /* fc_weights */ + &arg_fc_bias_[0], /* fc_biass */ + &arg_ln_scale_[0], /* ln_scales */ + &arg_ln_bias_[0], /* ln_biass */ + param.output->mutable_data(TARGET(kXPU)), /* output */ + param.fc_weight_max->data(), /* fc_weights_max */ + true, /* pretrans_b */ + true, /* use_l3 */ + act_type_ /* act_type */); + } CHECK_EQ(r, 0); } diff --git a/lite/operators/__xpu__multi_encoder_op.cc b/lite/operators/__xpu__multi_encoder_op.cc index 6d8aca942592668831b8d46d3e07ce83a57f1011..5a1d2cb82e5ba05035db5709ae2aae760593d33d 100644 --- a/lite/operators/__xpu__multi_encoder_op.cc +++ b/lite/operators/__xpu__multi_encoder_op.cc @@ -68,6 +68,7 @@ bool XPUMultiEncoderOp::AttachImpl(const cpp::OpDesc& op_desc, param_.head_num = op_desc.GetAttr("head_num"); param_.size_per_head = op_desc.GetAttr("size_per_head"); param_.act_type = op_desc.GetAttr("act_type"); + param_.precision = op_desc.GetAttr("precision"); return true; } diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index d2ae0ceb20d40aac662fd3068be79fd266f9e984..72ae2d7aa780b4c5a99cac06979f63100e15b5c7 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -1492,6 +1492,7 @@ struct XPUMultiEncoderParam : ParamBase { int head_num{}; int size_per_head{}; std::string act_type{}; + std::string precision{}; }; struct XPUEmbeddingWithEltwiseAddParam : ParamBase {