diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 0c2a65878d665a078081857d43e4fb96458f0a5d..88b7c4ebcc86dc6653f77e028a342fb11c8e4fa9 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -226,6 +226,7 @@ if(WITH_XPU) ${XPU_PASS_DEPS}) pass_library(multi_encoder_xpu_slice_fuse_pass inference DIR xpu) pass_library(generate_sequence_xpu_fuse_pass inference DIR xpu) + pass_library(link_xpu_op_max_pass inference DIR xpu) endif() cc_library( diff --git a/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc index 22a7229b70eee257b304a22ff57981f702f5091a..54efd1ed897b78e1e43fdf9e12a872640deca83e 100644 --- a/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc @@ -159,9 +159,9 @@ Fused subgraph: \ | / | \ | / | fc_xpu----------- - | - | - act_out + | \ + | \ + act_out out_max */ class FcXPUFusePass : public FusePassBase { protected: @@ -185,6 +185,7 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const { for (auto act_type : { "relu", "gelu", + "tanh", "", }) { ApplyImpl(graph, mul_type, with_bias, act_type); @@ -244,6 +245,18 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph, QuantWeight(mul_w_tensor, mul_w_max_tensor, !transpose_w); } + std::string fc_out_name; + if (act_out) { + fc_out_name = act_out->Name(); + } else if (add_out) { + fc_out_name = add_out->Name(); + } else { + fc_out_name = mul_out->Name(); + } + std::string fc_out_max_name = fc_out_name + "_max"; + VarDesc fc_out_max_desc(fc_out_max_name); + Node* fc_out_max = graph->CreateVarNode(&fc_out_max_desc); + // Generate fc_xpu op framework::OpDesc fc_xpu_op_desc(block); fc_xpu_op_desc.SetType("fc_xpu"); @@ -282,25 +295,21 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph, "act_alpha", PADDLE_GET_CONST(float, act->Op()->GetAttr("slope"))); } } - if (act_out) { - fc_xpu_op_desc.SetOutput("out", {act_out->Name()}); - } else if (add_out) { - fc_xpu_op_desc.SetOutput("out", {add_out->Name()}); - } else { - fc_xpu_op_desc.SetOutput("out", {mul_out->Name()}); - } + fc_xpu_op_desc.SetOutput("out", {fc_out_name}); + fc_xpu_op_desc.SetOutput("out_max", {fc_out_max_name}); auto* fc_xpu = graph->CreateOpNode(&fc_xpu_op_desc); - SAFE_IR_NODE_LINK_TO(mul_x, fc_xpu); - SAFE_IR_NODE_LINK_TO(mul_w, fc_xpu); - SAFE_IR_NODE_LINK_TO(mul_w_max, fc_xpu); + IR_NODE_LINK_TO(mul_x, fc_xpu); + IR_NODE_LINK_TO(mul_w, fc_xpu); + IR_NODE_LINK_TO(mul_w_max, fc_xpu); SAFE_IR_NODE_LINK_TO(bias, fc_xpu); if (act_out) { - SAFE_IR_NODE_LINK_TO(fc_xpu, act_out); + IR_NODE_LINK_TO(fc_xpu, act_out); } else if (add_out) { - SAFE_IR_NODE_LINK_TO(fc_xpu, add_out); + IR_NODE_LINK_TO(fc_xpu, add_out); } else { - SAFE_IR_NODE_LINK_TO(fc_xpu, mul_out); + IR_NODE_LINK_TO(fc_xpu, mul_out); } + IR_NODE_LINK_TO(fc_xpu, fc_out_max); // delete useless node std::unordered_set delete_nodes; diff --git a/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc b/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..86b6da2868714d51b2f75834f7bd5739f8eb0158 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc @@ -0,0 +1,145 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/ir/xpu/quant_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct FusionXPUOpPattern : public PatternBase { + FusionXPUOpPattern(PDPattern* pattern, + const std::string& name_scope, + const std::string& op_type); + + // declare operator node's name + PATTERN_DECL_NODE(fusion_op); + // declare variable node's name + PATTERN_DECL_NODE(out); + PATTERN_DECL_NODE(out_max); + + private: + std::string op_type_; +}; + +FusionXPUOpPattern::FusionXPUOpPattern(PDPattern* pattern, + const std::string& name_scope, + const std::string& op_type) + : PatternBase(pattern, name_scope, name_scope), op_type_(op_type) { + auto* fusion_op = pattern->NewNode(fusion_op_repr())->assert_is_op(op_type_); + auto* out = pattern->NewNode(out_repr()) + ->assert_is_op_output(op_type_, "out") + ->assert_var_not_persistable(); + auto* out_max = pattern->NewNode(out_max_repr()) + ->assert_is_op_output(op_type_, "out_max") + ->assert_var_not_persistable(); + fusion_op->LinksTo({out, out_max}); +} + +} // namespace patterns + +class LinkXPUOpMaxPass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + void ApplyImpl(ir::Graph* graph, const std::string& op_type) const; + + const std::string name_scope_{"multi_encoder_xpu_slice_fuse_pass"}; + // ops with x_max/out_max + std::set op_types_{"fc_xpu", "conv2d_xpu"}; +}; + +/* +Origin subgraph: + fusion_xpu_op0 + / \ + | | + out0 out0_max + | + \ + fusion_xpu_op1 + +Fused subgraph: + fusion_xpu_op0 + / \ + | | + out0 out0_max + | | + \ / + fusion_xpu_op1 +*/ +void LinkXPUOpMaxPass::ApplyImpl(ir::Graph* graph) const { + Init(name_scope_, graph); + for (auto op_type : op_types_) { + ApplyImpl(graph, op_type); + } +} + +void LinkXPUOpMaxPass::ApplyImpl(ir::Graph* graph, + const std::string& op_type) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + GraphPatternDetector gpd; + patterns::FusionXPUOpPattern pattern( + gpd.mutable_pattern(), name_scope_, op_type); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle LinkXPUOpMaxPass fuse"; + GET_IR_NODE(fusion_op); + GET_IR_NODE(out); + GET_IR_NODE(out_max); + for (auto next_op : out->outputs) { + auto* next_op_desc = next_op->Op(); + if (op_types_.count(next_op_desc->Type()) == 0) continue; + next_op_desc->SetInput("x_max", {out_max->Name()}); + IR_NODE_LINK_TO(out_max, next_op); + found_subgraph_count++; + } + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(link_xpu_op_max_pass, paddle::framework::ir::LinkXPUOpMaxPass); + +REGISTER_PASS_CAPABILITY(link_xpu_op_max_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "fc_xpu", 0)); diff --git a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc index e1b2f999d7daae9de9dd9f936a9368bbed1048cf..687e3581c5e47e862201e2de0b66dac1a6e21dd6 100644 --- a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc @@ -224,27 +224,22 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToXpu(Argument *argument) { LOG(INFO) << "Sync params from CPU to XPU: " << "xpu_device_id - " << argument->xpu_device_id(); - platform::Place place = platform::XPUPlace(argument->xpu_device_id()); + platform::CPUPlace cpu_place; + platform::Place xpu_place = platform::XPUPlace(argument->xpu_device_id()); auto *scope = argument->scope_ptr(); - std::vector all_vars = scope->LocalVarNames(); - - for (auto &var_name : all_vars) { - auto *var = scope->FindLocalVar(var_name); - PADDLE_ENFORCE_NOT_NULL( - var, - platform::errors::PreconditionNotMet("The var should not be nullptr")); - - if (var->IsType()) { - auto *t = var->GetMutable(); - - platform::CPUPlace cpu_place; - phi::DenseTensor temp_tensor; - temp_tensor.Resize(t->dims()); - - paddle::framework::TensorCopySync(*t, cpu_place, &temp_tensor); - t->clear(); - paddle::framework::TensorCopySync(temp_tensor, place, t); - } + framework::ir::Graph &graph = argument->main_graph(); + + for (auto *node : graph.Nodes()) { + if (!node->IsVar() || !node->Var()->Persistable()) continue; + auto *var = scope->FindVar(node->Name()); + if (!var->IsType()) continue; + auto *tensor = var->GetMutable(); + + phi::DenseTensor temp_tensor; + temp_tensor.Resize(tensor->dims()); + paddle::framework::TensorCopySync(*tensor, cpu_place, &temp_tensor); + tensor->clear(); + paddle::framework::TensorCopySync(temp_tensor, xpu_place, tensor); } } #endif diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 2b0e8e430d0f854acc57a52898087840c6c26b95..8f193fc8203f8208d946538e88ac54f3db1aec4c 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -522,7 +522,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "multi_encoder_xpu_slice_fuse_pass", // "embedding_with_eltwise_add_xpu_fuse_pass", "fc_xpu_fuse_pass", - // "link_previous_out_max_xpu_pass", + "link_xpu_op_max_pass", }); use_xpu_ = true; } diff --git a/paddle/phi/api/yaml/static_ops.yaml b/paddle/phi/api/yaml/static_ops.yaml index 9c9e5c0c8b27db1fbb1d1ada42ba3c90ae72fef7..7ba94a9f3da7df25331a0e1c0b3f190d8cf1baf9 100644 --- a/paddle/phi/api/yaml/static_ops.yaml +++ b/paddle/phi/api/yaml/static_ops.yaml @@ -1,12 +1,12 @@ - op : fc_xpu - args : (Tensor x, Tensor w, Tensor w_max, Tensor bias, int in_num_col_dims, bool transpose_x, float alpha, float beta, int act_type, float act_alpha) - output : Tensor + args : (Tensor x, Tensor x_max, Tensor w, Tensor w_max, Tensor bias, int in_num_col_dims, bool transpose_x, float alpha, float beta, int act_type, float act_alpha) + output : Tensor(out), Tensor(out_max) infer_meta : func : FcXPUInferMeta kernel : func : fc_xpu data_type : x - optional : bias + optional : bias, x_max - op : generate_sequence_xpu args : (Tensor x, DataType dtype) diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 699fcc830ce5751e0d7e5f177432e3bdc8adb0a3..d468b13e17d2ae0907469c2e358bc6805ab20092 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -22,6 +22,7 @@ limitations under the License. */ namespace phi { void FcXPUInferMeta(const MetaTensor& x, + const MetaTensor& x_max, const MetaTensor& w, const MetaTensor& w_max, const MetaTensor& bias, @@ -31,7 +32,8 @@ void FcXPUInferMeta(const MetaTensor& x, float beta, int act_type, float act_alpha, - MetaTensor* out) { + MetaTensor* out, + MetaTensor* out_max) { std::vector out_shape(in_num_col_dims + 1); for (int i = 0; i < in_num_col_dims; i++) { out_shape[i] = x.dims()[i]; @@ -40,6 +42,9 @@ void FcXPUInferMeta(const MetaTensor& x, out->set_dims(DDim(out_shape.data(), out_shape.size())); out->set_dtype(x.dtype()); out->set_layout(x.layout()); + out_max->set_dims(w_max.dims()); + out_max->set_dtype(x.dtype()); + out_max->set_layout(x.layout()); } void GenerateSequenceXPUInferMeta(const MetaTensor& x, diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 7848bb40f001a3a370703ec8366d5bb14a10b7b0..e1fe0c3c112a97ecd06c1710aaa2cf4171f7d4e0 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -23,6 +23,7 @@ namespace phi { // NOTE: The InferMeta Functions in this file are arranged in alphabetic order. void FcXPUInferMeta(const MetaTensor& x, + const MetaTensor& x_max, const MetaTensor& w, const MetaTensor& w_max, const MetaTensor& bias, @@ -32,7 +33,8 @@ void FcXPUInferMeta(const MetaTensor& x, float beta, int act_type, float act_alpha, - MetaTensor* out); + MetaTensor* out, + MetaTensor* out_max); void GenerateSequenceXPUInferMeta(const MetaTensor& x, DataType dtype, diff --git a/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc index fabf0bec6d928d888d710cbc1ffd2735aed77a08..f0f784f324b25bbce636280882c5842b75f85b53 100644 --- a/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc @@ -21,6 +21,7 @@ namespace fusion { template void FcXPUKernel(const Context& ctx, const DenseTensor& x, + const paddle::optional& x_max, const DenseTensor& w, const DenseTensor& w_max, const paddle::optional& bias, @@ -30,33 +31,35 @@ void FcXPUKernel(const Context& ctx, float beta, int act_type, float act_alpha, - DenseTensor* out) { + DenseTensor* out, + DenseTensor* out_max) { auto in_mat_dims = flatten_to_2d(x.dims(), in_num_col_dims); int m = in_mat_dims[0]; int k = in_mat_dims[1]; int n = w.dims()[0]; + const float* x_max_data = + x_max.get_ptr() == nullptr ? nullptr : x_max.get_ptr()->data(); const float* bias_data = - bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data(); + bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data(); xpu::Activation_t act(static_cast(act_type)); if (act_type == 5) { act.leaky_alpha = act_alpha; } else if (act_type == 15) { act.hard_sigmoid_slope = act_alpha; } - ctx.template Alloc(out); int r = xpu::fc_fusion( // TX, TW. TY, TGEMM ctx.x_context(), // ctx x.data(), // x w.data(), // w - out->data(), // y + ctx.template Alloc(out), // y m, // m n, // n k, // k transpose_x, // x_trans true, // w_trans - nullptr, // x_maxptr + x_max_data, // x_maxptr w_max.data(), // w_maxptr - nullptr, // y_maxptr + ctx.template Alloc(out_max), // y_maxptr transpose_x ? m : k, // ldx k, // ldw n, // ldy diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_link_xpu_op_max_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_link_xpu_op_max_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..b1a19b7b716af6b0c6e162bfa1c01b9a1dfe3476 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_link_xpu_op_max_pass.py @@ -0,0 +1,104 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import hypothesis.strategies as st +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + + +class TestFcXPUFusePass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ["fc_xpu", "fc_xpu"], (1e-3, 1e-3) + + def sample_program_config(self, draw): + # 1. matmul0 + matmul0_x_shape = draw( + st.lists( + st.integers(min_value=1, max_value=4), min_size=2, max_size=4 + ) + ) + matmul0_y_shape = draw( + st.lists( + st.integers(min_value=1, max_value=8), min_size=2, max_size=2 + ) + ) + matmul0_y_shape[0] = matmul0_x_shape[-1] + # 2. add0 + add0_bias_shape = [matmul0_y_shape[1]] + # 3. matmul1 + matmul1_y_shape = draw( + st.lists( + st.integers(min_value=1, max_value=8), min_size=2, max_size=2 + ) + ) + matmul1_y_shape[0] = matmul0_y_shape[-1] + # 4. add1 + add1_bias_shape = [matmul1_y_shape[1]] + + matmul0_op = OpConfig( + "matmul_v2", + inputs={"X": ["matmul0_x"], "Y": ["matmul0_y"]}, + outputs={"Out": ["matmul0_out"]}, + trans_x=False, + trans_y=False, + ) + add0_op = OpConfig( + "elementwise_add", + inputs={"X": ["matmul0_out"], "Y": ["add0_bias"]}, + outputs={"Out": ["add0_out"]}, + axis=-1, + ) + matmul1_op = OpConfig( + "matmul_v2", + inputs={"X": ["add0_out"], "Y": ["matmul1_y"]}, + outputs={"Out": ["matmul1_out"]}, + trans_x=False, + trans_y=False, + ) + add1_op = OpConfig( + "elementwise_add", + inputs={"X": ["matmul1_out"], "Y": ["add1_bias"]}, + outputs={"Out": ["add1_out"]}, + axis=-1, + ) + ops = [matmul0_op, add0_op, matmul1_op, add1_op] + + program_config = ProgramConfig( + ops=ops, + weights={ + "matmul0_y": TensorConfig(shape=matmul0_y_shape), + "add0_bias": TensorConfig(shape=add0_bias_shape), + "matmul1_y": TensorConfig(shape=matmul1_y_shape), + "add1_bias": TensorConfig(shape=add1_bias_shape), + }, + inputs={ + "matmul0_x": TensorConfig(shape=matmul0_x_shape), + }, + outputs=ops[-1].outputs["Out"], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=25, + passes=["fc_xpu_fuse_pass", "link_xpu_op_max_pass"], + ) + + +if __name__ == "__main__": + unittest.main()