提交 2db5023d 编写于 作者: C Cwndmiao 提交者: GitHub

[LITE][XPU] 1. Add precision switch(int16/int31) in XPUMultiEncoderOp; 2. Fix...

[LITE][XPU] 1. Add precision switch(int16/int31) in XPUMultiEncoderOp; 2. Fix identity_dropout_eliminate_pass, |AttrType| of 'is_test' in OpDesc can be INT or BOOLEAN; 3. Enhance |__xpu__multi_encoder_fuse_pass|; (#3596)

* [LITE][XPU] Add precision switch(int16/int31) in XPUMultiEncoderOp

* [LITE][XPU] fix identity_dropout_eliminate_pass, |AttrType| of 'is_test' in OpDesc can be INT or BOOLEAN

* test=develop

* [LITE][XPU] suppress linkage error
test=develop

* [LITE][XPU] 1. Reorder |identity_dropout_eliminate_pass| before |__xpu__multi_encoder_fuse_pass|; 2. Enhance |__xpu__multi_encoder_fuse_pass|, it works well in more scenarios;
test=develop

* [LITE][XPU] Remove XPUConfig
test=develop
上级 a66b29d7
...@@ -270,6 +270,16 @@ void CxxConfig::set_xpu_dev_per_thread(int dev_no) { ...@@ -270,6 +270,16 @@ void CxxConfig::set_xpu_dev_per_thread(int dev_no) {
#endif #endif
} }
void CxxConfig::set_xpu_multi_encoder_precision(const std::string &precision) {
#ifdef LITE_WITH_XPU
lite::Context<TargetType::kXPU>::_multi_encoder_precision = precision;
#else
LOG(WARNING) << "The invoking of the function "
"'set_xpu_multi_encoder_precision' is "
"ignored, please rebuild it with LITE_WITH_XPU=ON.";
#endif
}
// set model data in combined format, `set_model_from_file` refers to loading // set model data in combined format, `set_model_from_file` refers to loading
// model from file, set_model_from_buffer refers to loading model from memory // model from file, set_model_from_buffer refers to loading model from memory
// buffer // buffer
......
...@@ -216,6 +216,7 @@ class LITE_API CxxConfig : public ConfigBase { ...@@ -216,6 +216,7 @@ class LITE_API CxxConfig : public ConfigBase {
// **DEPRECATED**, use xpu_set_device() at the very beginning of each worker // **DEPRECATED**, use xpu_set_device() at the very beginning of each worker
// thread // thread
void set_xpu_dev_per_thread(int dev_no = 0); void set_xpu_dev_per_thread(int dev_no = 0);
void set_xpu_multi_encoder_precision(const std::string& precision = "int16");
}; };
/// MobileConfig is the config for the light weight predictor, it will skip /// MobileConfig is the config for the light weight predictor, it will skip
......
...@@ -18,6 +18,7 @@ namespace paddle { ...@@ -18,6 +18,7 @@ namespace paddle {
namespace lite { namespace lite {
#ifdef LITE_WITH_XPU #ifdef LITE_WITH_XPU
std::string Context<TargetType::kXPU>::_multi_encoder_precision; // NOLINT
thread_local xdnn::Context* Context<TargetType::kXPU>::_tls_raw_ctx{nullptr}; thread_local xdnn::Context* Context<TargetType::kXPU>::_tls_raw_ctx{nullptr};
int Context<TargetType::kXPU>::_workspace_l3_size_per_thread{0}; int Context<TargetType::kXPU>::_workspace_l3_size_per_thread{0};
#endif #endif
......
...@@ -178,6 +178,9 @@ class Context<TargetType::kXPU> { ...@@ -178,6 +178,9 @@ class Context<TargetType::kXPU> {
std::string name() const { return "XPUContext"; } std::string name() const { return "XPUContext"; }
public:
static std::string _multi_encoder_precision; // NOLINT
private: private:
static thread_local xdnn::Context* _tls_raw_ctx; static thread_local xdnn::Context* _tls_raw_ctx;
static int _workspace_l3_size_per_thread; static int _workspace_l3_size_per_thread;
......
...@@ -24,13 +24,30 @@ namespace { ...@@ -24,13 +24,30 @@ namespace {
class Eliminator : public FuseBase { class Eliminator : public FuseBase {
public: public:
static bool DropoutIsTest(const Node* x) {
if (x && x->IsStmt()) {
auto* op_info = x->stmt()->op_info();
if (op_info->HasAttr("is_test")) {
auto attr_type = op_info->GetAttrType("is_test");
if (attr_type == paddle::lite::OpDescAPI::AttrType::INT &&
op_info->GetAttr<int>("is_test") == 1) {
return true;
} else if (attr_type == paddle::lite::OpDescAPI::AttrType::BOOLEAN &&
op_info->GetAttr<bool>("is_test")) {
return true;
}
}
}
return false;
}
void BuildPattern() override { void BuildPattern() override {
// the previous op's output need updat // the previous op's output need updat
auto* pre_op = OpNode("preop")->assert_is_not_op_type("conditional_block"); auto* pre_op = OpNode("preop")->assert_is_not_op_type("conditional_block");
// TODO(Superjomn) check has only one output // TODO(Superjomn) check has only one output
auto* x = VarNode("x")->assert_is_op_input("dropout", "X"); auto* x = VarNode("x")->assert_is_op_input("dropout", "X");
auto* dropout_op = OpNode("dropout", "dropout") auto* dropout_op = OpNode("dropout", "dropout")
->assert_op_attr<int>("is_test", 1) ->assert_node_satisfied(Eliminator::DropoutIsTest)
->assert_op_attr<std::string>( ->assert_op_attr<std::string>(
"dropout_implementation", "upscale_in_train"); "dropout_implementation", "upscale_in_train");
auto* out = VarNode("out")->assert_is_op_output("dropout", "Out"); auto* out = VarNode("out")->assert_is_op_output("dropout", "Out");
......
...@@ -13,8 +13,10 @@ ...@@ -13,8 +13,10 @@
// limitations under the License. // limitations under the License.
#include <memory> #include <memory>
#include <set>
#include <vector> #include <vector>
#include "lite/backends/xpu/math.h" #include "lite/backends/xpu/math.h"
#include "lite/core/context.h"
#include "lite/core/mir/pass_registry.h" #include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/type_precision_cast_pass.h" // For UpdateInputs() #include "lite/core/mir/type_precision_cast_pass.h" // For UpdateInputs()
#include "lite/core/mir/xpu_pattern_matcher_high_api.h" #include "lite/core/mir/xpu_pattern_matcher_high_api.h"
...@@ -125,14 +127,6 @@ class XPUSingleEncoderFuser : public FuseBase { ...@@ -125,14 +127,6 @@ class XPUSingleEncoderFuser : public FuseBase {
auto* qk_softmax_out = VarNode("qk_softmax_out") auto* qk_softmax_out = VarNode("qk_softmax_out")
->assert_is_op_output("softmax", "Out") ->assert_is_op_output("softmax", "Out")
->AsIntermediate(); ->AsIntermediate();
auto* qk_dropout = OpNode("qk_dropout", "dropout")->AsIntermediate();
auto* qk_dropout_out = VarNode("qk_dropout_out")
->assert_is_op_output("dropout", "Out")
->assert_is_op_input("matmul", "X")
->AsIntermediate();
auto* qk_dropout_mask = VarNode("qk_dropout_mask")
->assert_is_op_output("dropout", "Mask")
->AsIntermediate();
auto* v_mul_y = auto* v_mul_y =
VarNode("v_mul_y")->assert_is_op_input("mul", "Y")->AsInput(); VarNode("v_mul_y")->assert_is_op_input("mul", "Y")->AsInput();
...@@ -203,15 +197,6 @@ class XPUSingleEncoderFuser : public FuseBase { ...@@ -203,15 +197,6 @@ class XPUSingleEncoderFuser : public FuseBase {
auto* qkv_add = OpNode("qkv_add", "elementwise_add")->AsIntermediate(); auto* qkv_add = OpNode("qkv_add", "elementwise_add")->AsIntermediate();
auto* qkv_add_out = VarNode("qkv_add_out") auto* qkv_add_out = VarNode("qkv_add_out")
->assert_is_op_output("elementwise_add", "Out") ->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("dropout", "X")
->AsIntermediate();
auto* qkv_dropout = OpNode("qkv_dropout", "dropout")->AsIntermediate();
auto* qkv_dropout_out = VarNode("qkv_dropout_out")
->assert_is_op_output("dropout", "Out")
->assert_is_op_input("elementwise_add", "X")
->AsIntermediate();
auto* qkv_dropout_mask = VarNode("qkv_dropout_mask")
->assert_is_op_output("dropout", "Mask")
->AsIntermediate(); ->AsIntermediate();
auto* qkv_add_2 = OpNode("qkv_add_2", "elementwise_add")->AsIntermediate(); auto* qkv_add_2 = OpNode("qkv_add_2", "elementwise_add")->AsIntermediate();
...@@ -271,15 +256,6 @@ class XPUSingleEncoderFuser : public FuseBase { ...@@ -271,15 +256,6 @@ class XPUSingleEncoderFuser : public FuseBase {
auto* qkv_add_4 = OpNode("qkv_add_4", "elementwise_add")->AsIntermediate(); auto* qkv_add_4 = OpNode("qkv_add_4", "elementwise_add")->AsIntermediate();
auto* qkv_add_4_out = VarNode("qkv_add_4_out") auto* qkv_add_4_out = VarNode("qkv_add_4_out")
->assert_is_op_output("elementwise_add", "Out") ->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("dropout", "X")
->AsIntermediate();
auto* qkv_dropout_4 = OpNode("qkv_dropout_4", "dropout")->AsIntermediate();
auto* qkv_dropout_4_out = VarNode("qkv_dropout_4_out")
->assert_is_op_output("dropout", "Out")
->assert_is_op_input("elementwise_add", "X")
->AsIntermediate();
auto* qkv_dropout_4_mask = VarNode("qkv_dropout_4_mask")
->assert_is_op_output("dropout", "Mask")
->AsIntermediate(); ->AsIntermediate();
auto* qkv_add_5 = OpNode("qkv_add_5", "elementwise_add")->AsIntermediate(); auto* qkv_add_5 = OpNode("qkv_add_5", "elementwise_add")->AsIntermediate();
...@@ -321,9 +297,8 @@ class XPUSingleEncoderFuser : public FuseBase { ...@@ -321,9 +297,8 @@ class XPUSingleEncoderFuser : public FuseBase {
*k_transpose2 >> *k_transpose2_xshape; *k_transpose2 >> *k_transpose2_xshape;
*qk_matmul >> *qk_matmul_out >> *qk_add >> *qk_add_out >> *qk_softmax >> *qk_matmul >> *qk_matmul_out >> *qk_add >> *qk_add_out >> *qk_softmax >>
*qk_softmax_out >> *qk_dropout >> *qk_dropout_out >> *qkv_matmul; *qk_softmax_out >> *qkv_matmul;
*qk_mask >> *qk_add; *qk_mask >> *qk_add;
*qk_dropout >> *qk_dropout_mask;
*input >> *v_mul >> *v_mul_out >> *v_add >> *v_add_out >> *v_reshape2 >> *input >> *v_mul >> *v_mul_out >> *v_add >> *v_add_out >> *v_reshape2 >>
*v_reshape2_out >> *v_transpose2 >> *v_transpose2_out >> *qkv_matmul; *v_reshape2_out >> *v_transpose2 >> *v_transpose2_out >> *qkv_matmul;
...@@ -334,13 +309,11 @@ class XPUSingleEncoderFuser : public FuseBase { ...@@ -334,13 +309,11 @@ class XPUSingleEncoderFuser : public FuseBase {
*qkv_matmul >> *qkv_matmul_out >> *qkv_transpose2 >> *qkv_transpose2_out >> *qkv_matmul >> *qkv_matmul_out >> *qkv_transpose2 >> *qkv_transpose2_out >>
*qkv_reshape2 >> *qkv_reshape2_out >> *qkv_mul >> *qkv_mul_out >> *qkv_reshape2 >> *qkv_reshape2_out >> *qkv_mul >> *qkv_mul_out >>
*qkv_add >> *qkv_add_out >> *qkv_dropout >> *qkv_dropout_out >> *qkv_add >> *qkv_add_out >> *qkv_add_2;
*qkv_add_2;
*qkv_transpose2 >> *qkv_transpose2_xshape; *qkv_transpose2 >> *qkv_transpose2_xshape;
*qkv_reshape2 >> *qkv_reshape2_xshape; *qkv_reshape2 >> *qkv_reshape2_xshape;
*qkv_mul_y >> *qkv_mul; *qkv_mul_y >> *qkv_mul;
*qkv_add_y >> *qkv_add; *qkv_add_y >> *qkv_add;
*qkv_dropout >> *qkv_dropout_mask;
*input >> *qkv_add_2 >> *qkv_add_2_out >> *qkv_ln_2 >> *qkv_ln_2_out; *input >> *qkv_add_2 >> *qkv_add_2_out >> *qkv_ln_2 >> *qkv_ln_2_out;
*qkv_ln_2_scale >> *qkv_ln_2; *qkv_ln_2_scale >> *qkv_ln_2;
...@@ -350,13 +323,11 @@ class XPUSingleEncoderFuser : public FuseBase { ...@@ -350,13 +323,11 @@ class XPUSingleEncoderFuser : public FuseBase {
*qkv_ln_2_out >> *qkv_mul_3 >> *qkv_mul_3_out >> *qkv_add_3 >> *qkv_ln_2_out >> *qkv_mul_3 >> *qkv_mul_3_out >> *qkv_add_3 >>
*qkv_add_3_out >> *qkv_act >> *qkv_act_out >> *qkv_mul_4 >> *qkv_add_3_out >> *qkv_act >> *qkv_act_out >> *qkv_mul_4 >>
*qkv_mul_4_out >> *qkv_add_4 >> *qkv_add_4_out >> *qkv_dropout_4 >> *qkv_mul_4_out >> *qkv_add_4 >> *qkv_add_4_out >> *qkv_add_5;
*qkv_dropout_4_out >> *qkv_add_5;
*qkv_mul_3_y >> *qkv_mul_3; *qkv_mul_3_y >> *qkv_mul_3;
*qkv_add_3_y >> *qkv_add_3; *qkv_add_3_y >> *qkv_add_3;
*qkv_mul_4_y >> *qkv_mul_4; *qkv_mul_4_y >> *qkv_mul_4;
*qkv_add_4_y >> *qkv_add_4; *qkv_add_4_y >> *qkv_add_4;
*qkv_dropout_4 >> *qkv_dropout_4_mask;
*qkv_ln_2_out >> *qkv_add_5 >> *qkv_add_5_out >> *qkv_ln_5 >> *qkv_ln_5_out; *qkv_ln_2_out >> *qkv_add_5 >> *qkv_add_5_out >> *qkv_ln_5 >> *qkv_ln_5_out;
*qkv_ln_5_scale >> *qkv_ln_5; *qkv_ln_5_scale >> *qkv_ln_5;
...@@ -451,6 +422,9 @@ class XPUSingleEncoderFuser : public FuseBase { ...@@ -451,6 +422,9 @@ class XPUSingleEncoderFuser : public FuseBase {
class XPUMultiEncoderFuser { class XPUMultiEncoderFuser {
public: public:
explicit XPUMultiEncoderFuser(const std::set<int>& fc_int31_ids)
: fc_int31_ids_(fc_int31_ids) {}
bool IsDirectPredecessorOf(Node* op1, Node* op2) { bool IsDirectPredecessorOf(Node* op1, Node* op2) {
for (auto* out : op1->outlinks) { for (auto* out : op1->outlinks) {
for (auto* in : op2->inlinks) { for (auto* in : op2->inlinks) {
...@@ -542,6 +516,8 @@ class XPUMultiEncoderFuser { ...@@ -542,6 +516,8 @@ class XPUMultiEncoderFuser {
op_desc.SetAttr<int>("n_layers", all_encoders.size()); op_desc.SetAttr<int>("n_layers", all_encoders.size());
op_desc.SetAttr<std::string>( op_desc.SetAttr<std::string>(
"act_type", first_encoder_op_info->GetAttr<std::string>("act_type")); "act_type", first_encoder_op_info->GetAttr<std::string>("act_type"));
op_desc.SetAttr<std::string>("precision",
(fc_int31_ids_.empty() ? "int16" : "int31"));
auto* scope = multi_encoder_stmt->op()->scope(); auto* scope = multi_encoder_stmt->op()->scope();
std::vector<float> fc_weight_max(arg_map["FCWeight"].size()); std::vector<float> fc_weight_max(arg_map["FCWeight"].size());
...@@ -553,7 +529,21 @@ class XPUMultiEncoderFuser { ...@@ -553,7 +529,21 @@ class XPUMultiEncoderFuser {
float* weight_on_host = weight_t->mutable_data<float>(); float* weight_on_host = weight_t->mutable_data<float>();
float max_f = float max_f =
paddle::lite::xpu::math::FindMaxAbs(weight_on_host, weight_len); paddle::lite::xpu::math::FindMaxAbs(weight_on_host, weight_len);
// i ranges from 0 to 6*encoder_num, so we need to do i%6 to get relative
// position in the encoder
if (fc_int31_ids_.find(i % 6) != fc_int31_ids_.end()) {
// FCs in encoder use int31
VLOG(3) << "Use FC-int31 in FC-" << i << ", " << i / 6 << "-" << i % 6;
std::unique_ptr<float[]> weight_trans_fp32(new float[weight_len]);
paddle::lite::xpu::math::Transpose(weight_on_host,
weight_trans_fp32.get(),
weight_dims[0],
weight_dims[1]);
memcpy(weight_on_host,
weight_trans_fp32.get(),
weight_len * sizeof(float));
} else {
std::unique_ptr<int16_t[]> weight_int16(new int16_t[weight_len]); std::unique_ptr<int16_t[]> weight_int16(new int16_t[weight_len]);
std::unique_ptr<int16_t[]> weight_trans_int16(new int16_t[weight_len]); std::unique_ptr<int16_t[]> weight_trans_int16(new int16_t[weight_len]);
paddle::lite::xpu::math::ConvertFP32ToInt16( paddle::lite::xpu::math::ConvertFP32ToInt16(
...@@ -565,6 +555,7 @@ class XPUMultiEncoderFuser { ...@@ -565,6 +555,7 @@ class XPUMultiEncoderFuser {
memcpy(weight_on_host, memcpy(weight_on_host,
weight_trans_int16.get(), weight_trans_int16.get(),
weight_len * sizeof(int16_t)); weight_len * sizeof(int16_t));
}
fc_weight_max[i] = max_f; fc_weight_max[i] = max_f;
} }
...@@ -631,6 +622,9 @@ class XPUMultiEncoderFuser { ...@@ -631,6 +622,9 @@ class XPUMultiEncoderFuser {
GraphSafeRemoveNodes(graph, to_remove2); GraphSafeRemoveNodes(graph, to_remove2);
} }
} }
private:
std::set<int> fc_int31_ids_;
}; };
} // namespace fusion } // namespace fusion
...@@ -641,15 +635,35 @@ class XPUMultiEncoderFusePass : public ProgramPass { ...@@ -641,15 +635,35 @@ class XPUMultiEncoderFusePass : public ProgramPass {
if (GetBoolFromEnv("XPU_ENABLE_XTCL")) return; if (GetBoolFromEnv("XPU_ENABLE_XTCL")) return;
// TODO(miaotianxiang): backup graph, recover from failed match // TODO(miaotianxiang): backup graph, recover from failed match
std::vector<std::string> act_types{"gelu", "relu"}; std::vector<std::string> act_types{"gelu", "relu"};
std::set<int> fc_int31_ids;
#ifdef LITE_WITH_XPU
// TODO(miaotianxiang): core/mir/*_pass.cc are compiled anyway and need to
// access Context<kXPU>::_multi_encoder_precision, but this static member
// variable in class specialization defined in lite/core/context.cc
// is only compiled iff LITE_WITH_XPU==ON. To suppress linkage error, we use
// #ifdef here. Any better idea?
if (GetStringFromEnv("XPU_ENCODER_PRECISION", "int16") == "int31" ||
lite::Context<TargetType::kXPU>::_multi_encoder_precision == "int31") {
fc_int31_ids = {0, 1, 2, 3, 4, 5};
VLOG(3) << "Use int31 in XPUMultiEncoderOp, "
<< "lite::Context<>::_multi_encoder_precision="
<< lite::Context<TargetType::kXPU>::_multi_encoder_precision;
} else {
VLOG(3) << "Use int16 in XPUMultiEncoderOp, "
<< "lite::Context<>::_multi_encoder_precision="
<< lite::Context<TargetType::kXPU>::_multi_encoder_precision;
}
#endif
for (auto& act_type : act_types) { for (auto& act_type : act_types) {
fusion::XPUSingleEncoderFuser single_encoder_fuser(act_type); fusion::XPUSingleEncoderFuser single_encoder_fuser(act_type);
single_encoder_fuser(graph.get()); single_encoder_fuser(graph.get());
fusion::XPUMultiEncoderFuser multi_encoder_fuser; fusion::XPUMultiEncoderFuser multi_encoder_fuser(fc_int31_ids);
multi_encoder_fuser(graph.get()); multi_encoder_fuser(graph.get());
} }
} }
}; };
} // namespace mir } // namespace mir
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
......
...@@ -162,6 +162,12 @@ struct PMNode { ...@@ -162,6 +162,12 @@ struct PMNode {
attr_name, [=](const T& src) { return src == attr; }); attr_name, [=](const T& src) { return src == attr; });
} }
PMNode* assert_node_satisfied(
const std::function<bool(const Node*)>& condition) {
asserts_.push_back(condition);
return this;
}
private: private:
PMNode(PMPattern* pattern, PMNode(PMPattern* pattern,
const std::string& name = "", const std::string& name = "",
......
...@@ -76,12 +76,11 @@ class Optimizer { ...@@ -76,12 +76,11 @@ class Optimizer {
(defined LITE_WITH_ARM) (defined LITE_WITH_ARM)
"lite_elementwise_activation_fuse_pass", // "lite_elementwise_activation_fuse_pass", //
#endif #endif
"identity_dropout_eliminate_pass",
"__xpu__resnet_fuse_pass", "__xpu__resnet_fuse_pass",
"__xpu__multi_encoder_fuse_pass", "__xpu__multi_encoder_fuse_pass",
"__xpu__embedding_with_eltwise_add_fuse_pass", "__xpu__embedding_with_eltwise_add_fuse_pass",
"__xpu__fc_fuse_pass", "__xpu__fc_fuse_pass",
"identity_dropout_eliminate_pass", // should be placed after
// xpu fusion
"quantized_op_attributes_inference_pass", // Only for fully "quantized_op_attributes_inference_pass", // Only for fully
// quantized model, infer // quantized model, infer
// the output scale and // the output scale and
......
...@@ -46,7 +46,30 @@ void XPUMultiEncoderCompute::Run() { ...@@ -46,7 +46,30 @@ void XPUMultiEncoderCompute::Run() {
int batch_size = param.input->dims()[0]; int batch_size = param.input->dims()[0];
int seq_len = param.input->dims()[1]; int seq_len = param.input->dims()[1];
int r = xdnn::bert_encoder_transformer_int16<int16_t>( int r = -1;
if (param.precision == "int31") {
r = xdnn::bert_encoder_transformer_int31(
ctx.GetRawContext(), /* context */
batch_size, /* batch_size */
seq_len, /* from_seq_len */
seq_len, /* to_seq_len */
param.head_num, /* head_num */
param.size_per_head, /* size_per_head */
param.n_layers, /* n_layers */
param.input->data<float>(), /* from_tensor */
param.input->data<float>(), /* to_tensor */
param.mask->data<float>(), /* att_mask */
(const float**)(&arg_fc_weight_[0]), /* fc_weights */
&arg_fc_bias_[0], /* fc_biass */
&arg_ln_scale_[0], /* ln_scales */
&arg_ln_bias_[0], /* ln_biass */
param.output->mutable_data<float>(TARGET(kXPU)), /* output */
param.fc_weight_max->data<float>(), /* fc_weights_max */
true, /* pretrans_b */
true, /* use_l3 */
act_type_ /* act_type */);
} else {
r = xdnn::bert_encoder_transformer_int16<int16_t>(
ctx.GetRawContext(), /* context */ ctx.GetRawContext(), /* context */
batch_size, /* batch_size */ batch_size, /* batch_size */
seq_len, /* from_seq_len */ seq_len, /* from_seq_len */
...@@ -66,6 +89,7 @@ void XPUMultiEncoderCompute::Run() { ...@@ -66,6 +89,7 @@ void XPUMultiEncoderCompute::Run() {
true, /* pretrans_b */ true, /* pretrans_b */
true, /* use_l3 */ true, /* use_l3 */
act_type_ /* act_type */); act_type_ /* act_type */);
}
CHECK_EQ(r, 0); CHECK_EQ(r, 0);
} }
......
...@@ -68,6 +68,7 @@ bool XPUMultiEncoderOp::AttachImpl(const cpp::OpDesc& op_desc, ...@@ -68,6 +68,7 @@ bool XPUMultiEncoderOp::AttachImpl(const cpp::OpDesc& op_desc,
param_.head_num = op_desc.GetAttr<int>("head_num"); param_.head_num = op_desc.GetAttr<int>("head_num");
param_.size_per_head = op_desc.GetAttr<int>("size_per_head"); param_.size_per_head = op_desc.GetAttr<int>("size_per_head");
param_.act_type = op_desc.GetAttr<std::string>("act_type"); param_.act_type = op_desc.GetAttr<std::string>("act_type");
param_.precision = op_desc.GetAttr<std::string>("precision");
return true; return true;
} }
......
...@@ -1492,6 +1492,7 @@ struct XPUMultiEncoderParam : ParamBase { ...@@ -1492,6 +1492,7 @@ struct XPUMultiEncoderParam : ParamBase {
int head_num{}; int head_num{};
int size_per_head{}; int size_per_head{};
std::string act_type{}; std::string act_type{};
std::string precision{};
}; };
struct XPUEmbeddingWithEltwiseAddParam : ParamBase { struct XPUEmbeddingWithEltwiseAddParam : ParamBase {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册