From 4df009399ccd9d1d016c31abf91cf663d595737d Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Thu, 20 Jul 2023 13:39:39 +0800 Subject: [PATCH] [XPU] fuse cast to conv2d/fc in mixed precision model (#54493) --- paddle/fluid/framework/ir/CMakeLists.txt | 6 + .../xpu/cast_mixed_precision_op_fuse_pass.cc | 206 ++++++++++++++++++ .../cast_mixed_precision_op_fuse_pass_test.cc | 71 ++++++ .../framework/ir/xpu/conv2d_xpu_fuse_pass.cc | 17 +- .../framework/ir/xpu/fc_xpu_fuse_pass.cc | 9 +- .../inference/api/paddle_pass_builder.cc | 4 +- paddle/phi/api/yaml/fused_ops.yaml | 4 +- paddle/phi/infermeta/fusion.cc | 7 +- paddle/phi/infermeta/fusion.h | 4 +- .../kernels/fusion/xpu/conv2d_xpu_kernel.cc | 138 +++++++++--- .../phi/kernels/fusion/xpu/fc_xpu_kernel.cc | 127 +++++++---- 11 files changed, 502 insertions(+), 91 deletions(-) create mode 100644 paddle/fluid/framework/ir/xpu/cast_mixed_precision_op_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/xpu/cast_mixed_precision_op_fuse_pass_test.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 4faa9cd2183..f4b88178499 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -236,6 +236,8 @@ if(WITH_XPU) SRCS xpu/pass_utils.cc DEPS pass xpu_quant_utils) set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils) + pass_library(cast_mixed_precision_op_fuse_pass inference DIR xpu DEPS + ${XPU_PASS_DEPS}) pass_library(yolo_box_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(conv2d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(redundant_onnx_ops_elimination_pass inference DIR xpu DEPS @@ -550,6 +552,10 @@ if(WITH_MKLDNN) endif() if(WITH_XPU) + cc_test( + test_cast_mixed_precision_op_fuse_pass + SRCS xpu/cast_mixed_precision_op_fuse_pass_test.cc + DEPS cast_mixed_precision_op_fuse_pass) cc_test( test_delete_isolated_node_pass SRCS xpu/delete_isolated_node_pass_test.cc diff --git a/paddle/fluid/framework/ir/xpu/cast_mixed_precision_op_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/cast_mixed_precision_op_fuse_pass.cc new file mode 100644 index 00000000000..ef8759153b0 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/cast_mixed_precision_op_fuse_pass.cc @@ -0,0 +1,206 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "glog/logging.h" + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { + +namespace patterns { +struct CastBeforePattern : public PatternBase { + CastBeforePattern(PDPattern* pattern, + const std::string& name_scope, + const std::string& mixed_precision_op_type); + + PATTERN_DECL_NODE(cast_in); + PATTERN_DECL_NODE(cast); + PATTERN_DECL_NODE(cast_out); + PATTERN_DECL_NODE(mixed_precision_op); +}; + +CastBeforePattern::CastBeforePattern(PDPattern* pattern, + const std::string& name_scope, + const std::string& mixed_precision_op_type) + : PatternBase(pattern, name_scope, name_scope) { + auto* cast_in = + pattern->NewNode(cast_in_repr())->assert_is_op_input("cast", "X"); + auto* cast = pattern->NewNode(cast_repr()) + ->assert_is_op("cast") + ->assert_more([&](Node* node) { + auto* op_desc = node->Op(); + return op_desc->GetAttrIfExists("in_dtype") == 5 && + op_desc->GetAttrIfExists("out_dtype") == 4; + }); + auto* cast_out = pattern->NewNode(cast_out_repr()) + ->assert_is_op_output("cast", "Out") + ->assert_is_op_input(mixed_precision_op_type, "x") + ->assert_has_n_outputs(1); + auto* mixed_precision_op = pattern->NewNode(mixed_precision_op_repr()) + ->assert_is_op(mixed_precision_op_type); + + cast->LinksFrom({cast_in}).LinksTo({cast_out}); + mixed_precision_op->LinksFrom({cast_out}); +} + +struct CastAfterPattern : public PatternBase { + CastAfterPattern(PDPattern* pattern, + const std::string& name_scope, + const std::string& mixed_precision_op_type); + + PATTERN_DECL_NODE(mixed_precision_op); + PATTERN_DECL_NODE(cast_in); + PATTERN_DECL_NODE(cast); + PATTERN_DECL_NODE(cast_out); +}; + +CastAfterPattern::CastAfterPattern(PDPattern* pattern, + const std::string& name_scope, + const std::string& mixed_precision_op_type) + : PatternBase(pattern, name_scope, name_scope) { + auto* mixed_precision_op = pattern->NewNode(mixed_precision_op_repr()) + ->assert_is_op(mixed_precision_op_type); + auto* cast_in = pattern->NewNode(cast_in_repr()) + ->assert_is_op_output(mixed_precision_op_type, "out") + ->assert_is_op_input("cast", "X") + ->assert_has_n_outputs(1); + auto* cast = pattern->NewNode(cast_repr()) + ->assert_is_op("cast") + ->assert_more([&](Node* node) { + auto* op_desc = node->Op(); + return op_desc->GetAttrIfExists("in_dtype") == 4 && + op_desc->GetAttrIfExists("out_dtype") == 5; + }); + auto* cast_out = + pattern->NewNode(cast_out_repr())->assert_is_op_output("cast", "Out"); + + mixed_precision_op->LinksTo({cast_in}); + cast->LinksFrom({cast_in}).LinksTo({cast_out}); +} + +} // namespace patterns + +class CastMixedPrecisionOpFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + int ApplyCastBeforePass(ir::Graph* graph, + const std::string& mixed_precision_op_type) const; + int ApplyCastAfterPass(ir::Graph* graph, + const std::string& mixed_precision_op_type) const; + + const std::string name_scope_{"cast_mixed_precision_op_fuse_pass"}; +}; + +int CastMixedPrecisionOpFusePass::ApplyCastBeforePass( + ir::Graph* graph, const std::string& mixed_precision_op_type) const { + GraphPatternDetector gpd; + patterns::CastBeforePattern pattern( + gpd.mutable_pattern(), name_scope_, mixed_precision_op_type); + int found_subgraph_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle ApplyCastBeforePass"; + GET_IR_NODE(cast_in); + GET_IR_NODE(cast); + GET_IR_NODE(cast_out); + GET_IR_NODE(mixed_precision_op); + + mixed_precision_op->Op()->RenameInput(cast_out->Name(), cast_in->Name()); + IR_NODE_LINK_TO(cast_in, mixed_precision_op); + + // delete useless node + std::unordered_set delete_nodes = {cast, cast_out}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + return found_subgraph_count; +} + +int CastMixedPrecisionOpFusePass::ApplyCastAfterPass( + ir::Graph* graph, const std::string& mixed_precision_op_type) const { + GraphPatternDetector gpd; + patterns::CastAfterPattern pattern( + gpd.mutable_pattern(), name_scope_, mixed_precision_op_type); + int found_subgraph_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle ApplyCastAfterPass"; + GET_IR_NODE(mixed_precision_op); + GET_IR_NODE(cast_in); + GET_IR_NODE(cast); + GET_IR_NODE(cast_out); + + mixed_precision_op->Op()->RenameOutput(cast_in->Name(), cast_out->Name()); + int out_dtype = proto::VarType::Type::VarType_Type_FP32; + mixed_precision_op->Op()->SetAttr("out_dtype", out_dtype); + IR_NODE_LINK_TO(mixed_precision_op, cast_out); + + // delete useless node + std::unordered_set delete_nodes = {cast_in, cast}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + return found_subgraph_count; +} + +void CastMixedPrecisionOpFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + + int count = 0; + for (auto op_type : {"conv2d_xpu", "fc_xpu"}) { + count += ApplyCastBeforePass(graph, op_type); + count += ApplyCastAfterPass(graph, op_type); + } + AddStatis(count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(cast_mixed_precision_op_fuse_pass, + paddle::framework::ir::CastMixedPrecisionOpFusePass); + +REGISTER_PASS_CAPABILITY(cast_mixed_precision_op_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "cast", 0)); diff --git a/paddle/fluid/framework/ir/xpu/cast_mixed_precision_op_fuse_pass_test.cc b/paddle/fluid/framework/ir/xpu/cast_mixed_precision_op_fuse_pass_test.cc new file mode 100644 index 00000000000..855f09176f0 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/cast_mixed_precision_op_fuse_pass_test.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +TEST(CastMixedPrecisionOpFusePass, cast_before) { + Layers layers; + auto* block = layers.Block(); + + auto* cast_in = layers.data("cast_in"); + auto* cast_out = layers.cast(cast_in, 5, 4); + OpDesc* conv2d_xpu = block->AppendOp(); + conv2d_xpu->SetType("conv2d_xpu"); + conv2d_xpu->SetInput("x", {cast_out->Name()}); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = PassRegistry::Instance().Get("cast_mixed_precision_op_fuse_pass"); + pass->Apply(graph.get()); + auto num = GetNumOpNodes(graph, "cast"); + PADDLE_ENFORCE_EQ( + num, + 0, + platform::errors::PreconditionNotMet( + "cast op should be removed from graph, but graph still has %d ops.", + num)); +} + +TEST(CastMixedPrecisionOpFusePass, cast_after) { + Layers layers; + auto* block = layers.Block(); + + auto* cast_in = layers.data("cast_in"); + OpDesc* conv2d_xpu = block->AppendOp(); + conv2d_xpu->SetType("conv2d_xpu"); + conv2d_xpu->SetOutput("out", {cast_in->Name()}); + layers.cast(cast_in, 4, 5); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = PassRegistry::Instance().Get("cast_mixed_precision_op_fuse_pass"); + pass->Apply(graph.get()); + auto num = GetNumOpNodes(graph, "cast"); + PADDLE_ENFORCE_EQ( + num, + 0, + platform::errors::PreconditionNotMet( + "cast op should be removed from graph, but graph still has %d ops.", + num)); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(cast_mixed_precision_op_fuse_pass); diff --git a/paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc index 40af4f8c000..f893da660d0 100644 --- a/paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc @@ -429,10 +429,13 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, auto* filter_t = scope->FindVar(conv_filter->Name())->GetMutable(); // conv_filter fp16 --> fp32 - auto tensor_type = filter_t->dtype(); - if (tensor_type == phi::DataType::FLOAT16) { + auto filter_dtype = filter_t->dtype(); + int out_dtype = proto::VarType::Type::VarType_Type_FP32; + if (filter_dtype == phi::DataType::FLOAT16) { + out_dtype = proto::VarType::Type::VarType_Type_FP16; CastToFp32(filter_t, nullptr); } + auto filter_dims = filter_t->dims(); bool has_bias = with_bn || with_conv_bias; // Create conv_fusion_bias (conv bias) variable @@ -515,7 +518,6 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, Node* filter_max = nullptr; PrepareWeight( graph, scope, block, conv_filter, &filter_int16, &filter_max, false); - bool has_branch = with_branch_x || with_branch_y; // output && output max std::string conv2d_xpu_out_name; if (!act_type.empty()) { @@ -590,15 +592,8 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, conv2d_xpu_op_desc.SetAttr( "strides", PADDLE_GET_CONST(std::vector, conv->Op()->GetAttr("strides"))); - conv2d_xpu_op_desc.SetAttr("conv_bias", conv_bias); - conv2d_xpu_op_desc.SetAttr("op_type", std::vector{0}); - conv2d_xpu_op_desc.SetAttr("place_x", std::vector{0}); - conv2d_xpu_op_desc.SetAttr("place_y", std::vector{9}); - conv2d_xpu_op_desc.SetAttr("place_z", std::vector{10}); conv2d_xpu_op_desc.SetAttr("paddings", conv_paddings); - conv2d_xpu_op_desc.SetAttr("block_lod", std::vector{1}); - conv2d_xpu_op_desc.SetAttr("has_branch", has_branch); - conv2d_xpu_op_desc.SetAttr("has_bias", has_bias); + conv2d_xpu_op_desc.SetAttr("out_dtype", out_dtype); auto* conv2d_xpu = graph->CreateOpNode(&conv2d_xpu_op_desc); IR_NODE_LINK_TO(input, conv2d_xpu); diff --git a/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc index 18a573db0c7..59e7f4f7de5 100644 --- a/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc @@ -313,9 +313,11 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph, auto* filter_t = scope->FindVar(mul_w->Name())->GetMutable(); - // filter fp16 --> fp32 - auto tensor_type = filter_t->dtype(); - if (tensor_type == phi::DataType::FLOAT16) { + // weight fp16 --> fp32 + auto filter_dtype = filter_t->dtype(); + int out_dtype = proto::VarType::Type::VarType_Type_FP32; + if (filter_dtype == phi::DataType::FLOAT16) { + out_dtype = proto::VarType::Type::VarType_Type_FP16; CastToFp32(filter_t, nullptr); } auto filter_dims = filter_t->dims(); @@ -435,6 +437,7 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph, "act_alpha", PADDLE_GET_CONST(float, act->Op()->GetAttr("slope"))); } } + fc_xpu_op_desc.SetAttr("out_dtype", out_dtype); fc_xpu_op_desc.SetOutput("out", {fc_out_name}); fc_xpu_op_desc.SetOutput("out_max", {fc_out_max_name}); auto* fc_xpu = graph->CreateOpNode(&fc_xpu_op_desc); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 958ee89af0c..c985eecd20d 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -544,8 +544,10 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "add_layernorm_xpu_fuse_pass", "yolo_box_xpu_fuse_pass", "link_xpu_op_max_pass", - "inplace_op_var_pass", "delete_isolated_node_pass", + // "auto_mixed_precision_pass", + "cast_mixed_precision_op_fuse_pass", + "inplace_op_var_pass", }); use_xpu_ = true; } diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 5ed1bf3576b..1ca4ba24332 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -34,7 +34,7 @@ optional : bias, x_max - op : conv2d_xpu - args : (Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, Tensor branch_max, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, bool has_bias, bool has_branch, int act_type, float act_param) + args : (Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, Tensor branch_max, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, int act_type, float act_param, DataType out_dtype) output : Tensor(out), Tensor(out_max) infer_meta : func : Conv2dXPUInferMeta @@ -54,7 +54,7 @@ optional : mask, seq_lod, max_seq_len - op : fc_xpu - args : (Tensor x, Tensor x_max, Tensor w, Tensor w_max, Tensor bias, int in_num_col_dims, bool transpose_x, float alpha, float beta, int act_type, float act_alpha) + args : (Tensor x, Tensor x_max, Tensor w, Tensor w_max, Tensor bias, int in_num_col_dims, bool transpose_x, float alpha, float beta, int act_type, float act_alpha, DataType out_dtype) output : Tensor(out), Tensor(out_max) infer_meta : func : FcXPUInferMeta diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index d9b1b08aead..ecf37532f14 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -151,10 +151,9 @@ void Conv2dXPUInferMeta(const MetaTensor& x, const std::vector& strides, const std::string& padding_algorithm, int groups, - bool has_bias, - bool has_branch, int act_type, float act_param, + DataType out_dtype, MetaTensor* out, MetaTensor* out_max) { auto in_dims = x.dims(); @@ -264,6 +263,7 @@ void Conv2dXPUInferMeta(const MetaTensor& x, // set output and output max dims out->set_dims(DDim(out_shape.data(), out_shape.size())); out_max->set_dims(phi::make_ddim({6})); + out->set_dtype(out_dtype); } void EmbeddingWithEltwiseAddXPUInferMeta( @@ -302,6 +302,7 @@ void FcXPUInferMeta(const MetaTensor& x, float beta, int act_type, float act_alpha, + DataType out_dtype, MetaTensor* out, MetaTensor* out_max) { std::vector out_shape(in_num_col_dims + 1); @@ -310,7 +311,7 @@ void FcXPUInferMeta(const MetaTensor& x, } out_shape[in_num_col_dims] = w.dims()[0]; out->set_dims(DDim(out_shape.data(), out_shape.size())); - out->set_dtype(x.dtype()); + out->set_dtype(out_dtype); out->set_layout(x.layout()); out_max->set_dims(phi::make_ddim({6})); out_max->set_dtype(x.dtype()); diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 921b5b6a021..9f9fff0f36c 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -54,10 +54,9 @@ void Conv2dXPUInferMeta(const MetaTensor& x, const std::vector& strides, const std::string& padding_algorithm, int groups, - bool has_bias, - bool has_branch, int act_type, float act_param, + DataType out_dtype, MetaTensor* out, MetaTensor* out_max); @@ -80,6 +79,7 @@ void FcXPUInferMeta(const MetaTensor& x, float beta, int act_type, float act_alpha, + DataType out_dtype, MetaTensor* out, MetaTensor* out_max); diff --git a/paddle/phi/kernels/fusion/xpu/conv2d_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/conv2d_xpu_kernel.cc index f82d9fdd9fd..43caa13698b 100644 --- a/paddle/phi/kernels/fusion/xpu/conv2d_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/conv2d_xpu_kernel.cc @@ -19,27 +19,31 @@ namespace phi { namespace fusion { -template -void Conv2dXPUKernel(const Context& ctx, - const DenseTensor& x, - const paddle::optional& x_max, - const DenseTensor& filter, - const DenseTensor& filter_max, - const paddle::optional& bias, - const paddle::optional& branch, - const paddle::optional& branch_max, - const std::vector& paddings, - const std::vector& dilations, - const std::vector& strides, - const std::string& padding_algorithm, - int groups, - bool has_bias, - bool has_branch, - int act_type, - float act_param, - DenseTensor* out, - DenseTensor* out_max) { - using XPUType = typename XPUTypeTrait::Type; +template +void Conv2dXPUKernelImpl(const Context& ctx, + const DenseTensor& x, + const paddle::optional& x_max, + const DenseTensor& filter, + const DenseTensor& filter_max, + const paddle::optional& bias, + const paddle::optional& branch, + const paddle::optional& branch_max, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const std::string& padding_algorithm, + int groups, + int act_type, + float act_param, + DenseTensor* out, + DenseTensor* out_max) { + using XPUTypeX = typename XPUTypeTrait::Type; + using XPUTypeW = typename XPUTypeTrait::Type; + using XPUTypeOut = typename XPUTypeTrait::Type; auto input_dims = x.dims(); auto filter_dims = filter.dims(); // update paddings and dilations accoring to padding_algorithm @@ -63,30 +67,51 @@ void Conv2dXPUKernel(const Context& ctx, int win_h = static_cast(filter_dims[2]); int win_w = static_cast(filter_dims[3]); - auto* input_data = reinterpret_cast(x.data()); + auto* input_data = reinterpret_cast(x.data()); const float* input_max_data = x_max.get_ptr() == nullptr ? nullptr : x_max.get_ptr()->data(); - auto* branch_data = - branch.get_ptr() == nullptr - ? nullptr - : reinterpret_cast(branch.get_ptr()->data()); + auto* filter_data = reinterpret_cast(filter.data()); + auto* filter_max_data = filter_max.data(); + + const XPUTypeOut* branch_data = nullptr; + auto* branch_tensor = branch.get_ptr(); + xpu::ctx_guard RAII_GUARD(ctx.x_context()); + if (branch_tensor != nullptr) { + if (branch_tensor->dtype() == out->dtype()) { + branch_data = + reinterpret_cast(branch_tensor->data()); + } else { + auto branch_data_temp = + RAII_GUARD.alloc_l3_or_gm(branch_tensor->numel()); + int r = xpu::cast( + ctx.x_context(), + reinterpret_cast(branch_tensor->data()), + branch_data_temp, + branch_tensor->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); + branch_data = branch_data_temp; + } + } const float* branch_max_data = branch_max.get_ptr() == nullptr ? nullptr : branch_max.get_ptr()->data(); const float* bias_data = bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data(); - auto* out_data = reinterpret_cast(ctx.template Alloc(out)); + auto* out_data = + reinterpret_cast(ctx.template Alloc(out)); + auto* out_max_data = ctx.template Alloc(out_max); xpu::Activation_t act(static_cast(act_type)); if (act_type == xpu::Activation_t::LEAKY_RELU) { act.leaky_alpha = act_param; } else if (act_type == xpu::Activation_t::HARD_SIGMOID) { act.hard_sigmoid_slope = act_param; } - int r = - xpu::conv2d_fusion( // TX/TW/TY/TGEMM + + int r = xpu:: + conv2d_fusion( // TX/TW/TY/TGEMM /* baidu::xpu::api::Context* ctx */ ctx.x_context(), /* const TX* input */ input_data, - /* const TW* filter */ filter.data(), + /* const TW* filter */ filter_data, /* TY* output */ out_data, /* int64_t n */ batch, /* int64_t ic */ in_c, @@ -99,8 +124,8 @@ void Conv2dXPUKernel(const Context& ctx, /* const std::vector& dilations */ dilations_vec, /* int64_t groups */ groups, /* const float* in_maxptr */ input_max_data, - /* const float* filter_maxptr */ filter_max.data(), - /* float* out_maxptr */ ctx.template Alloc(out_max), + /* const float* filter_maxptr */ filter_max_data, + /* float* out_maxptr */ out_max_data, /* bool is_nchw */ true, /* const float* bias */ bias_data, /* const TY* branch */ branch_data, @@ -110,6 +135,55 @@ void Conv2dXPUKernel(const Context& ctx, PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_xpu"); } +#define CONV2D_XPU_KERNEL_IMPL(x_dtype_, w_dtype_, out_dtype_, gemm_dtype_) \ + Conv2dXPUKernelImpl( \ + ctx, \ + x, \ + x_max, \ + filter, \ + filter_max, \ + bias, \ + branch, \ + branch_max, \ + paddings, \ + dilations, \ + strides, \ + padding_algorithm, \ + groups, \ + act_type, \ + act_param, \ + out, \ + out_max); + +template +void Conv2dXPUKernel(const Context& ctx, + const DenseTensor& x, + const paddle::optional& x_max, + const DenseTensor& filter, + const DenseTensor& filter_max, + const paddle::optional& bias, + const paddle::optional& branch, + const paddle::optional& branch_max, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const std::string& padding_algorithm, + int groups, + int act_type, + float act_param, + DataType out_dtype, + DenseTensor* out, + DenseTensor* out_max) { + if (out_dtype == DataType::FLOAT32) { + CONV2D_XPU_KERNEL_IMPL(T, int16_t, float, int16_t); + } else if (out_dtype == DataType::FLOAT16) { + CONV2D_XPU_KERNEL_IMPL(T, int16_t, dtype::float16, int16_t); + } else { + PADDLE_THROW(phi::errors::Unimplemented("Not support out_dtype is %s.", + DataTypeToString(out_dtype))); + } +} + } // namespace fusion } // namespace phi diff --git a/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc index 68715ca76e1..6a6721194e9 100644 --- a/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc @@ -18,32 +18,42 @@ namespace phi { namespace fusion { -template -void FcXPUKernel(const Context& ctx, - const DenseTensor& x, - const paddle::optional& x_max, - const DenseTensor& w, - const DenseTensor& w_max, - const paddle::optional& bias, - int in_num_col_dims, - bool transpose_x, - float alpha, - float beta, - int act_type, - float act_alpha, - DenseTensor* out, - DenseTensor* out_max) { - using XPUType = typename XPUTypeTrait::Type; +template +void FcXPUKernelImpl(const Context& ctx, + const DenseTensor& x, + const paddle::optional& x_max, + const DenseTensor& w, + const DenseTensor& w_max, + const paddle::optional& bias, + int in_num_col_dims, + bool transpose_x, + float alpha, + float beta, + int act_type, + float act_alpha, + DenseTensor* out, + DenseTensor* out_max) { + using XPUTypeX = typename XPUTypeTrait::Type; + using XPUTypeW = typename XPUTypeTrait::Type; + using XPUTypeOut = typename XPUTypeTrait::Type; auto in_mat_dims = flatten_to_2d(x.dims(), in_num_col_dims); int m = in_mat_dims[0]; int k = in_mat_dims[1]; int n = w.dims()[0]; - auto* x_data = reinterpret_cast(x.data()); + auto* x_data = reinterpret_cast(x.data()); const float* x_max_data = x_max.get_ptr() == nullptr ? nullptr : x_max.get_ptr()->data(); + auto* w_data = reinterpret_cast(w.data()); + auto* w_max_data = w_max.data(); const float* bias_data = bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data(); - auto* out_data = reinterpret_cast(ctx.template Alloc(out)); + auto* out_data = + reinterpret_cast(ctx.template Alloc(out)); + auto* out_max_data = ctx.template Alloc(out_max); xpu::Activation_t act(static_cast(act_type)); if (act_type == xpu::Activation_t::LEAKY_RELU) { act.leaky_alpha = act_alpha; @@ -51,29 +61,72 @@ void FcXPUKernel(const Context& ctx, act.hard_sigmoid_slope = act_alpha; } int r = - xpu::fc_fusion( // TX, TW. TY, TGEMM - ctx.x_context(), // ctx - x_data, // x - w.data(), // w - out_data, // y - m, // m - n, // n - k, // k - transpose_x, // x_trans - true, // w_trans - x_max_data, // x_maxptr - w_max.data(), // w_maxptr - ctx.template Alloc(out_max), // y_maxptr - transpose_x ? m : k, // ldx - k, // ldw - n, // ldy - alpha, // alpha - beta, // beta - bias_data, // bias + xpu::fc_fusion( // TX/TW/TY/TGEMM + ctx.x_context(), // ctx + x_data, // x + w_data, // w + out_data, // y + m, // m + n, // n + k, // k + transpose_x, // x_trans + true, // w_trans + x_max_data, // x_maxptr + w_max_data, // w_maxptr + out_max_data, // y_maxptr + transpose_x ? m : k, // ldx + k, // ldw + n, // ldy + alpha, // alpha + beta, // beta + bias_data, // bias act); PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_xpu"); } +#define FC_XPU_KERNEL_IMPL(x_dtype_, w_dtype_, out_dtype_, gemm_dtype_) \ + FcXPUKernelImpl( \ + ctx, \ + x, \ + x_max, \ + w, \ + w_max, \ + bias, \ + in_num_col_dims, \ + transpose_x, \ + alpha, \ + beta, \ + act_type, \ + act_alpha, \ + out, \ + out_max); + +template +void FcXPUKernel(const Context& ctx, + const DenseTensor& x, + const paddle::optional& x_max, + const DenseTensor& w, + const DenseTensor& w_max, + const paddle::optional& bias, + int in_num_col_dims, + bool transpose_x, + float alpha, + float beta, + int act_type, + float act_alpha, + DataType out_dtype, + DenseTensor* out, + DenseTensor* out_max) { + if (out_dtype == DataType::FLOAT32) { + FC_XPU_KERNEL_IMPL(T, int16_t, float, int16_t); + } else if (out_dtype == DataType::FLOAT16) { + FC_XPU_KERNEL_IMPL(T, int16_t, dtype::float16, int16_t); + } else { + PADDLE_THROW(phi::errors::Unimplemented("Not support out_dtype is %s.", + DataTypeToString(out_dtype))); + } +} + } // namespace fusion } // namespace phi -- GitLab