未验证 提交 4df00939 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] fuse cast to conv2d/fc in mixed precision model (#54493)

上级 4f307a7e
......@@ -236,6 +236,8 @@ if(WITH_XPU)
SRCS xpu/pass_utils.cc
DEPS pass xpu_quant_utils)
set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils)
pass_library(cast_mixed_precision_op_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(yolo_box_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(conv2d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(redundant_onnx_ops_elimination_pass inference DIR xpu DEPS
......@@ -550,6 +552,10 @@ if(WITH_MKLDNN)
endif()
if(WITH_XPU)
cc_test(
test_cast_mixed_precision_op_fuse_pass
SRCS xpu/cast_mixed_precision_op_fuse_pass_test.cc
DEPS cast_mixed_precision_op_fuse_pass)
cc_test(
test_delete_isolated_node_pass
SRCS xpu/delete_isolated_node_pass_test.cc
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct CastBeforePattern : public PatternBase {
CastBeforePattern(PDPattern* pattern,
const std::string& name_scope,
const std::string& mixed_precision_op_type);
PATTERN_DECL_NODE(cast_in);
PATTERN_DECL_NODE(cast);
PATTERN_DECL_NODE(cast_out);
PATTERN_DECL_NODE(mixed_precision_op);
};
CastBeforePattern::CastBeforePattern(PDPattern* pattern,
const std::string& name_scope,
const std::string& mixed_precision_op_type)
: PatternBase(pattern, name_scope, name_scope) {
auto* cast_in =
pattern->NewNode(cast_in_repr())->assert_is_op_input("cast", "X");
auto* cast = pattern->NewNode(cast_repr())
->assert_is_op("cast")
->assert_more([&](Node* node) {
auto* op_desc = node->Op();
return op_desc->GetAttrIfExists<int>("in_dtype") == 5 &&
op_desc->GetAttrIfExists<int>("out_dtype") == 4;
});
auto* cast_out = pattern->NewNode(cast_out_repr())
->assert_is_op_output("cast", "Out")
->assert_is_op_input(mixed_precision_op_type, "x")
->assert_has_n_outputs(1);
auto* mixed_precision_op = pattern->NewNode(mixed_precision_op_repr())
->assert_is_op(mixed_precision_op_type);
cast->LinksFrom({cast_in}).LinksTo({cast_out});
mixed_precision_op->LinksFrom({cast_out});
}
struct CastAfterPattern : public PatternBase {
CastAfterPattern(PDPattern* pattern,
const std::string& name_scope,
const std::string& mixed_precision_op_type);
PATTERN_DECL_NODE(mixed_precision_op);
PATTERN_DECL_NODE(cast_in);
PATTERN_DECL_NODE(cast);
PATTERN_DECL_NODE(cast_out);
};
CastAfterPattern::CastAfterPattern(PDPattern* pattern,
const std::string& name_scope,
const std::string& mixed_precision_op_type)
: PatternBase(pattern, name_scope, name_scope) {
auto* mixed_precision_op = pattern->NewNode(mixed_precision_op_repr())
->assert_is_op(mixed_precision_op_type);
auto* cast_in = pattern->NewNode(cast_in_repr())
->assert_is_op_output(mixed_precision_op_type, "out")
->assert_is_op_input("cast", "X")
->assert_has_n_outputs(1);
auto* cast = pattern->NewNode(cast_repr())
->assert_is_op("cast")
->assert_more([&](Node* node) {
auto* op_desc = node->Op();
return op_desc->GetAttrIfExists<int>("in_dtype") == 4 &&
op_desc->GetAttrIfExists<int>("out_dtype") == 5;
});
auto* cast_out =
pattern->NewNode(cast_out_repr())->assert_is_op_output("cast", "Out");
mixed_precision_op->LinksTo({cast_in});
cast->LinksFrom({cast_in}).LinksTo({cast_out});
}
} // namespace patterns
class CastMixedPrecisionOpFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
int ApplyCastBeforePass(ir::Graph* graph,
const std::string& mixed_precision_op_type) const;
int ApplyCastAfterPass(ir::Graph* graph,
const std::string& mixed_precision_op_type) const;
const std::string name_scope_{"cast_mixed_precision_op_fuse_pass"};
};
int CastMixedPrecisionOpFusePass::ApplyCastBeforePass(
ir::Graph* graph, const std::string& mixed_precision_op_type) const {
GraphPatternDetector gpd;
patterns::CastBeforePattern pattern(
gpd.mutable_pattern(), name_scope_, mixed_precision_op_type);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle ApplyCastBeforePass";
GET_IR_NODE(cast_in);
GET_IR_NODE(cast);
GET_IR_NODE(cast_out);
GET_IR_NODE(mixed_precision_op);
mixed_precision_op->Op()->RenameInput(cast_out->Name(), cast_in->Name());
IR_NODE_LINK_TO(cast_in, mixed_precision_op);
// delete useless node
std::unordered_set<const Node*> delete_nodes = {cast, cast_out};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
int CastMixedPrecisionOpFusePass::ApplyCastAfterPass(
ir::Graph* graph, const std::string& mixed_precision_op_type) const {
GraphPatternDetector gpd;
patterns::CastAfterPattern pattern(
gpd.mutable_pattern(), name_scope_, mixed_precision_op_type);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle ApplyCastAfterPass";
GET_IR_NODE(mixed_precision_op);
GET_IR_NODE(cast_in);
GET_IR_NODE(cast);
GET_IR_NODE(cast_out);
mixed_precision_op->Op()->RenameOutput(cast_in->Name(), cast_out->Name());
int out_dtype = proto::VarType::Type::VarType_Type_FP32;
mixed_precision_op->Op()->SetAttr("out_dtype", out_dtype);
IR_NODE_LINK_TO(mixed_precision_op, cast_out);
// delete useless node
std::unordered_set<const Node*> delete_nodes = {cast_in, cast};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
void CastMixedPrecisionOpFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
int count = 0;
for (auto op_type : {"conv2d_xpu", "fc_xpu"}) {
count += ApplyCastBeforePass(graph, op_type);
count += ApplyCastAfterPass(graph, op_type);
}
AddStatis(count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(cast_mixed_precision_op_fuse_pass,
paddle::framework::ir::CastMixedPrecisionOpFusePass);
REGISTER_PASS_CAPABILITY(cast_mixed_precision_op_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"cast", 0));
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
TEST(CastMixedPrecisionOpFusePass, cast_before) {
Layers layers;
auto* block = layers.Block();
auto* cast_in = layers.data("cast_in");
auto* cast_out = layers.cast(cast_in, 5, 4);
OpDesc* conv2d_xpu = block->AppendOp();
conv2d_xpu->SetType("conv2d_xpu");
conv2d_xpu->SetInput("x", {cast_out->Name()});
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("cast_mixed_precision_op_fuse_pass");
pass->Apply(graph.get());
auto num = GetNumOpNodes(graph, "cast");
PADDLE_ENFORCE_EQ(
num,
0,
platform::errors::PreconditionNotMet(
"cast op should be removed from graph, but graph still has %d ops.",
num));
}
TEST(CastMixedPrecisionOpFusePass, cast_after) {
Layers layers;
auto* block = layers.Block();
auto* cast_in = layers.data("cast_in");
OpDesc* conv2d_xpu = block->AppendOp();
conv2d_xpu->SetType("conv2d_xpu");
conv2d_xpu->SetOutput("out", {cast_in->Name()});
layers.cast(cast_in, 4, 5);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("cast_mixed_precision_op_fuse_pass");
pass->Apply(graph.get());
auto num = GetNumOpNodes(graph, "cast");
PADDLE_ENFORCE_EQ(
num,
0,
platform::errors::PreconditionNotMet(
"cast op should be removed from graph, but graph still has %d ops.",
num));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(cast_mixed_precision_op_fuse_pass);
......@@ -429,10 +429,13 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph,
auto* filter_t =
scope->FindVar(conv_filter->Name())->GetMutable<phi::DenseTensor>();
// conv_filter fp16 --> fp32
auto tensor_type = filter_t->dtype();
if (tensor_type == phi::DataType::FLOAT16) {
auto filter_dtype = filter_t->dtype();
int out_dtype = proto::VarType::Type::VarType_Type_FP32;
if (filter_dtype == phi::DataType::FLOAT16) {
out_dtype = proto::VarType::Type::VarType_Type_FP16;
CastToFp32(filter_t, nullptr);
}
auto filter_dims = filter_t->dims();
bool has_bias = with_bn || with_conv_bias;
// Create conv_fusion_bias (conv bias) variable
......@@ -515,7 +518,6 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph,
Node* filter_max = nullptr;
PrepareWeight<int16_t>(
graph, scope, block, conv_filter, &filter_int16, &filter_max, false);
bool has_branch = with_branch_x || with_branch_y;
// output && output max
std::string conv2d_xpu_out_name;
if (!act_type.empty()) {
......@@ -590,15 +592,8 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph,
conv2d_xpu_op_desc.SetAttr(
"strides",
PADDLE_GET_CONST(std::vector<int>, conv->Op()->GetAttr("strides")));
conv2d_xpu_op_desc.SetAttr("conv_bias", conv_bias);
conv2d_xpu_op_desc.SetAttr("op_type", std::vector<int>{0});
conv2d_xpu_op_desc.SetAttr("place_x", std::vector<int>{0});
conv2d_xpu_op_desc.SetAttr("place_y", std::vector<int>{9});
conv2d_xpu_op_desc.SetAttr("place_z", std::vector<int>{10});
conv2d_xpu_op_desc.SetAttr("paddings", conv_paddings);
conv2d_xpu_op_desc.SetAttr("block_lod", std::vector<int>{1});
conv2d_xpu_op_desc.SetAttr("has_branch", has_branch);
conv2d_xpu_op_desc.SetAttr("has_bias", has_bias);
conv2d_xpu_op_desc.SetAttr("out_dtype", out_dtype);
auto* conv2d_xpu = graph->CreateOpNode(&conv2d_xpu_op_desc);
IR_NODE_LINK_TO(input, conv2d_xpu);
......
......@@ -313,9 +313,11 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph,
auto* filter_t =
scope->FindVar(mul_w->Name())->GetMutable<phi::DenseTensor>();
// filter fp16 --> fp32
auto tensor_type = filter_t->dtype();
if (tensor_type == phi::DataType::FLOAT16) {
// weight fp16 --> fp32
auto filter_dtype = filter_t->dtype();
int out_dtype = proto::VarType::Type::VarType_Type_FP32;
if (filter_dtype == phi::DataType::FLOAT16) {
out_dtype = proto::VarType::Type::VarType_Type_FP16;
CastToFp32(filter_t, nullptr);
}
auto filter_dims = filter_t->dims();
......@@ -435,6 +437,7 @@ int FcXPUFusePass::ApplyImpl(ir::Graph* graph,
"act_alpha", PADDLE_GET_CONST(float, act->Op()->GetAttr("slope")));
}
}
fc_xpu_op_desc.SetAttr("out_dtype", out_dtype);
fc_xpu_op_desc.SetOutput("out", {fc_out_name});
fc_xpu_op_desc.SetOutput("out_max", {fc_out_max_name});
auto* fc_xpu = graph->CreateOpNode(&fc_xpu_op_desc);
......
......@@ -544,8 +544,10 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"add_layernorm_xpu_fuse_pass",
"yolo_box_xpu_fuse_pass",
"link_xpu_op_max_pass",
"inplace_op_var_pass",
"delete_isolated_node_pass",
// "auto_mixed_precision_pass",
"cast_mixed_precision_op_fuse_pass",
"inplace_op_var_pass",
});
use_xpu_ = true;
}
......
......@@ -34,7 +34,7 @@
optional : bias, x_max
- op : conv2d_xpu
args : (Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, Tensor branch_max, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, bool has_bias, bool has_branch, int act_type, float act_param)
args : (Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, Tensor branch_max, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, int act_type, float act_param, DataType out_dtype)
output : Tensor(out), Tensor(out_max)
infer_meta :
func : Conv2dXPUInferMeta
......@@ -54,7 +54,7 @@
optional : mask, seq_lod, max_seq_len
- op : fc_xpu
args : (Tensor x, Tensor x_max, Tensor w, Tensor w_max, Tensor bias, int in_num_col_dims, bool transpose_x, float alpha, float beta, int act_type, float act_alpha)
args : (Tensor x, Tensor x_max, Tensor w, Tensor w_max, Tensor bias, int in_num_col_dims, bool transpose_x, float alpha, float beta, int act_type, float act_alpha, DataType out_dtype)
output : Tensor(out), Tensor(out_max)
infer_meta :
func : FcXPUInferMeta
......
......@@ -151,10 +151,9 @@ void Conv2dXPUInferMeta(const MetaTensor& x,
const std::vector<int>& strides,
const std::string& padding_algorithm,
int groups,
bool has_bias,
bool has_branch,
int act_type,
float act_param,
DataType out_dtype,
MetaTensor* out,
MetaTensor* out_max) {
auto in_dims = x.dims();
......@@ -264,6 +263,7 @@ void Conv2dXPUInferMeta(const MetaTensor& x,
// set output and output max dims
out->set_dims(DDim(out_shape.data(), out_shape.size()));
out_max->set_dims(phi::make_ddim({6}));
out->set_dtype(out_dtype);
}
void EmbeddingWithEltwiseAddXPUInferMeta(
......@@ -302,6 +302,7 @@ void FcXPUInferMeta(const MetaTensor& x,
float beta,
int act_type,
float act_alpha,
DataType out_dtype,
MetaTensor* out,
MetaTensor* out_max) {
std::vector<int> out_shape(in_num_col_dims + 1);
......@@ -310,7 +311,7 @@ void FcXPUInferMeta(const MetaTensor& x,
}
out_shape[in_num_col_dims] = w.dims()[0];
out->set_dims(DDim(out_shape.data(), out_shape.size()));
out->set_dtype(x.dtype());
out->set_dtype(out_dtype);
out->set_layout(x.layout());
out_max->set_dims(phi::make_ddim({6}));
out_max->set_dtype(x.dtype());
......
......@@ -54,10 +54,9 @@ void Conv2dXPUInferMeta(const MetaTensor& x,
const std::vector<int>& strides,
const std::string& padding_algorithm,
int groups,
bool has_bias,
bool has_branch,
int act_type,
float act_param,
DataType out_dtype,
MetaTensor* out,
MetaTensor* out_max);
......@@ -80,6 +79,7 @@ void FcXPUInferMeta(const MetaTensor& x,
float beta,
int act_type,
float act_alpha,
DataType out_dtype,
MetaTensor* out,
MetaTensor* out_max);
......
......@@ -19,27 +19,31 @@
namespace phi {
namespace fusion {
template <typename T, typename Context>
void Conv2dXPUKernel(const Context& ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& x_max,
const DenseTensor& filter,
const DenseTensor& filter_max,
const paddle::optional<DenseTensor>& bias,
const paddle::optional<DenseTensor>& branch,
const paddle::optional<DenseTensor>& branch_max,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::string& padding_algorithm,
int groups,
bool has_bias,
bool has_branch,
int act_type,
float act_param,
DenseTensor* out,
DenseTensor* out_max) {
using XPUType = typename XPUTypeTrait<T>::Type;
template <typename T_X,
typename T_W,
typename T_OUT,
typename T_GEMM,
typename Context>
void Conv2dXPUKernelImpl(const Context& ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& x_max,
const DenseTensor& filter,
const DenseTensor& filter_max,
const paddle::optional<DenseTensor>& bias,
const paddle::optional<DenseTensor>& branch,
const paddle::optional<DenseTensor>& branch_max,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::string& padding_algorithm,
int groups,
int act_type,
float act_param,
DenseTensor* out,
DenseTensor* out_max) {
using XPUTypeX = typename XPUTypeTrait<T_X>::Type;
using XPUTypeW = typename XPUTypeTrait<T_W>::Type;
using XPUTypeOut = typename XPUTypeTrait<T_OUT>::Type;
auto input_dims = x.dims();
auto filter_dims = filter.dims();
// update paddings and dilations accoring to padding_algorithm
......@@ -63,30 +67,51 @@ void Conv2dXPUKernel(const Context& ctx,
int win_h = static_cast<int>(filter_dims[2]);
int win_w = static_cast<int>(filter_dims[3]);
auto* input_data = reinterpret_cast<const XPUType*>(x.data<T>());
auto* input_data = reinterpret_cast<const XPUTypeX*>(x.data<T_X>());
const float* input_max_data =
x_max.get_ptr() == nullptr ? nullptr : x_max.get_ptr()->data<float>();
auto* branch_data =
branch.get_ptr() == nullptr
? nullptr
: reinterpret_cast<const XPUType*>(branch.get_ptr()->data<T>());
auto* filter_data = reinterpret_cast<const XPUTypeW*>(filter.data<T_W>());
auto* filter_max_data = filter_max.data<float>();
const XPUTypeOut* branch_data = nullptr;
auto* branch_tensor = branch.get_ptr();
xpu::ctx_guard RAII_GUARD(ctx.x_context());
if (branch_tensor != nullptr) {
if (branch_tensor->dtype() == out->dtype()) {
branch_data =
reinterpret_cast<const XPUTypeOut*>(branch_tensor->data<T_OUT>());
} else {
auto branch_data_temp =
RAII_GUARD.alloc_l3_or_gm<XPUTypeOut>(branch_tensor->numel());
int r = xpu::cast<XPUTypeX, XPUTypeOut>(
ctx.x_context(),
reinterpret_cast<const XPUTypeX*>(branch_tensor->data<T_X>()),
branch_data_temp,
branch_tensor->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
branch_data = branch_data_temp;
}
}
const float* branch_max_data = branch_max.get_ptr() == nullptr
? nullptr
: branch_max.get_ptr()->data<float>();
const float* bias_data =
bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data<float>();
auto* out_data = reinterpret_cast<XPUType*>(ctx.template Alloc<T>(out));
auto* out_data =
reinterpret_cast<XPUTypeOut*>(ctx.template Alloc<T_OUT>(out));
auto* out_max_data = ctx.template Alloc<float>(out_max);
xpu::Activation_t act(static_cast<xpu::Activation_t::act_enum>(act_type));
if (act_type == xpu::Activation_t::LEAKY_RELU) {
act.leaky_alpha = act_param;
} else if (act_type == xpu::Activation_t::HARD_SIGMOID) {
act.hard_sigmoid_slope = act_param;
}
int r =
xpu::conv2d_fusion<XPUType, int16_t, XPUType, int16_t>( // TX/TW/TY/TGEMM
int r = xpu::
conv2d_fusion<XPUTypeX, XPUTypeW, XPUTypeOut, T_GEMM>( // TX/TW/TY/TGEMM
/* baidu::xpu::api::Context* ctx */ ctx.x_context(),
/* const TX* input */ input_data,
/* const TW* filter */ filter.data<int16_t>(),
/* const TW* filter */ filter_data,
/* TY* output */ out_data,
/* int64_t n */ batch,
/* int64_t ic */ in_c,
......@@ -99,8 +124,8 @@ void Conv2dXPUKernel(const Context& ctx,
/* const std::vector<int>& dilations */ dilations_vec,
/* int64_t groups */ groups,
/* const float* in_maxptr */ input_max_data,
/* const float* filter_maxptr */ filter_max.data<float>(),
/* float* out_maxptr */ ctx.template Alloc<float>(out_max),
/* const float* filter_maxptr */ filter_max_data,
/* float* out_maxptr */ out_max_data,
/* bool is_nchw */ true,
/* const float* bias */ bias_data,
/* const TY* branch */ branch_data,
......@@ -110,6 +135,55 @@ void Conv2dXPUKernel(const Context& ctx,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_xpu");
}
#define CONV2D_XPU_KERNEL_IMPL(x_dtype_, w_dtype_, out_dtype_, gemm_dtype_) \
Conv2dXPUKernelImpl<x_dtype_, w_dtype_, out_dtype_, gemm_dtype_, Context>( \
ctx, \
x, \
x_max, \
filter, \
filter_max, \
bias, \
branch, \
branch_max, \
paddings, \
dilations, \
strides, \
padding_algorithm, \
groups, \
act_type, \
act_param, \
out, \
out_max);
template <typename T, typename Context>
void Conv2dXPUKernel(const Context& ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& x_max,
const DenseTensor& filter,
const DenseTensor& filter_max,
const paddle::optional<DenseTensor>& bias,
const paddle::optional<DenseTensor>& branch,
const paddle::optional<DenseTensor>& branch_max,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::string& padding_algorithm,
int groups,
int act_type,
float act_param,
DataType out_dtype,
DenseTensor* out,
DenseTensor* out_max) {
if (out_dtype == DataType::FLOAT32) {
CONV2D_XPU_KERNEL_IMPL(T, int16_t, float, int16_t);
} else if (out_dtype == DataType::FLOAT16) {
CONV2D_XPU_KERNEL_IMPL(T, int16_t, dtype::float16, int16_t);
} else {
PADDLE_THROW(phi::errors::Unimplemented("Not support out_dtype is %s.",
DataTypeToString(out_dtype)));
}
}
} // namespace fusion
} // namespace phi
......
......@@ -18,32 +18,42 @@
namespace phi {
namespace fusion {
template <typename T, typename Context>
void FcXPUKernel(const Context& ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& x_max,
const DenseTensor& w,
const DenseTensor& w_max,
const paddle::optional<DenseTensor>& bias,
int in_num_col_dims,
bool transpose_x,
float alpha,
float beta,
int act_type,
float act_alpha,
DenseTensor* out,
DenseTensor* out_max) {
using XPUType = typename XPUTypeTrait<T>::Type;
template <typename T_X,
typename T_W,
typename T_OUT,
typename T_GEMM,
typename Context>
void FcXPUKernelImpl(const Context& ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& x_max,
const DenseTensor& w,
const DenseTensor& w_max,
const paddle::optional<DenseTensor>& bias,
int in_num_col_dims,
bool transpose_x,
float alpha,
float beta,
int act_type,
float act_alpha,
DenseTensor* out,
DenseTensor* out_max) {
using XPUTypeX = typename XPUTypeTrait<T_X>::Type;
using XPUTypeW = typename XPUTypeTrait<T_W>::Type;
using XPUTypeOut = typename XPUTypeTrait<T_OUT>::Type;
auto in_mat_dims = flatten_to_2d(x.dims(), in_num_col_dims);
int m = in_mat_dims[0];
int k = in_mat_dims[1];
int n = w.dims()[0];
auto* x_data = reinterpret_cast<const XPUType*>(x.data<T>());
auto* x_data = reinterpret_cast<const XPUTypeX*>(x.data<T_X>());
const float* x_max_data =
x_max.get_ptr() == nullptr ? nullptr : x_max.get_ptr()->data<float>();
auto* w_data = reinterpret_cast<const XPUTypeW*>(w.data<T_W>());
auto* w_max_data = w_max.data<float>();
const float* bias_data =
bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data<float>();
auto* out_data = reinterpret_cast<XPUType*>(ctx.template Alloc<T>(out));
auto* out_data =
reinterpret_cast<XPUTypeOut*>(ctx.template Alloc<T_OUT>(out));
auto* out_max_data = ctx.template Alloc<float>(out_max);
xpu::Activation_t act(static_cast<xpu::Activation_t::act_enum>(act_type));
if (act_type == xpu::Activation_t::LEAKY_RELU) {
act.leaky_alpha = act_alpha;
......@@ -51,29 +61,72 @@ void FcXPUKernel(const Context& ctx,
act.hard_sigmoid_slope = act_alpha;
}
int r =
xpu::fc_fusion<XPUType, int16_t, XPUType, int16_t>( // TX, TW. TY, TGEMM
ctx.x_context(), // ctx
x_data, // x
w.data<int16_t>(), // w
out_data, // y
m, // m
n, // n
k, // k
transpose_x, // x_trans
true, // w_trans
x_max_data, // x_maxptr
w_max.data<float>(), // w_maxptr
ctx.template Alloc<float>(out_max), // y_maxptr
transpose_x ? m : k, // ldx
k, // ldw
n, // ldy
alpha, // alpha
beta, // beta
bias_data, // bias
xpu::fc_fusion<XPUTypeX, XPUTypeW, XPUTypeOut, T_GEMM>( // TX/TW/TY/TGEMM
ctx.x_context(), // ctx
x_data, // x
w_data, // w
out_data, // y
m, // m
n, // n
k, // k
transpose_x, // x_trans
true, // w_trans
x_max_data, // x_maxptr
w_max_data, // w_maxptr
out_max_data, // y_maxptr
transpose_x ? m : k, // ldx
k, // ldw
n, // ldy
alpha, // alpha
beta, // beta
bias_data, // bias
act);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_xpu");
}
#define FC_XPU_KERNEL_IMPL(x_dtype_, w_dtype_, out_dtype_, gemm_dtype_) \
FcXPUKernelImpl<x_dtype_, w_dtype_, out_dtype_, gemm_dtype_>( \
ctx, \
x, \
x_max, \
w, \
w_max, \
bias, \
in_num_col_dims, \
transpose_x, \
alpha, \
beta, \
act_type, \
act_alpha, \
out, \
out_max);
template <typename T, typename Context>
void FcXPUKernel(const Context& ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& x_max,
const DenseTensor& w,
const DenseTensor& w_max,
const paddle::optional<DenseTensor>& bias,
int in_num_col_dims,
bool transpose_x,
float alpha,
float beta,
int act_type,
float act_alpha,
DataType out_dtype,
DenseTensor* out,
DenseTensor* out_max) {
if (out_dtype == DataType::FLOAT32) {
FC_XPU_KERNEL_IMPL(T, int16_t, float, int16_t);
} else if (out_dtype == DataType::FLOAT16) {
FC_XPU_KERNEL_IMPL(T, int16_t, dtype::float16, int16_t);
} else {
PADDLE_THROW(phi::errors::Unimplemented("Not support out_dtype is %s.",
DataTypeToString(out_dtype)));
}
}
} // namespace fusion
} // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册