提交 9443bc19 编写于 作者: D DesmonDay

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into support-0D-sort

......@@ -195,6 +195,7 @@ function(create_dummy_static_lib TARGET_NAME)
# the dummy target would be consisted of limit size libraries
set(limit ${merge_LIMIT})
list(LENGTH merge_LIBS libs_len)
message("libs_len ${libs_len}")
foreach(lib ${merge_LIBS})
list(APPEND merge_list ${lib})
list(LENGTH merge_list listlen)
......
......@@ -739,6 +739,14 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
self.backward_returns_list,
) = ParseYamlBackward(backward_args_str, backward_returns_str)
# Remove the output which is intermediate
if 'intermediate' in grad_api_contents:
backward_returns_list_new = []
for return_item in self.backward_returns_list:
if return_item[0] not in grad_api_contents['intermediate']:
backward_returns_list_new.append(return_item)
self.backward_returns_list = backward_returns_list_new
def CollectForwardInfoFromBackwardContents(self):
backward_forward_str = self.backward_forward_str
......@@ -1979,7 +1987,6 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
fill_zero_str += f"{indent}egr::EagerUtils::FillZeroForEmptyGradInput(&grads[{fwd_position}], input_metas[{fwd_position}]);\n"
inplace_grad_input_str = ""
inplaced_tensor_wrapper = False
inplace_check_str = ""
optional_inplace_var_name = []
# Grad Ins from TensorWrappers
......
......@@ -105,6 +105,7 @@ pass_library(delete_fill_constant_op_pass inference)
pass_library(constant_folding_pass inference)
pass_library(auto_mixed_precision_pass inference)
pass_library(conv2d_fusion_layout_transfer_pass inference)
pass_library(silu_fuse_pass inference)
pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base)
pass_library(skip_layernorm_fuse_pass base)
......@@ -429,10 +430,6 @@ if(WITH_MKLDNN)
test_conv_batch_norm_mkldnn_fuse_pass
SRCS mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc
DEPS ${TEST_CONV_BN_PASS_DEPS})
cc_test(
test_scale_matmul_fuse_pass
SRCS mkldnn/scale_matmul_fuse_pass_tester.cc
DEPS scale_matmul_fuse_pass)
cc_test(
test_mkldnn_placement_pass
SRCS mkldnn/mkldnn_placement_pass_tester.cc
......
......@@ -143,10 +143,16 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
static_cast<phi::DataType>(Get<int>("model_precision")) ==
phi::DataType::FLOAT16 ||
Get<bool>("enable_gpu_mixed");
bool cutlass_enable = false;
bool cutlass_enable = Get<bool>("use_cutlass");
#ifdef PADDLE_WITH_CUTLASS
cutlass_enable = true;
const auto &prop = platform::GetDeviceProperties(Get<int>("gpu_device_id"));
int sm_version = prop.major * 10 + prop.minor;
// Now we only implement cutlass kernel on SM75.
if (sm_version == 75) {
} else {
cutlass_enable = false;
}
#endif
if (!(is_fp16_precision && cutlass_enable)) return;
......@@ -184,10 +190,21 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
auto filter_names = op_node->Op()->Input("Filter");
auto act_type = op_node->Op()->GetAttrIfExists<std::string>("activation");
constexpr int CUTLASS_NHWC_ALIGNMENT = 8;
std::unordered_set<std::string> cutlass_act_set = {
// conv2d_fusion has two forms: conv + bias + act, conv + bias +
// elmentwise_add + act.
std::unordered_set<std::string> cutlass_cba_act_set = {
"relu", "swish", "identity", "leaky_relu"};
if (!cutlass_act_set.count(act_type)) {
return false;
std::unordered_set<std::string> cutlass_cbaa_act_set = {"relu"};
bool is_residual = op_node->Op()->Input("ResidualData").size() >= 1UL;
if (is_residual) {
if (!cutlass_cbaa_act_set.count(act_type)) {
return false;
}
} else {
if (!cutlass_cba_act_set.count(act_type)) {
return false;
}
}
// If filter's channel is not multiple of 8, conv2d_fusion not run at nhwc.
......
......@@ -32,7 +32,11 @@ void AddVarToScope(Scope* param_scope,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<phi::DenseTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(platform::CPUPlace());
auto* data = tensor->mutable_data<float>(platform::CPUPlace());
int64_t numel = tensor->numel();
for (int64_t i = 0; i < numel; ++i) {
data[i] = 0;
}
}
Scope* CreateParamScope() {
......
......@@ -167,14 +167,19 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
phi::DataType::FLOAT16 ||
Get<bool>("enable_gpu_mixed");
constexpr int CUTLASS_NHWC_ALIGNMENT = 8;
if (is_fp16_precision) {
bool cutlass_enable = Get<bool>("use_cutlass");
if (is_fp16_precision && cutlass_enable) {
#ifdef PADDLE_WITH_CUTLASS
// cutlass now support these activations
// cutlass_act_set.insert("swish");
// cutlass_act_set.insert("relu");
// cutlass_act_set.insert("identity");
// cutlass_act_set.insert("leaky_relu");
const auto& prop = platform::GetDeviceProperties(Get<int>("gpu_device_id"));
int sm_version = prop.major * 10 + prop.minor;
// Now we only implement cutlass kernel on SM75.
if (sm_version == 75) {
// Cutlass now support these cba activations.
cutlass_act_set.insert("swish");
cutlass_act_set.insert("relu");
cutlass_act_set.insert("identity");
cutlass_act_set.insert("leaky_relu");
}
all_act_set.insert(cutlass_act_set.begin(), cutlass_act_set.end());
#endif
}
......@@ -198,8 +203,8 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
auto* filter_var = scope->FindLocalVar(conv_filter->Name());
auto* filter_tensor = filter_var->GetMutable<phi::DenseTensor>();
CHECK_EQ(filter_tensor->dims().size() == 4UL, true);
// when this conv2d_fusion problem size is not supported by cutlass and not
// supported by cuDNN, we should not apply this pass
// When this conv2d_fusion problem size is not supported by cutlass and not
// supported by cuDNN, we should not apply this pass.
int oc = filter_tensor->dims()[0];
int ic = filter_tensor->dims()[1];
bool cutlass_can_fuse = oc % CUTLASS_NHWC_ALIGNMENT == 0 &&
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/scale_matmul_fuse_pass.h"
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog,
const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
float scale = 1.0f,
float bias = 0.0f) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
if (type == "scale") {
op->SetInput("X", {inputs[0]});
op->SetAttr("scale", scale);
op->SetAttr("bias", bias);
} else if (type == "matmul") {
op->SetAttr("transpose_X", false);
op->SetAttr("transpose_Y", false);
op->SetInput("X", {inputs[0]});
if (inputs.size() > 1) op->SetInput("Y", {inputs[1]});
op->SetAttr("alpha", scale);
} else {
FAIL() << "Unexpected operator type.";
}
op->SetOutput("Out", {outputs[0]});
}
// a->scale->b
// (b,c)->matmul->d
ProgramDesc BuildProgramDesc(float scale, float bias, float alpha) {
ProgramDesc prog;
for (auto& v : std::vector<std::string>({"a", "b", "c", "d"})) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "scale", {"a"}, {"b"}, scale, bias);
SetOp(&prog, "matmul", {"b", "c"}, {"d"}, alpha);
return prog;
}
void MainTest(const ProgramDesc& prog,
int removed_nodes_count,
const std::vector<std::string> scale_in_out,
const std::vector<std::string> matmul_in_out,
float alpha) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int original_nodes_num = graph->Nodes().size();
auto pass = PassRegistry::Instance().Get("scale_matmul_fuse_pass");
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
if (op->Type() == "scale") {
EXPECT_EQ(op->Input("X")[0], scale_in_out[0]);
EXPECT_EQ(op->Output("Out")[0], scale_in_out[1]);
} else if (op->Type() == "matmul") {
EXPECT_EQ(op->Input("X")[0], matmul_in_out[0]);
EXPECT_EQ(op->Input("Y")[0], matmul_in_out[1]);
EXPECT_EQ(op->Output("Out")[0], matmul_in_out[2]);
EXPECT_EQ(op->GetAttrIfExists<float>("alpha"), alpha);
}
}
}
EXPECT_EQ(original_nodes_num - removed_nodes_count, current_nodes_num);
}
TEST(ScaleMatmulFusePass, scale_matmul_with_no_bias) {
auto bias = 0.0f;
auto scale = 2.34f;
auto alpha = 3.45f;
int removed_nodes_count = 2;
MainTest(BuildProgramDesc(scale, bias, alpha),
removed_nodes_count,
{},
{"a", "c", "d"},
scale * alpha);
}
TEST(ScaleMatmulFusePass, scale_matmul_with_bias) {
auto bias = 1.0f;
auto scale = 2.34f;
auto alpha = 3.45f;
int removed_nodes_count = 0;
MainTest(BuildProgramDesc(scale, bias, alpha),
removed_nodes_count,
{"a", "b"},
{"b", "c", "d"},
alpha);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(scale_matmul_fuse_pass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/silu_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
void SiluFusePass::ApplyImpl(ir::Graph* graph) const {
// This pass is used for cutlass, because cutlass can fuse conv + bias + silu
bool cutlass_enable = Get<bool>("use_cutlass");
if (!cutlass_enable) {
return;
}
const std::string pattern_name = "silu_fuse";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* sigmoid_in = gpd.mutable_pattern()->NewNode("sigmoid_in");
auto sigmoid_op =
gpd.mutable_pattern()->NewNode("sigmoid_op")->assert_is_op("sigmoid");
auto sigmoid_out = gpd.mutable_pattern()
->NewNode("sigmoid_out")
->assert_is_op_output("sigmoid")
->AsIntermediate();
auto elementwise_mul_op = gpd.mutable_pattern()
->NewNode("elementwise_mul_op")
->assert_is_op("elementwise_mul");
auto elementwise_mul_out = gpd.mutable_pattern()
->NewNode("elementwise_mul_out")
->assert_is_op_output("elementwise_mul")
->AsOutput();
sigmoid_op->LinksFrom({sigmoid_in}).LinksTo({sigmoid_out});
elementwise_mul_op->LinksFrom({sigmoid_in, sigmoid_out})
.LinksTo({elementwise_mul_out});
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
Node* sigmoid_in_node = subgraph.at(sigmoid_in);
Node* sigmoid_op_node = subgraph.at(sigmoid_op);
Node* elementwise_mul_op_node = subgraph.at(elementwise_mul_op);
Node* elementwise_mul_out_node = subgraph.at(elementwise_mul_out);
OpDesc new_desc;
new_desc.SetType("swish");
new_desc.SetAttr("beta", 1.f);
new_desc.SetInput("X", {sigmoid_in_node->Name()});
new_desc.SetOutput("Out", {elementwise_mul_out_node->Name()});
new_desc.Flush();
std::unordered_set<const Node*> del_node_set;
del_node_set.insert(sigmoid_op_node);
del_node_set.insert(elementwise_mul_op_node);
GraphSafeRemoveNodes(graph, del_node_set);
auto fused_node = graph->CreateOpNode(&new_desc);
IR_NODE_LINK_TO(sigmoid_in_node, fused_node);
IR_NODE_LINK_TO(fused_node, elementwise_mul_out_node);
};
gpd(graph, handler);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(silu_fuse_pass, paddle::framework::ir::SiluFusePass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
......@@ -13,23 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
namespace phi {
class Graph;
KernelSignature SqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"squeeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"});
}
class SiluFusePass : public FusePassBase {
public:
virtual ~SiluFusePass() {}
KernelSignature SqueezeGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"squeeze_grad", {"XShape", "Out@GRAD"}, {"axes"}, {"X@GRAD"});
}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(squeeze2, squeeze);
PD_REGISTER_BASE_KERNEL_NAME(squeeze2_grad, squeeze_grad);
PD_REGISTER_ARG_MAPPING_FN(squeeze2, phi::SqueezeOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(squeeze2_grad, phi::SqueezeGradOpArgumentMapping);
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -1603,11 +1603,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
}
#endif
auto exe_ctx = ExecutionContext(*this, scope, *dev_ctx, *runtime_ctx);
// using cache
if (kernel_type_.get()) {
dev_ctx = pool.Get(kernel_type_->place_);
}
auto exe_ctx = ExecutionContext(*this, scope, *dev_ctx, *runtime_ctx);
// TODO(Liu-xiandong): Now we are using too much if-else and hard code in XPU
// device, it's ugly, and we will refactor in the future.
......@@ -2716,22 +2716,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
static_cast<proto::VarType::Type>(-1);
proto::VarType::Type data_type = dafault_data_type;
auto in_name_list = ctx.InNameList();
if (Info().HasOpProtoAndChecker()) {
for (auto& attr : Info().Proto().attrs()) {
auto it =
std::find_if(in_name_list.begin(),
in_name_list.end(),
[&attr](const std::string* name) {
return attr.support_tensor() && *name == attr.name();
});
if (it != in_name_list.end()) {
in_name_list.erase(it);
}
}
}
for (auto* name : in_name_list) {
for (auto* name : ctx.InNameList()) {
if (ctx.InputSize(*name) == 1UL) {
ParseInputDataType(ctx.InputVar(*name), *name, &data_type);
} else {
......
......@@ -202,6 +202,7 @@ struct Argument {
// Passed from config.
DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool);
DECL_ARGUMENT_FIELD(use_cutlass, UseCutlass, bool);
DECL_ARGUMENT_FIELD(use_fc_padding, UseFcPadding, bool);
DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int);
......
......@@ -52,6 +52,7 @@ void IRPassManager::CreatePasses(Argument *argument,
for (const std::string &pass_name : passes) {
auto pass = framework::ir::PassRegistry::Instance().Get(pass_name);
pass->Set("use_varseqlen", new bool(argument->tensorrt_use_varseqlen()));
pass->Set("use_cutlass", new bool(argument->use_cutlass()));
pass->Set("with_interleaved",
new bool(argument->tensorrt_with_interleaved()));
pass->Set("tensorrt_transformer_posid",
......@@ -80,6 +81,10 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("optim_shape_tensor",
new std::map<std::string, std::vector<int>>());
// This gpu_device_id is used by some fp16 precision passes, so move it
// here.
pass->Set("gpu_device_id", new int(argument->gpu_device_id()));
// tuned trt dynamic_shape
pass->Set("trt_tuned_dynamic_shape",
new bool(argument->tensorrt_tuned_dynamic_shape()));
......@@ -198,7 +203,6 @@ void IRPassManager::CreatePasses(Argument *argument,
"model_opt_cache_dir",
new std::string(GetOrCreateModelOptCacheDir(model_opt_cache_dir)));
}
pass->Set("gpu_device_id", new int(argument->gpu_device_id()));
pass->Set("use_static_engine", new bool(use_static_engine));
pass->Set("model_from_memory", new bool(argument->model_from_memory()));
pass->Set("use_inspector", new bool(argument->tensorrt_use_inspector()));
......
......@@ -222,6 +222,51 @@ void MakeSimpleReusePlan(
}
}
// Remove the inplace operation from the plan because it does not support memory
// reuse
void DelInplaceOpFromPlan(
Graph* graph,
std::unordered_map<std::string, std::string>* node2cluster,
int sort_kind) {
auto topo_nodes = TopologyVarientSort(
*graph, static_cast<framework::ir::SortKind>(sort_kind));
for (auto* op_node : topo_nodes) {
if (!op_node->IsOp()) continue;
auto input_tensors = op_node->inputs;
auto output_tensors = op_node->outputs;
std::unordered_set<std::string> in_names;
for (const Node* node : input_tensors) {
if (!node->Var()) continue;
if (node->Var()->Persistable()) continue;
std::string var = node->Name();
in_names.insert(var);
}
for (const Node* node : output_tensors) {
if (!node->Var()) continue;
if (node->Var()->Persistable()) continue;
std::string var = node->Name();
if (in_names.find(var) != in_names.end()) {
// delete key
if (node2cluster->count(var)) {
node2cluster->erase(var);
}
// delete value
std::string tmp_name = "";
for (auto it = node2cluster->begin(); it != node2cluster->end(); ++it) {
if (it->second == var) {
if (tmp_name == "") {
tmp_name = it->first;
}
it->second = tmp_name;
}
}
}
}
}
}
// NOTE The optimized opdesc doesn't match ir::Graph.
void UpdateOpDescsByReuse(
Graph* graph,
......@@ -324,6 +369,7 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
CollectLifeCycle(graph, &lifecycles, sort_kind);
CollectVarMemorySize(graph, &space_table);
MakeSimpleReusePlan(lifecycles, space_table, &node2cluster, &cluster_size);
DelInplaceOpFromPlan(graph, &node2cluster, sort_kind);
auto* pass_res_info = PassResultInfoForRuntime::Instance();
pass_res_info->Set(
......
......@@ -115,6 +115,17 @@ void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
Update();
}
void AnalysisConfig::Exp_EnableUseCutlass() {
#if defined(PADDLE_WITH_CUTLASS)
use_cutlass_ = true;
#else
LOG(ERROR) << "Please compile with cutlass to EnableUseCutlass()";
use_cutlass_ = false;
#endif
Update();
}
void AnalysisConfig::SetExecStream(void *stream) {
PADDLE_ENFORCE_NOT_NULL(
stream,
......@@ -389,6 +400,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(use_fc_padding_);
// GPU related.
CP_MEMBER(use_gpu_);
CP_MEMBER(use_cutlass_);
CP_MEMBER(use_external_stream_);
CP_MEMBER(exec_stream_);
CP_MEMBER(use_cudnn_);
......@@ -1249,6 +1261,7 @@ std::string AnalysisConfig::Summary() {
// gpu info
os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"});
if (use_gpu_) {
os.InsertRow({"use_cutlass", use_cutlass_ ? "true" : "false"});
os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)});
os.InsertRow({"enable_gpu_mixed", std::to_string(enable_gpu_mixed_)});
os.InsertRow({"memory_pool_init_size",
......
......@@ -1088,6 +1088,7 @@ void AnalysisPredictor::PrepareArgument() {
// Init std::unique_ptr argument_.
argument_.reset(new Argument);
argument_->SetUseGPU(config_.use_gpu());
argument_->SetUseCutlass(config_.use_cutlass_);
argument_->SetUseFcPadding(config_.use_fc_padding());
argument_->SetGPUDeviceId(config_.gpu_device_id());
argument_->SetEnableIrOptim(config_.enable_ir_optim_);
......@@ -2396,6 +2397,7 @@ USE_TRT_CONVERTER(cast)
USE_TRT_CONVERTER(recover_padding)
USE_TRT_CONVERTER(remove_padding)
USE_TRT_CONVERTER(equal);
USE_TRT_CONVERTER(not_equal);
USE_TRT_CONVERTER(top_k)
USE_TRT_CONVERTER(top_k_v2)
USE_TRT_CONVERTER(range)
......
......@@ -395,6 +395,12 @@ struct PD_INFER_DECL AnalysisConfig {
///
bool use_gpu() const { return use_gpu_; }
///
/// \brief When running the fp16 model on Nvidia GPU, you can also try running
/// your model on cutlass.
///
void Exp_EnableUseCutlass();
///
///
/// \brief A boolean state telling whether the XPU is turned on.
///
/// \return bool Whether the XPU is turned on.
......@@ -1047,6 +1053,7 @@ struct PD_INFER_DECL AnalysisConfig {
// GPU related.
bool use_gpu_{false};
bool use_cutlass_{false};
int gpu_device_id_{0};
uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB.
bool enable_gpu_mixed_{false};
......
......@@ -164,6 +164,7 @@ const std::vector<std::string> kLiteSubgraphPasses({
const std::vector<std::string> kGpuLowerPrecisionPasses{
"identity_scale_op_clean_pass",
"simplify_with_basic_ops_pass",
"silu_fuse_pass",
"delete_quant_dequant_linear_op_pass",
"delete_weight_dequant_linear_op_pass",
"map_depthwise_conv_to_conv_pass",
......@@ -172,6 +173,7 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
"conv_elementwise_add_act_fuse_pass",
"conv_elementwise_add2_act_fuse_pass",
"conv_elementwise_add_fuse_pass",
"conv2d_fusion_layout_transfer_pass",
"multihead_matmul_fuse_pass_v2",
"fused_multi_transformer_encoder_pass",
"fused_multi_transformer_decoder_pass",
......@@ -216,6 +218,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"delete_weight_dequant_linear_op_pass", //
"map_depthwise_conv_to_conv_pass", //
"constant_folding_pass", //
"silu_fuse_pass", //
"conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", //
"embedding_eltwise_layernorm_fuse_pass", //
......@@ -250,7 +253,8 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
#endif //
"transpose_flatten_concat_fuse_pass", //
"constant_folding_pass", //
"auto_mixed_precision_pass", //
"conv2d_fusion_layout_transfer_pass", //
"auto_mixed_precision_pass"
});
use_gpu_ = true;
......
......@@ -142,7 +142,8 @@ void ConvertConv2d(TensorRTEngine* engine,
layer,
platform::errors::Fatal("TensorRT create conv2d/conv2d_transpose"
" layer failed."));
layer->setStride(nv_strides);
layer->setStrideNd(nv_strides);
layer->setPrePadding(nv_pre_paddings);
if (output_padding.size() > 0) {
nv_post_paddings.d[0] -= output_padding[0];
......@@ -189,7 +190,7 @@ class Conv2dOpConverter : public OpConverter {
TensorRTEngine::Weight& weight,
TensorRTEngine::Weight& bias) -> nvinfer1::IConvolutionLayer* {
auto* layer = TRT_ENGINE_ADD_LAYER(engine_,
Convolution,
ConvolutionNd,
*inputs,
n_output,
ksize,
......
......@@ -35,7 +35,6 @@ class EqualOpConverter : public OpConverter {
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
#if IS_TRT_VERSION_GE(8000)
framework::OpDesc op_desc(op, nullptr);
nvinfer1::ILayer* layer = nullptr;
......@@ -79,11 +78,62 @@ class EqualOpConverter : public OpConverter {
layer = TRT_ENGINE_ADD_LAYER(
engine_, ElementWise, *X, *Y, nvinfer1::ElementWiseOperation::kEQUAL);
RreplenishLayerAndOutput(layer, "equal", {output_name}, test_mode);
#else
PADDLE_THROW(
platform::errors::Fatal("ElementWise Equal Operation is only supported "
"on TRT 8 or higher version."));
#endif
}
};
class NotEqualOpConverter : public OpConverter {
public:
NotEqualOpConverter() {}
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
framework::OpDesc op_desc(op, nullptr);
nvinfer1::ILayer* layer = nullptr;
auto* X = engine_->GetITensor(op_desc.Input("X").front());
auto* Y = engine_->GetITensor(op_desc.Input("Y").front());
nvinfer1::Dims dims_x = X->getDimensions();
nvinfer1::Dims dims_y = Y->getDimensions();
int axis = PADDLE_GET_CONST(int, op_desc.GetAttr("axis"));
if (axis < 0) {
axis = std::abs(dims_x.nbDims - dims_y.nbDims);
}
auto output_name = op_desc.Output("Out")[0];
nvinfer1::IShuffleLayer* expand_layer = nullptr;
if (dims_x.nbDims > dims_y.nbDims) {
nvinfer1::Dims expand_shape;
expand_shape.nbDims = dims_x.nbDims;
for (int i = 0; i < expand_shape.nbDims; i++) {
expand_shape.d[i] = 1;
}
for (int i = 0; i < dims_y.nbDims; i++) {
expand_shape.d[i + axis] = dims_y.d[i];
}
expand_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *Y);
expand_layer->setReshapeDimensions(expand_shape);
Y = expand_layer->getOutput(0);
} else if (dims_x.nbDims < dims_y.nbDims) {
nvinfer1::Dims expand_shape;
expand_shape.nbDims = dims_y.nbDims;
for (int i = 0; i < expand_shape.nbDims; i++) {
expand_shape.d[i] = 1;
}
for (int i = 0; i < dims_x.nbDims; i++) {
expand_shape.d[i + axis] = dims_x.d[i];
}
expand_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *X);
expand_layer->setReshapeDimensions(expand_shape);
X = expand_layer->getOutput(0);
}
layer = TRT_ENGINE_ADD_LAYER(
engine_, ElementWise, *X, *Y, nvinfer1::ElementWiseOperation::kEQUAL);
layer = TRT_ENGINE_ADD_LAYER(
engine_, Unary, *layer->getOutput(0), nvinfer1::UnaryOperation::kNOT);
RreplenishLayerAndOutput(layer, "not_equal", {output_name}, test_mode);
}
};
......@@ -92,3 +142,4 @@ class EqualOpConverter : public OpConverter {
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(equal, EqualOpConverter);
REGISTER_TRT_OP_CONVERTER(not_equal, NotEqualOpConverter);
......@@ -119,24 +119,21 @@ struct SimpleOpTypeSetTeller : public Teller {
#endif
}
// In static shape mode in TRT, we can't allow that op's input is a
// 1D-tensor So we filter it here. Some op like elementwise having "Y" too,
// but that is dealt with in the specified op, here just the common case
// In static shape in Paddle-TRT, we can't allow that one op has a
// 1D intermediate tensor as input.
if (!with_dynamic_shape) {
std::string X_name;
auto inputs = desc.Inputs();
if (inputs.count("X") && !desc.Input("X").empty()) {
X_name = desc.Input("X")[0];
} else if (inputs.count("Input") && !desc.Input("Input").empty()) {
X_name = desc.Input("Input")[0];
}
auto* block = desc.Block();
if (block) {
auto* x_var_desc = block->FindVar(X_name);
// Can't get feed op's TensorDesc
if (op_type != "feed" && x_var_desc && !x_var_desc->Persistable()) {
const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() == 1) return false;
for (auto iter : inputs) {
for (auto var_name : iter.second) {
auto* block = desc.Block();
if (block) {
auto* var_desc = block->FindVar(var_name);
// Can't get feed op's TensorDesc
if (op_type != "feed" && var_desc && !var_desc->Persistable()) {
const auto shape = var_desc->GetShape();
if (shape.size() == 1) return false;
}
}
}
}
}
......@@ -2341,7 +2338,7 @@ struct SimpleOpTypeSetTeller : public Teller {
}
#endif
if (op_type == "equal") {
if (op_type == "equal" || op_type == "not_equal") {
#if !IS_TRT_VERSION_GE(8000)
VLOG(3) << "compare is not supported when TensorRT < 8.0";
return false;
......@@ -2493,6 +2490,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"elementwise_max",
"elementwise_floordiv",
"equal",
"not_equal",
"less_than",
"greater_than",
"logical_or",
......@@ -2639,6 +2637,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"elementwise_max",
"elementwise_floordiv",
"equal",
"not_equal",
"less_than",
"greater_than",
"logical_or",
......
......@@ -330,3 +330,13 @@ REGISTER_OPERATOR(
ops::ConvOpInferVarType,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
// This op is used by cutlass, conv2d_fusion_cutlass is a intermediate op
// produced by conv2d_fusion_layout_transfer_pass.
REGISTER_OPERATOR(
conv2d_fusion_cutlass,
ops::Conv2DFusionOp,
ops::Conv2DFusionOpMaker,
ops::ConvOpInferVarType,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
......@@ -131,9 +131,10 @@ def process_int_array(op_item, int_array_configs):
)
if attr_item['is_support_tensor']:
attr_item['typename'] = (
data_type_map[int_array_config['data_type']]
'int[]'
if 'data_type' in int_array_config
else 'std::vector<int64_t>'
and int_array_config['data_type'] == 'int'
else 'int64_t[]'
)
else:
attr_item['data_type'] = (
......@@ -153,21 +154,95 @@ def process_int_array(op_item, int_array_configs):
# replace name of op and params for OpMaker
def replace_compat_name(op_op_map, forward_op_dict, backward_op_dict):
def get_op_and_op_name(op_item):
def replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict):
def get_phi_and_fluid_op_name(op_item):
names = op_item.split('(')
if len(names) == 1:
return names[0].strip(), names[0].strip()
else:
return names[0].strip(), names[1].split(')')[0].strip()
def update_op_attr_name(attrs, attrs_alias_map):
for attr_item in attrs:
if attr_item['name'] in attrs_alias_map:
attr_item['name'] = attrs_alias_map[attr_item['name']]
def update_op_param_name(op_args, args_alias_map):
for item in op_args:
if item['name'] in args_alias_map:
item['name'] = args_alias_map[item['name']]
def update_grad_args_name(op_args, args_alias_map):
for item in op_args:
if (
item['name'].endswith('_grad')
and item['name'][:-5] in args_alias_map
):
args_alias_map[item['name']] = (
args_alias_map[item['name'][:-5]] + '_grad'
)
item['name'] = args_alias_map[item['name'][:-5]] + '_grad'
def get_param_list_alias(param_list, args_map):
return [
args_map[param] if param in args_map else param
for param in param_list
]
for op_args in op_op_map:
new_op_name, op_name = get_op_and_op_name(op_args['op'])
def update_common_params_name(
op_item, args_name_map, scalar_configs, int_array_configs
):
if 'inplace' in op_item and op_item['inplace']:
inplace_map = {}
for key, val in op_item['inplace'].items():
if key in args_map:
key = args_map[key]
if val in args_map:
val = args_map[val]
inplace_map[key] = val
op_item['inplace'] = inplace_map
if 'no_need_buffer' in op_item and op_item['no_need_buffer']:
op_item['no_need_buffer'] = get_param_list_alias(
op_item['no_need_buffer'], args_map
)
process_scalar(op_item, scalar_configs)
process_int_array(op_item, int_array_configs)
if 'invoke' in op_item:
op_item['invoke']['args'] = [
args_map[param.strip()]
if param.strip() in args_map
else param.strip()
for param in op_item['invoke']['args'].split(',')
]
return
op_item['infer_meta']['param'] = get_param_list_alias(
op_item['infer_meta']['param'], args_name_map
)
op_item['kernel']['param'] = get_param_list_alias(
op_item['kernel']['param'], args_name_map
)
if op_item['kernel']['data_type']:
op_item['kernel']['data_type']['candidates'] = get_param_list_alias(
op_item['kernel']['data_type']['candidates'], args_name_map
)
if op_item['kernel']['backend']:
op_item['kernel']['backend']['candidates'] = get_param_list_alias(
op_item['kernel']['backend']['candidates'], args_name_map
)
if op_item['kernel']['layout']:
op_item['kernel']['layout']['candidates'] = get_param_list_alias(
op_item['kernel']['layout']['candidates'], args_name_map
)
def update_grad_op_compat_name(grad_op_item, args_name_map):
update_op_param_name(grad_op_item['inputs'], args_name_map)
update_op_param_name(grad_op_item['outputs'], args_name_map)
update_op_param_name(grad_op_item['attrs'], args_name_map)
update_op_param_name(grad_op_item['forward']['inputs'], args_name_map)
update_op_param_name(grad_op_item['forward']['outputs'], args_name_map)
update_op_param_name(grad_op_item['forward']['attrs'], args_name_map)
update_grad_args_name(grad_op_item['inputs'], args_map)
update_grad_args_name(grad_op_item['outputs'], args_map)
for op_args in op_fluid_map_list:
new_op_name, op_name = get_phi_and_fluid_op_name(op_args['op'])
if new_op_name not in forward_op_dict:
continue
forward_op_item = forward_op_dict[new_op_name]
......@@ -179,189 +254,102 @@ def replace_compat_name(op_op_map, forward_op_dict, backward_op_dict):
scalar_configs = None
int_array_configs = None
if 'scalar' in op_args:
scalar_configs = op_args['scalar']
if 'int_array' in op_args:
int_array_configs = op_args['int_array']
if 'extra' in op_args and 'outputs' in op_args['extra']:
for out_item in forward_op_item['outputs']:
if out_item['name'] in op_args['extra']['outputs']:
out_item['is_extra'] = True
process_scalar(forward_op_item, scalar_configs)
process_int_array(forward_op_item, int_array_configs)
key_set = ['inputs', 'attrs', 'outputs']
args_map = {}
for key in key_set:
if key in op_args:
args_map.update(op_args[key])
for args_item in forward_op_item[key]:
if args_item['name'] in op_args[key]:
if (
scalar_configs
and args_item['name'] in scalar_configs
):
scalar_configs[
op_args[key][args_item['name']]
] = scalar_configs[args_item['name']]
if (
int_array_configs
and args_item['name'] in int_array_configs
):
int_array_configs[
op_args[key][args_item['name']]
] = int_array_configs[args_item['name']]
args_item['name'] = op_args[key][args_item['name']]
if has_backward:
for args_item in backward_op_item['forward'][key]:
if args_item['name'] in op_args[key]:
args_item['name'] = op_args[key][args_item['name']]
forward_op_item["attr_dict"] = to_named_dict(forward_op_item["attrs"])
update_common_params_name(
forward_op_item, args_map, scalar_configs, int_array_configs
)
if has_backward:
update_grad_op_compat_name(backward_op_item, args_map)
update_common_params_name(
backward_op_item, args_map, scalar_configs, int_array_configs
)
backward_op_item["attr_dict"] = to_named_dict(
backward_op_item["attrs"]
)
if 'backward' not in op_args:
continue
if 'backward' in op_args and has_backward:
backward_op_list = op_args['backward'].split(',')
_, bw_op_name = get_op_and_op_name(backward_op_list[0])
_, bw_op_name = get_phi_and_fluid_op_name(backward_op_list[0])
forward_op_item['backward'] = bw_op_name
backward_op_item['op_name'] = bw_op_name
process_scalar(backward_op_item, scalar_configs)
process_int_array(backward_op_item, int_array_configs)
# for double grad
if len(backward_op_list) > 1:
(
new_double_grad_op_name,
phi_double_grad_op_name,
double_grad_op_name,
) = get_op_and_op_name(backward_op_list[1])
double_grad_item = backward_op_dict[new_double_grad_op_name]
) = get_phi_and_fluid_op_name(backward_op_list[1])
double_grad_item = backward_op_dict[phi_double_grad_op_name]
backward_op_item['backward'] = double_grad_op_name
double_grad_item['op_name'] = double_grad_op_name
if 'attrs' in op_args:
update_op_attr_name(
double_grad_item['attrs'], op_args['attrs']
)
update_op_attr_name(
double_grad_item['forward']['attrs'], op_args['attrs']
)
process_scalar(double_grad_item, scalar_configs)
process_int_array(double_grad_item, int_array_configs)
update_grad_op_compat_name(double_grad_item, args_map)
update_common_params_name(
double_grad_item,
args_map,
scalar_configs,
int_array_configs,
)
double_grad_item["attr_dict"] = to_named_dict(
double_grad_item["attrs"]
)
# for triple grad
if len(backward_op_list) > 2:
(
new_triple_grad_op_name,
phi_triple_grad_op_name,
triple_grad_op_name,
) = get_op_and_op_name(backward_op_list[2])
triple_grad_item = backward_op_dict[new_triple_grad_op_name]
) = get_phi_and_fluid_op_name(backward_op_list[2])
triple_grad_item = backward_op_dict[phi_triple_grad_op_name]
double_grad_item['backward'] = triple_grad_op_name
triple_grad_item['op_name'] = triple_grad_op_name
if 'attrs' in op_args:
update_op_attr_name(
triple_grad_item['attrs'], op_args['attrs']
)
update_op_attr_name(
triple_grad_item['forward']['attrs'],
op_args['attrs'],
)
process_scalar(triple_grad_item, scalar_configs)
process_int_array(triple_grad_item, int_array_configs)
key_set = ['inputs', 'attrs', 'outputs']
args_map = {}
for key in key_set:
if key in op_args:
args_map.update(op_args[key])
for args_item in forward_op_item[key]:
if args_item['name'] in op_args[key]:
args_item['name'] = op_args[key][args_item['name']]
if has_backward:
for args_item in backward_op_item['forward'][key]:
if args_item['name'] in op_args[key]:
args_item['name'] = op_args[key][args_item['name']]
forward_op_item['infer_meta']['param'] = [
args_map[param] if param in args_map else param
for param in forward_op_item['infer_meta']['param']
]
forward_op_item['kernel']['param'] = [
args_map[param] if param in args_map else param
for param in forward_op_item['kernel']['param']
]
if forward_op_item['kernel']['data_type']:
forward_op_item['kernel']['data_type']['candidates'] = [
args_map[param] if param in args_map else param
for param in forward_op_item['kernel']['data_type'][
'candidates'
]
]
if forward_op_item['kernel']['backend']:
forward_op_item['kernel']['backend']['candidates'] = [
args_map[param] if param in args_map else param
for param in forward_op_item['kernel']['backend']['candidates']
]
if forward_op_item['kernel']['layout']:
forward_op_item['kernel']['layout']['candidates'] = [
args_map[param] if param in args_map else param
for param in forward_op_item['kernel']['layout']['candidates']
]
if forward_op_item['inplace']:
inplace_map = {}
for key, val in forward_op_item['inplace'].items():
if key in args_map:
key = args_map[key]
if val in args_map:
val = args_map[val]
inplace_map[key] = val
forward_op_item['inplace'] = inplace_map
if has_backward:
for args_item in backward_op_item['inputs']:
if args_item['name'] in args_map:
args_item['name'] = args_map[args_item['name']]
elif (
args_item['name'].endswith('_grad')
and args_item['name'][:-5] in args_map
):
args_map[args_item['name']] = (
args_map[args_item['name'][:-5]] + '_grad'
update_grad_op_compat_name(triple_grad_item, args_map)
update_common_params_name(
triple_grad_item,
args_map,
scalar_configs,
int_array_configs,
)
args_item['name'] = args_map[args_item['name']]
for args_item in backward_op_item['attrs']:
if args_item['name'] in args_map:
args_item['name'] = args_map[args_item['name']]
for args_item in backward_op_item['outputs']:
if (
args_item['name'].endswith('_grad')
and args_item['name'][:-5] in args_map
):
args_map[args_item['name']] = (
args_map[args_item['name'][:-5]] + '_grad'
triple_grad_item["attr_dict"] = to_named_dict(
triple_grad_item["attrs"]
)
args_item['name'] = args_map[args_item['name']]
if 'invoke' in backward_op_item:
backward_op_item['invoke']['args'] = [
args_map[param.strip()]
if param.strip() in args_map
else param.strip()
for param in backward_op_item['invoke']['args'].split(',')
]
continue
backward_op_item['infer_meta']['param'] = [
args_map[param] if param in args_map else param
for param in backward_op_item['infer_meta']['param']
]
backward_op_item['kernel']['param'] = [
args_map[param] if param in args_map else param
for param in backward_op_item['kernel']['param']
]
if backward_op_item['kernel']['data_type']:
backward_op_item['kernel']['data_type']['candidates'] = [
args_map[param] if param in args_map else param
for param in backward_op_item['kernel']['data_type'][
'candidates'
]
]
if backward_op_item['kernel']['backend']:
backward_op_item['kernel']['backend']['candidates'] = [
args_map[param] if param in args_map else param
for param in backward_op_item['kernel']['backend'][
'candidates'
]
]
if backward_op_item['kernel']['layout']:
backward_op_item['kernel']['layout']['candidates'] = [
args_map[param] if param in args_map else param
for param in backward_op_item['kernel']['layout'][
'candidates'
]
]
if backward_op_item['no_need_buffer']:
backward_op_item['no_need_buffer'] = [
args_map[param] if param in args_map else param
for param in backward_op_item['no_need_buffer']
]
if backward_op_item['inplace']:
inplace_map = {}
for key, val in backward_op_item['inplace'].items():
if key in args_map:
key = args_map[key]
if val in args_map:
val = args_map[val]
inplace_map[key] = val
backward_op_item['inplace'] = inplace_map
def process_invoke_op(forward_op_dict, backward_op_dict):
......@@ -372,6 +360,7 @@ def process_invoke_op(forward_op_dict, backward_op_dict):
args_index = 0
if invoke_op in forward_op_dict:
reuse_op = forward_op_dict[invoke_op]
bw_op['invoke']['func'] = reuse_op['op_name']
bw_op['invoke']['inputs'] = []
bw_op['invoke']['attrs'] = []
bw_op['invoke']['outputs'] = []
......@@ -430,14 +419,14 @@ def main(
forward_op_dict[op_version['op']]['version'] = op_version['version']
with open(op_compat_yaml_path, "rt") as f:
op_op_map = yaml.safe_load(f)
op_fluid_map_list = yaml.safe_load(f)
for op in ops:
op['op_name'] = op['name']
for bw_op in backward_ops:
bw_op['op_name'] = bw_op['name']
replace_compat_name(op_op_map, forward_op_dict, backward_op_dict)
replace_compat_name(op_fluid_map_list, forward_op_dict, backward_op_dict)
# prepare for invoke case
process_invoke_op(forward_op_dict, backward_op_dict)
......
......@@ -54,6 +54,10 @@ AddOutput({{name | to_opmaker_name}}, "({{typename}}), output {{i}} of {{op_name
.AsIntermediate()
{%- endif %}
{%- if "is_extra" in output and output["is_extra"] %}
.AsExtra()
{%- endif %}
{%- endmacro %}
{# add attribute, and process default value if needed #}
......@@ -115,7 +119,7 @@ KernelSignature {{op["op_name"] | to_pascal_case }}OpArgumentMapping(const Argum
paddle::small_vector<const char*> attrs;
{% for attr in op["attrs"]%}
{% filter indent(2)%}
{{get_an_attr(attr)}}
{{get_an_attr(attr, kernel_args)}}
{% endfilter %}
{% endfor %}
{{get_output_list(op["outputs"], kernel_args)}};
......@@ -170,7 +174,7 @@ KernelSignature {{op["op_name"] | to_pascal_case }}OpArgumentMapping(const Argum
paddle::small_vector<const char*> attrs;
{% for attr in op["attrs"]%}
{% filter indent(2)%}
{{get_an_attr(attr)}}
{{get_an_attr(attr, kernel_args)}}
{% endfilter %}
{% endfor %}
{{get_output_list(op["outputs"], kernel_args)}};
......@@ -209,8 +213,9 @@ paddle::small_vector<const char*> inputs {
}
{%- endmacro %}
{% macro get_an_attr(attr) %}{# inline #}
{% macro get_an_attr(attr, kernel_args) %}{# inline #}
{% set typename = attr["typename"] %}
{%- if attr["name"] in kernel_args %}
{% set name = attr["name"] %}
{% if typename is scalar %}{# scalar correspond to a dispensable input and an attr in opmaker #}
attrs.emplace_back(ctx.HasInput("{{attr | to_scalar_tensor_name}}") ? "{{attr | to_scalar_tensor_name}}" : "{{name}}");
......@@ -236,6 +241,7 @@ attrs.emplace_back(
{%- else %}
attrs.emplace_back("{{name}}");
{%- endif %}
{%- endif %}
{%- endmacro %}
{% macro get_output_list(outputs, kernel_args) %}{# inline #}
......@@ -502,10 +508,9 @@ OutputGrad({{name_in_forward_orig | to_opmaker_name}})
{% set name_in_forward = name[:-5] %}
{% set name_in_forward_orig = input_orig_names[input_names.index(name_in_forward)]%}
InputGrad({{name_in_forward_orig | to_opmaker_name}})
{%- elif (name | to_input_name) in input_names %}
{% set name_in_forward = name | to_input_name %}
{% set name_in_forward_orig = input_orig_names[input_names.index(name_in_forward)]%}
InputGrad({{name | to_input_name | to_opmaker_name}})
{%- elif (name) in input_names %}
{% set name_in_forward_orig = input_orig_names[input_names.index(name)]%}
Input({{name | to_opmaker_name}})
{%- endif %}
{%- endmacro %}
......
......@@ -30,6 +30,13 @@ class PadOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Pad");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Pad");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class PadOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -98,6 +105,14 @@ class PadOpGrad : public framework::OperatorWithKernel {
ctx->SetOutputDim(x_grad_name, dout_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
template <typename T>
......
......@@ -114,11 +114,6 @@ class ReshapeOp : public framework::OperatorWithKernel {
return;
}
PADDLE_ENFORCE_EQ(!shape.empty(),
true,
platform::errors::InvalidArgument(
"The parameter 'shape' in ReshapeOp must be set. "
"But received 'shape' is empty."));
auto x_dims = ctx->GetInputDim("X");
auto out_dims = ValidateShape(shape, x_dims);
ctx->SetOutputDim("Out", out_dims);
......
......@@ -195,17 +195,6 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
class Squeeze2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
template <typename T>
class SqueezeGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
......@@ -220,32 +209,6 @@ class SqueezeGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
class Squeeze2GradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(
context->HasInput("XShape"), "Input", "XShape", "Squeeze2Grad");
OP_INOUT_CHECK(context->HasInput(framework::GradVarName("Out")),
"Input",
framework::GradVarName("Out"),
"Squeeze2Grad");
auto xshape_dims = context->GetInputDim("XShape");
auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
context->SetOutputDim(framework::GradVarName("X"), x_dims);
context->ShareLoD("XShape", framework::GradVarName("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
template <typename T>
class SqueezeDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
......@@ -259,82 +222,6 @@ class SqueezeDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
// FIXME(zcd): squeeze2 adds an intermediate output(XShape) based on squeeze,
// the XShape is used to carry the shape and lod of X which will be used in
// squeeze_grad, in this way, the framework can reuse the memory of X
// immediately the squeeze2_op is finished.
// Considering compatibility issues, we could not fix squeeze2_op
class Squeeze2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor). The input tensor of squeeze operator.");
AddOutput("Out", "(Tensor). The output tensor of squeeze operator.");
AddOutput("XShape",
"XShape is just used to store the shape and lod of X, which will "
"be used in SqueezeGradOp.")
.AsIntermediate()
.AsExtra();
AddAttr<std::vector<int>>("axes",
"(std::vector<int>). List of integers,"
" indicating the dimensions to squeeze.")
.SetDefault({})
.SupportTensor();
AddComment(R"DOC(
Squeeze2 Operator.
Remove single-dimensional entries from the shape of a tensor.
Takes a parameter axes with a list of axes to squeeze.
If axes is not provided, all the single dimensions will be removed from the shape.
If an axis is selected with shape entry not equal to one, an error is raised.
Examples:
Case 1:
Given
X.shape = (1, 3, 1, 5)
and
axes = [0]
we get:
Out.shape = (3, 1, 5)
Case 2:
Given
X.shape = (1, 3, 1, 5)
and
axes = []
we get:
Out.shape = (3, 5)
)DOC");
}
};
template <typename T>
class Squeeze2GradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("squeeze2_grad");
grad_op->SetInput("XShape", this->Output("XShape"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};
template <typename T>
class Squeeze2DoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("squeeze2");
grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
grad_op->SetOutput("XShape", this->Input("XShape"));
grad_op->SetAttrMap(this->Attrs());
}
};
DECLARE_INPLACE_OP_INFERER(SqueezeInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(SqueezeGradInplaceInferer,
{framework::GradVarName("Out"),
......@@ -345,10 +232,6 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(SqueezeGradNoNeedBufferVarsInferer, "X");
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(squeeze2,
SqueezeInferShapeFunctor,
PD_INFER_META(phi::SqueezeWithXShapeInferMeta));
REGISTER_OPERATOR(squeeze,
ops::SqueezeOp,
ops::SqueezeOpMaker,
......@@ -360,19 +243,6 @@ REGISTER_OPERATOR(squeeze_grad,
ops::SqueezeDoubleGradOpMaker<paddle::imperative::OpBase>,
ops::SqueezeGradNoNeedBufferVarsInferer);
REGISTER_OPERATOR(squeeze2,
ops::Squeeze2Op,
ops::Squeeze2OpMaker,
ops::Squeeze2GradOpMaker<paddle::framework::OpDesc>,
ops::Squeeze2GradOpMaker<paddle::imperative::OpBase>,
ops::SqueezeInplaceInferer,
SqueezeInferShapeFunctor);
REGISTER_OPERATOR(squeeze2_grad,
ops::Squeeze2GradOp,
ops::Squeeze2DoubleGradOpMaker<paddle::framework::OpDesc>,
ops::Squeeze2DoubleGradOpMaker<paddle::imperative::OpBase>,
ops::SqueezeGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(
squeeze,
ops::SqueezeKernel<phi::CPUContext, float>,
......
......@@ -260,83 +260,6 @@ class UnsqueezeDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
// FIXME(zcd): unsqueeze2 adds an intermediate output(XShape) based on
// unsqueeze, the XShape is used to carry the shape and lod of X which
// will be used in unsqueeze_grad, in this way, the framework can reuse
// the memory of X immediately the unsqueeze2_op is finished.
// Considering compatibility issues, we could not fix unsqueeze2_op
class Unsqueeze2Op : public UnsqueezeOp {
public:
using UnsqueezeOp::UnsqueezeOp;
};
class Unsqueeze2OpMaker : public UnsqueezeOpMaker {
public:
void Make() override {
UnsqueezeOpMaker::Make();
AddOutput("XShape",
"XShape is just used to store the shape and lod of X, which will "
"be used in UnsqueezeGradOp.")
.AsIntermediate()
.AsExtra();
}
};
template <typename T>
class Unsqueeze2GradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("unsqueeze2_grad");
grad_op->SetInput("XShape", this->Output("XShape"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};
class Unsqueeze2GradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE_EQ(
context->HasInput("XShape"),
true,
platform::errors::InvalidArgument("Input(XShape) shouldn't be null."));
PADDLE_ENFORCE_EQ(context->HasInput(framework::GradVarName("Out")),
true,
platform::errors::InvalidArgument(
"Input(Out@GRAD) shouldn't be null."));
auto xshape_dims = context->GetInputDim("XShape");
auto x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
context->SetOutputDim(framework::GradVarName("X"), x_dims);
context->ShareLoD("XShape", framework::GradVarName("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
template <typename T>
class Unsqueeze2DoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("unsqueeze2");
grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
grad_op->SetOutput("XShape", this->Input("XShape"));
grad_op->SetAttrMap(this->Attrs());
}
};
DECLARE_INPLACE_OP_INFERER(UnsqueezeInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(UnsqueezeGradInplaceInferer,
{framework::GradVarName("Out"),
......@@ -345,10 +268,6 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnsqueezeGradOpNoNeedBufferVarInferer, "X");
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(unsqueeze2,
Unsqueeze2InferShapeFunctor,
PD_INFER_META(phi::UnsqueezeWithXShapeInferMeta));
namespace ops = paddle::operators;
REGISTER_OPERATOR(unsqueeze,
ops::UnsqueezeOp,
......@@ -362,20 +281,6 @@ REGISTER_OPERATOR(unsqueeze_grad,
ops::UnsqueezeDoubleGradOpMaker<paddle::imperative::OpBase>,
ops::UnsqueezeGradOpNoNeedBufferVarInferer);
REGISTER_OPERATOR(unsqueeze2,
ops::Unsqueeze2Op,
ops::Unsqueeze2OpMaker,
ops::Unsqueeze2GradOpMaker<paddle::framework::OpDesc>,
ops::Unsqueeze2GradOpMaker<paddle::imperative::OpBase>,
Unsqueeze2InferShapeFunctor,
ops::UnsqueezeInplaceInferer);
REGISTER_OPERATOR(unsqueeze2_grad,
ops::Unsqueeze2GradOp,
ops::Unsqueeze2DoubleGradOpMaker<paddle::framework::OpDesc>,
ops::Unsqueeze2DoubleGradOpMaker<paddle::imperative::OpBase>,
ops::UnsqueezeGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(
unsqueeze,
ops::UnsqueezeKernel<phi::CPUContext, float>,
......
......@@ -646,6 +646,7 @@ void BindAnalysisConfig(py::module *m) {
py::arg("memory_pool_init_size_mb"),
py::arg("device_id") = 0,
py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32)
.def("exp_enable_use_cutlass", &AnalysisConfig::Exp_EnableUseCutlass)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
.def("set_exec_stream",
[](AnalysisConfig &self, phi::CUDAStream &stream) {
......
......@@ -44,11 +44,7 @@ set(PHI_DEPS
get_property(phi_kernels GLOBAL PROPERTY PHI_KERNELS)
set(PHI_DEPS ${PHI_DEPS} ${phi_kernels})
if(APPLE AND WITH_ARM)
cc_library(phi DEPS ${PHI_DEPS})
else()
create_dummy_static_lib(phi LIBS ${PHI_DEPS} LIMIT 100)
endif()
cc_library(phi DEPS ${PHI_DEPS})
set(phi_extension_header_file
${CMAKE_CURRENT_SOURCE_DIR}/extension.h
......
......@@ -19,7 +19,7 @@ limitations under the License. */
// Note(chenweihang): In order to be compatible with the original custom
// operator Tensor interface, only available to external users, the file
// cannot be includeed in paddle
// cannot be included in paddle
namespace paddle {
using Tensor = experimental::Tensor;
......
......@@ -1186,6 +1186,26 @@
backward : square_double_grad
inplace : (out_grad -> x_grad)
- backward_op : squeeze_double_grad
forward : squeeze_grad(Tensor xshape, Tensor grad_out, IntArray axis) -> Tensor(grad_x)
args : (Tensor grad_x_grad, IntArray axis)
output : Tensor(grad_out_grad), Tensor(xshape)
invoke: squeeze(grad_x_grad, axis)
intermediate : xshape
- backward_op : squeeze_grad
forward : squeeze(Tensor x, IntArray axis) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad, IntArray axis)
output : Tensor(x_grad)
infer_meta :
func : KernelWithXShapeInferMeta
param: [xshape]
kernel :
func : squeeze_grad
data_type : out_grad
inplace : (out_grad -> x_grad)
backward: squeeze_double_grad
- backward_op : svd_grad
forward : svd (Tensor x, bool full_matrices = false) -> Tensor(u), Tensor(s), Tensor(vh)
args : (Tensor x, Tensor u, Tensor vh, Tensor s, Tensor u_grad, Tensor vh_grad, Tensor s_grad, bool full_matrices)
......@@ -1321,6 +1341,27 @@
data_type : out_grad
no_need_buffer : x
- backward_op : unsqueeze_double_grad
forward : unsqueeze_grad(Tensor xshape, Tensor grad_out, IntArray axes) -> Tensor(grad_x)
args : (Tensor grad_x_grad, IntArray axes)
output : Tensor(grad_out_grad), Tensor(xshape)
invoke : unsqueeze(grad_x_grad, axes)
intermediate : xshape
- backward_op : unsqueeze_grad
forward : unsqueeze(Tensor x, IntArray axes) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad, IntArray axes)
output : Tensor(x_grad)
infer_meta :
func : KernelWithXShapeInferMeta
param: [xshape]
kernel :
func : unsqueeze_grad
param : [xshape, out_grad]
data_type : out_grad
inplace : (out_grad -> x_grad)
backward : unsqueeze_double_grad
- backward_op : unstack_grad
forward : unstack (Tensor x, int axis=0, int num=0) -> Tensor[](out)
args : (Tensor[] out_grad, int axis)
......
......@@ -1363,24 +1363,6 @@
kernel :
func : squared_l2_norm_grad
- backward_op : squeeze_double_grad
forward : squeeze_grad(Tensor xshape, Tensor grad_out, IntArray axis) -> Tensor(grad_x)
args : (Tensor grad_x_grad, IntArray axis)
output : Tensor(grad_out_grad)
invoke: squeeze(grad_x_grad, axis)
- backward_op : squeeze_grad
forward : squeeze(Tensor x, IntArray axis) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad, IntArray axis)
output : Tensor(x_grad)
infer_meta :
func : KernelWithXShapeInferMeta
param: [xshape]
kernel :
func : squeeze_grad
inplace : (out_grad -> x_grad)
backward: squeeze_double_grad
- backward_op : stack_grad
forward : stack (Tensor[] x, int axis) -> Tensor(out)
args : (Tensor[] x, Tensor out_grad, int axis)
......@@ -1574,25 +1556,6 @@
func : uniform_inplace_grad
inplace : (out_grad -> x_grad)
- backward_op : unsqueeze_double_grad
forward : unsqueeze_grad(Tensor xshape, Tensor grad_out, IntArray axes) -> Tensor(grad_x)
args : (Tensor grad_x_grad, IntArray axes)
output : Tensor(grad_out_grad)
invoke : unsqueeze(grad_x_grad, axes)
- backward_op : unsqueeze_grad
forward : unsqueeze(Tensor x, IntArray axes) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad, IntArray axes)
output : Tensor(x_grad)
infer_meta :
func : KernelWithXShapeInferMeta
param: [xshape]
kernel :
func : unsqueeze_grad
param: [xshape, out_grad]
inplace : (out_grad -> x_grad)
backward : unsqueeze_double_grad
- backward_op : warpctc_grad
forward : warpctc (Tensor logits, Tensor label, Tensor logits_length, Tensor labels_length, int blank, bool norm_by_times) -> Tensor(loss), Tensor(warpctcgrad)
args : (Tensor logits, Tensor logits_length, Tensor warpctcgrad, Tensor loss_grad, int blank, bool norm_by_times)
......
......@@ -1777,18 +1777,6 @@
func : squared_l2_norm
backward : squared_l2_norm_grad
- op : squeeze
args : (Tensor x, IntArray axis)
output : Tensor(out), Tensor(xshape)
infer_meta :
func : SqueezeWithXShapeInferMeta
kernel :
func : squeeze_with_xshape
inplace : (x -> out)
view: (x -> out)
intermediate : xshape
backward : squeeze_grad
- op : stack
args : (Tensor[] x, int axis)
output : Tensor
......@@ -2022,18 +2010,6 @@
data_type: x
backward: unpool3d_grad
- op : unsqueeze
args : (Tensor x, IntArray axis)
output : Tensor(out), Tensor(xshape)
infer_meta :
func : UnsqueezeWithXShapeInferMeta
kernel :
func : unsqueeze_with_xshape
inplace : (x -> out)
view: (x -> out)
intermediate : xshape
backward : unsqueeze_grad
- op : update_loss_scaling_
args : (Tensor[] x, Tensor found_infinite, Tensor prev_loss_scaling, Tensor in_good_steps, Tensor in_bad_steps, int incr_every_n_steps, int decr_every_n_nan_or_inf, float incr_ratio, float decr_ratio, Scalar stop_update)
output : Tensor[](out){x.size()}, Tensor(loss_scaling), Tensor(out_good_steps), Tensor(out_bad_steps)
......
......@@ -1270,9 +1270,20 @@
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- op : squeeze (squeeze2)
backward : squeeze_grad (squeeze2_grad)
backward : squeeze_grad (squeeze2_grad), squeeze_double_grad(squeeze2_double_grad)
inputs :
x : X
attrs :
axis : axes
outputs :
{out : Out, xshape : XShape}
int_array:
axis :
data_type : int
support_tensor : true
extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32"]
outputs : [xshape]
- op : stack
backward : stack_grad
......@@ -1389,6 +1400,22 @@
outputs :
out : Y
- op : unsqueeze (unsqueeze2)
backward : unsqueeze_grad (unsqueeze2_grad), unsqueeze_double_grad(unsqueeze2_double_grad)
inputs :
x : X
attrs :
axis : axes
outputs :
{out : Out, xshape : XShape}
int_array:
axis :
data_type : int
tensor_name : AxesTensor
tensors_name : AxesTensorList
extra :
outputs : [xshape]
- op : unstack
backward : unstack_grad
inputs :
......
......@@ -1054,6 +1054,19 @@
square_sr {selected_rows -> selected_rows}
backward : square_grad
- op : squeeze
args : (Tensor x, IntArray axis={})
output : Tensor(out), Tensor(xshape)
infer_meta :
func : SqueezeWithXShapeInferMeta
kernel :
func : squeeze_with_xshape
data_type : x
inplace : (x -> out)
view: (x -> out)
intermediate : xshape
backward : squeeze_grad
- op : svd
args : (Tensor x, bool full_matrices = false)
output : Tensor(u), Tensor(s), Tensor(vh)
......@@ -1149,6 +1162,19 @@
func : unfold
backward : unfold_grad
- op : unsqueeze
args : (Tensor x, IntArray axis = {})
output : Tensor(out), Tensor(xshape)
infer_meta :
func : UnsqueezeWithXShapeInferMeta
kernel :
func : unsqueeze_with_xshape
data_type : x
inplace : (x -> out)
view: (x -> out)
intermediate : xshape
backward : unsqueeze_grad
- op : unstack
args : (Tensor x, int axis=0, int num=0)
output : Tensor[](out){num}
......
......@@ -917,9 +917,6 @@ void ExpandInferMeta(const MetaTensor& x,
auto out_rank =
std::max(static_cast<size_t>(x_dims.size()), expand_shape.size());
std::vector<int64_t> out_shape(out_rank);
auto x_dim_vec = phi::vectorize<int>(x_dims);
auto diff = expand_shape.size() - x_dim_vec.size();
x_dim_vec.insert(x_dim_vec.begin(), diff, -1);
for (size_t i = 0; i < expand_shape.size(); ++i) {
if (x_dims[i] == -1) {
out_shape[i] = -1;
......
......@@ -106,8 +106,7 @@ file(
"fusion/gpu/*.cu")
if(WITH_CUTLASS)
file(GLOB cutlass_cu "fusion/cutlass/default_moe_fc_traits.h"
"fusion/cutlass/linear_combination_ft_gelu.h" "fusion/cutlass/moe*")
file(GLOB cutlass_cu "fusion/cutlass/conv2d/*.cu" "fusion/cutlass/*.cu")
list(APPEND kernel_cu ${cutlass_cu})
endif()
......
......@@ -1023,15 +1023,20 @@ void BroadcastKernel(const KPDevice &ctx,
std::vector<DenseTensor *> *outs,
int axis,
Functor func) {
std::vector<int> dims_size;
dims_size.reserve(ins.size());
// When there are multiple inputs, the outputs's rank should be equal the
// maximum rank of all inputs.
int max_rank = 0;
int min_rank = phi::DDim::kMaxRank;
for (auto *in : ins) {
dims_size.emplace_back(in->dims().size());
max_rank = std::max(max_rank, in->dims().size());
min_rank = std::min(min_rank, in->dims().size());
}
axis = axis == -1 ? *std::max_element(dims_size.begin(), dims_size.end()) -
*std::min_element(dims_size.begin(), dims_size.end())
: axis;
if (ins.size() == 1) {
// When there is only 1 input, the input's rank may be less than outputs'
// rank.
max_rank = std::max(max_rank, (*outs)[0]->dims().size());
}
axis = axis == -1 ? max_rank - min_rank : axis;
BroadcastKernelForDifferentVecSize<ET, InT, OutT, Functor, NumOuts>(
ctx, ins, outs, axis, func);
}
......
......@@ -25,8 +25,8 @@ struct BroadcastDimsSimplifier {
typedef void (*MergeFunctor)(
bool &, std::vector<DimVector> &, DimVector &, int, int);
int64_t N;
int64_t rank;
int N;
int rank;
DimVector out_dims;
std::vector<DimVector> in_dims;
......@@ -103,41 +103,43 @@ struct BroadcastDimsSimplifier {
// To compensate the lackage of input_tensors' dimension with axis.
void ExtendInputDimensions(int N, int axis) {
for (auto &in_dim : in_dims) {
int64_t in_idx = 0;
if (in_dim.size() < rank) {
DimVector tmp_dim(rank, 1);
for (; in_idx < in_dim.size();) {
if (in_dim[in_idx] == out_dims[axis] || in_dim[in_idx] == 1) {
tmp_dim[axis] = in_dim[in_idx];
in_idx++;
axis++;
DimVector extended_in_dim(rank, 1);
int out_idx = axis;
for (int in_idx = 0; in_idx < in_dim.size(); in_idx++) {
if (in_dim[in_idx] == out_dims[out_idx] || in_dim[in_idx] == 1) {
extended_in_dim[out_idx] = in_dim[in_idx];
out_idx++;
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"The %d-th dimension of input tensor is expected to be equal "
"with the %d-th dimension of output tensor %d or 1, but "
"received %d.",
in_idx + 1,
axis + 1,
"received %d. The input's shape is {%s}, the output's shape is "
"{%s}.",
in_idx,
out_idx,
out_dims[axis],
in_dim[in_idx]));
in_dim[in_idx],
phi::make_ddim(in_dim),
phi::make_ddim(out_dims)));
}
}
in_dim.resize(rank);
std::copy(tmp_dim.begin(), tmp_dim.end(), in_dim.begin());
std::copy(
extended_in_dim.begin(), extended_in_dim.end(), in_dim.begin());
} else {
for (; in_idx < rank;) {
if (in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1) {
in_idx++;
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"The %d-th dimension of input tensor is expected to be equal "
"with the %d-th dimension of output tensor %d or 1, but "
"received %d.",
in_idx + 1,
in_idx + 1,
out_dims[in_idx],
in_dim[in_idx]));
}
for (int in_idx = 0; in_idx < rank; in_idx++) {
PADDLE_ENFORCE_EQ(
in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1,
true,
phi::errors::InvalidArgument(
"The %d-th dimension of input tensor is expected to be equal "
"with the %d-th dimension of output tensor %d or 1, but "
"received %d.",
in_idx,
in_idx,
out_dims[in_idx],
in_dim[in_idx]));
}
}
std::reverse(in_dim.begin(), in_dim.end());
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <mutex>
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
#include "cutlass/epilogue/thread/linear_combination_bias_relu.h"
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
template <typename TShape, typename WShape, int Alignment = 8>
cutlass::Status Conv2dBiasImpl(ConvAllParams params) {
using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm75;
using ThreadblockShape = TShape;
using WarpShape = WShape;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using SwizzleThreadBlock =
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>;
constexpr int NumStages = 2;
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm =
cutlass::conv::IteratorAlgorithm::kOptimized;
using EpilogueOp =
cutlass::epilogue::thread::LinearCombination<ElementOutput,
Alignment,
float,
ElementComputeEpilogue>;
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm,
cutlass::conv::StrideSupport::kStrided,
Alignment,
Alignment>::Kernel;
using ImplicitGemm =
cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
const half *input = params.input;
const half *weight = params.weight;
const half *bias = params.bias;
half *output = params.output;
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
int oh = params.oh;
int ow = params.ow;
int dilation_h = params.dilation_h;
int dilation_w = params.dilation_w;
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic},
{oc, kh, kw, ic},
{pad_h0, 0, pad_w0, 0},
{stride_h, stride_w},
{dilation_h, dilation_w},
{batch, oh, ow, oc},
mode,
1);
typename ImplicitGemm::Arguments arguments{
problem_size,
{(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}},
{(cutlass::half_t *)(weight), {ic, ic * kw, ic * kw * kh}},
{(cutlass::half_t *)(bias), {0, 0, 0}},
{(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}},
{1.f, 1.f}};
ImplicitGemm implicit_gemm_op;
size_t bytes = implicit_gemm_op.get_workspace_size(arguments);
auto ctx = params.ctx;
auto stream = ctx->stream();
paddle::memory::allocation::AllocationPtr tmp_gpu_ptrs_data =
paddle::memory::Alloc(
ctx->GetPlace(),
bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
void *workspace = tmp_gpu_ptrs_data->ptr();
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
status = implicit_gemm_op.initialize(arguments, workspace);
CUTLASS_CHECK(status);
status = implicit_gemm_op(stream);
CUTLASS_CHECK(status);
return status;
}
// config 0
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(
ConvAllParams);
// config 1
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(
ConvAllParams);
// config 2
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(
ConvAllParams);
// config 3
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(
ConvAllParams);
// config 4
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>(
ConvAllParams);
// config 5
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>(
ConvAllParams);
// config 6
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>(
ConvAllParams);
// config 7
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(
ConvAllParams);
// config 8
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>(
ConvAllParams);
std::vector<std::function<cutlass::Status(ConvAllParams)>>
conv2d_bias_all_func = {
Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>,
Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>,
Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>};
std::map<std::vector<int>, int> map_problem_conv2d_bias;
std::mutex conv2d_bias_mutex;
void Conv2dBias(ConvAllParams params) {
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
std::vector<int> problem_size = {
batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w};
if (map_problem_conv2d_bias.count(problem_size)) {
conv2d_bias_all_func[map_problem_conv2d_bias.at(problem_size)](params);
return;
}
int best_config_index =
ProfileToGetBestConfig(conv2d_bias_all_func, params, CONV2D_BIAS);
std::lock_guard<std::mutex> guard(conv2d_bias_mutex);
map_problem_conv2d_bias[problem_size] = best_config_index;
conv2d_bias_all_func[best_config_index](params);
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <mutex>
#include "cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h"
#include "cutlass/epilogue/thread/linear_combination_residual_block.h"
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
template <typename TShape, typename WShape, int Alignment = 8>
cutlass::Status Conv2dBiasAddReluImpl(ConvAllParams params) {
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationResidualBlock<
cutlass::half_t,
float,
float,
cutlass::half_t,
Alignment,
cutlass::epilogue::thread::Identity,
cutlass::plus,
cutlass::epilogue::thread::ReLu>;
using Conv2dFpropKernel =
typename cutlass::conv::kernel::DefaultConv2dFpropWithBroadcast<
cutlass::half_t,
cutlass::layout::TensorNHWC,
cutlass::half_t,
cutlass::layout::TensorNHWC,
cutlass::half_t,
cutlass::layout::TensorNHWC,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
TShape,
WShape,
cutlass::gemm::GemmShape<16, 8, 8>,
EpilogueOp,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>,
2,
cutlass::arch::OpMultiplyAdd,
cutlass::conv::IteratorAlgorithm::kOptimized,
cutlass::conv::StrideSupport::kStrided,
Alignment,
Alignment>::Kernel;
using ImplicitGemm =
cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
const half *input = params.input;
const half *weight = params.weight;
const half *bias = params.bias;
half *output = params.output;
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
const half *residual = params.residual;
int oh = params.oh;
int ow = params.ow;
int dilation_h = params.dilation_h;
int dilation_w = params.dilation_w;
cutlass::conv::Conv2dProblemSize problem_size(
{batch, ih, iw, ic},
{oc, kh, kw, ic},
{pad_h0, 0, pad_w0, 0},
{stride_h, stride_w},
{dilation_h, dilation_w},
{batch, oh, ow, oc},
cutlass::conv::Mode::kCrossCorrelation,
1);
typename ImplicitGemm::Arguments arguments{
problem_size,
{(cutlass::half_t *)input, {ic, ic * iw, ic * iw * ih}},
{(cutlass::half_t *)weight, {ic, ic * kw, ic * kw * kh}},
{(cutlass::half_t *)residual, {oc, oc * ow, oc * ow * oh}},
{(cutlass::half_t *)output, {oc, oc * ow, oc * ow * oh}},
{1.f, 1.f},
cutlass::conv::SplitKMode::kSerial,
(cutlass::half_t *)(bias),
nullptr,
0,
oc};
ImplicitGemm implicit_gemm_op;
size_t bytes = implicit_gemm_op.get_workspace_size(arguments);
auto ctx = params.ctx;
auto stream = ctx->stream();
paddle::memory::allocation::AllocationPtr tmp_gpu_ptrs_data =
paddle::memory::Alloc(
ctx->GetPlace(),
bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
void *workspace = tmp_gpu_ptrs_data->ptr();
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
status = implicit_gemm_op.initialize(arguments, workspace);
CUTLASS_CHECK(status);
status = implicit_gemm_op(stream);
CUTLASS_CHECK(status);
return status;
}
// config 0
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 1
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 2
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 3
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 4
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>(ConvAllParams);
// config 5
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>(ConvAllParams);
// config 6
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 7
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 8
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>(ConvAllParams);
// config 9
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 10
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 11
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<256, 64, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 12
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<256, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
std::vector<std::function<cutlass::Status(ConvAllParams)>>
conv2d_bias_add_relu_all_func = {
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<256, 64, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<256, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>};
std::map<std::vector<int>, int> map_problem_conv2d_bias_add_relu;
std::mutex conv2d_bias_add_relu_mutex;
void Conv2dBiasAddRelu(ConvAllParams params) {
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
std::vector<int> problem_size = {
batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w};
if (map_problem_conv2d_bias_add_relu.count(problem_size)) {
conv2d_bias_add_relu_all_func[map_problem_conv2d_bias_add_relu.at(
problem_size)](params);
return;
}
std::lock_guard<std::mutex> guard(conv2d_bias_add_relu_mutex);
// config 6's diff is large.
conv2d_bias_add_relu_all_func[6] = nullptr;
int best_config_index = ProfileToGetBestConfig(
conv2d_bias_add_relu_all_func, params, CONV2D_BIAS_ADD_RELU);
map_problem_conv2d_bias_add_relu[problem_size] = best_config_index;
conv2d_bias_add_relu_all_func[best_config_index](params);
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <mutex>
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
#include "cutlass/epilogue/thread/linear_combination_leaky_relu.h"
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
template <typename TShape, typename WShape, int Alignment = 8>
cutlass::Status Conv2dBiasLeakyReluImpl(ConvAllParams params) {
using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm75;
using ThreadblockShape = TShape;
using WarpShape = WShape;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using SwizzleThreadBlock =
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>;
constexpr int NumStages = 2;
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm =
cutlass::conv::IteratorAlgorithm::kOptimized;
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationLeakyRelu<
ElementOutput,
Alignment,
float,
ElementComputeEpilogue>;
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm,
cutlass::conv::StrideSupport::kStrided,
Alignment,
Alignment>::Kernel;
using ImplicitGemm =
cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
const half *input = params.input;
const half *weight = params.weight;
const half *bias = params.bias;
half *output = params.output;
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
float alpha = params.alpha;
int oh = params.oh;
int ow = params.ow;
int dilation_h = params.dilation_h;
int dilation_w = params.dilation_w;
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic},
{oc, kh, kw, ic},
{pad_h0, 0, pad_w0, 0},
{stride_h, stride_w},
{dilation_h, dilation_w},
{batch, oh, ow, oc},
mode,
1);
typename ImplicitGemm::Arguments arguments{
problem_size,
{(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}},
{(cutlass::half_t *)(weight), {ic, ic * kw, ic * kw * kh}},
{(cutlass::half_t *)(bias), {0, 0, 0}},
{(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}},
{1.f, 1.f, alpha}};
ImplicitGemm implicit_gemm_op;
size_t bytes = implicit_gemm_op.get_workspace_size(arguments);
auto ctx = params.ctx;
auto stream = ctx->stream();
paddle::memory::allocation::AllocationPtr tmp_gpu_ptrs_data =
paddle::memory::Alloc(
ctx->GetPlace(),
bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
void *workspace = tmp_gpu_ptrs_data->ptr();
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
status = implicit_gemm_op.initialize(arguments, workspace);
CUTLASS_CHECK(status);
status = implicit_gemm_op(stream);
CUTLASS_CHECK(status);
return status;
}
// config 0
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 1
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 2
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 3
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 4
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>(ConvAllParams);
// config 5
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>(ConvAllParams);
// config 6
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 7
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 8
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>(ConvAllParams);
std::vector<std::function<cutlass::Status(ConvAllParams)>>
conv2d_bias_leaky_relu_all_func = {
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>,
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>,
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>};
std::map<std::vector<int>, int> map_problem_conv2d_bias_leaky_relu;
std::mutex conv2d_bias_leaky_relu_mutex;
void Conv2dBiasLeakyRelu(ConvAllParams params) {
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
std::vector<int> problem_size = {
batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w};
if (map_problem_conv2d_bias_leaky_relu.count(problem_size)) {
conv2d_bias_leaky_relu_all_func[map_problem_conv2d_bias_leaky_relu.at(
problem_size)](params);
return;
}
int best_config_index = ProfileToGetBestConfig(
conv2d_bias_leaky_relu_all_func, params, CONV2D_BIAS_LEAKY_RELU);
std::lock_guard<std::mutex> guard(conv2d_bias_leaky_relu_mutex);
map_problem_conv2d_bias_leaky_relu[problem_size] = best_config_index;
conv2d_bias_leaky_relu_all_func[best_config_index](params);
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <mutex>
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
#include "cutlass/epilogue/thread/linear_combination_bias_relu.h"
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
template <typename TShape, typename WShape, int Alignment = 8>
cutlass::Status Conv2dBiasReluImpl(ConvAllParams params) {
using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm75;
using ThreadblockShape = TShape;
using WarpShape = WShape;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using SwizzleThreadBlock =
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>;
constexpr int NumStages = 2;
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm =
cutlass::conv::IteratorAlgorithm::kOptimized;
using EpilogueOp =
cutlass::epilogue::thread::LinearCombinationRelu<ElementOutput,
Alignment,
float,
ElementComputeEpilogue>;
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm,
cutlass::conv::StrideSupport::kStrided,
Alignment,
Alignment>::Kernel;
using ImplicitGemm =
cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
const half *input = params.input;
const half *weight = params.weight;
const half *bias = params.bias;
half *output = params.output;
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
int oh = params.oh;
int ow = params.ow;
int dilation_h = params.dilation_h;
int dilation_w = params.dilation_w;
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic},
{oc, kh, kw, ic},
{pad_h0, 0, pad_w0, 0},
{stride_h, stride_w},
{dilation_h, dilation_w},
{batch, oh, ow, oc},
mode,
1);
typename ImplicitGemm::Arguments arguments{
problem_size,
{(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}},
{(cutlass::half_t *)(weight), {ic, ic * kw, ic * kw * kh}},
{(cutlass::half_t *)(bias), {0, 0, 0}},
{(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}},
{1.f, 1.f}};
ImplicitGemm implicit_gemm_op;
size_t bytes = implicit_gemm_op.get_workspace_size(arguments);
auto ctx = params.ctx;
auto stream = ctx->stream();
paddle::memory::allocation::AllocationPtr tmp_gpu_ptrs_data =
paddle::memory::Alloc(
ctx->GetPlace(),
bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
void *workspace = tmp_gpu_ptrs_data->ptr();
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
status = implicit_gemm_op.initialize(arguments, workspace);
CUTLASS_CHECK(status);
status = implicit_gemm_op(stream);
CUTLASS_CHECK(status);
return status;
}
// config 0
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 1
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 2
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 3
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 4
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>(ConvAllParams);
// config 5
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>(ConvAllParams);
// config 6
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 7
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 8
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>(ConvAllParams);
std::vector<std::function<cutlass::Status(ConvAllParams)>>
conv2d_bias_relu_all_func = {
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>,
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>,
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>};
std::map<std::vector<int>, int> map_problem_conv2d_bias_relu;
std::mutex conv2d_bias_relu_mutex;
void Conv2dBiasRelu(ConvAllParams params) {
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
std::vector<int> problem_size = {
batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w};
if (map_problem_conv2d_bias_relu.count(problem_size)) {
conv2d_bias_relu_all_func[map_problem_conv2d_bias_relu.at(problem_size)](
params);
return;
}
int best_config_index = ProfileToGetBestConfig(
conv2d_bias_relu_all_func, params, CONV2D_BIAS_RELU);
std::lock_guard<std::mutex> guard(conv2d_bias_relu_mutex);
map_problem_conv2d_bias_relu[problem_size] = best_config_index;
conv2d_bias_relu_all_func[best_config_index](params);
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <mutex>
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
#include "cutlass/epilogue/thread/linear_combination_bias_relu.h"
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
template <typename TShape, typename WShape, int Alignment = 1>
cutlass::Status Conv2dBiasReluFewChannelsImpl(ConvAllParams params) {
using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm75;
using ThreadblockShape = TShape;
using WarpShape = WShape;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using SwizzleThreadBlock =
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>;
constexpr int NumStages = 2;
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm =
cutlass::conv::IteratorAlgorithm::kFewChannels;
using EpilogueOp =
cutlass::epilogue::thread::LinearCombinationRelu<ElementOutput,
Alignment,
float,
ElementComputeEpilogue>;
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm,
cutlass::conv::StrideSupport::kStrided,
Alignment,
Alignment>::Kernel;
using ImplicitGemm =
cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
const half *input = params.input;
const half *weight = params.weight;
const half *bias = params.bias;
half *output = params.output;
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w1;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
int oh = params.oh;
int ow = params.ow;
int dilation_h = params.dilation_h;
int dilation_w = params.dilation_w;
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic},
{oc, kh, kw, ic},
{pad_h0, 0, pad_w0, 0},
{stride_h, stride_w},
{dilation_h, dilation_w},
{batch, oh, ow, oc},
mode,
1);
typename ImplicitGemm::Arguments arguments{
problem_size,
{(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}},
{(cutlass::half_t *)(weight), {ic, ic * kw, ic * kw * kh}},
{(cutlass::half_t *)(bias), {0, 0, 0}},
{(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}},
{1.f, 1.f}};
ImplicitGemm implicit_gemm_op;
size_t bytes = implicit_gemm_op.get_workspace_size(arguments);
auto ctx = params.ctx;
auto stream = ctx->stream();
paddle::memory::allocation::AllocationPtr tmp_gpu_ptrs_data =
paddle::memory::Alloc(
ctx->GetPlace(),
bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
void *workspace = tmp_gpu_ptrs_data->ptr();
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
status = implicit_gemm_op.initialize(arguments, workspace);
CUTLASS_CHECK(status);
status = implicit_gemm_op(stream);
CUTLASS_CHECK(status);
return status;
}
// config 0
template cutlass::Status Conv2dBiasReluFewChannelsImpl<
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 1
template cutlass::Status Conv2dBiasReluFewChannelsImpl<
cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 2
template cutlass::Status Conv2dBiasReluFewChannelsImpl<
cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 3
template cutlass::Status Conv2dBiasReluFewChannelsImpl<
cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 4
template cutlass::Status Conv2dBiasReluFewChannelsImpl<
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>(ConvAllParams);
// config 5
template cutlass::Status Conv2dBiasReluFewChannelsImpl<
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>(ConvAllParams);
// config 6
template cutlass::Status Conv2dBiasReluFewChannelsImpl<
cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 7
template cutlass::Status Conv2dBiasReluFewChannelsImpl<
cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 8
template cutlass::Status Conv2dBiasReluFewChannelsImpl<
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>(ConvAllParams);
std::vector<std::function<cutlass::Status(ConvAllParams)>>
conv2d_bias_relu_few_channels_all_func = {
Conv2dBiasReluFewChannelsImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasReluFewChannelsImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasReluFewChannelsImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasReluFewChannelsImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasReluFewChannelsImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>,
Conv2dBiasReluFewChannelsImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>,
Conv2dBiasReluFewChannelsImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasReluFewChannelsImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasReluFewChannelsImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>};
std::map<std::vector<int>, int> map_problem_conv2d_bias_relu_few_channels;
void Conv2dBiasReluFewChannels(ConvAllParams params) {
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w1;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
std::vector<int> problem_size = {
batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w};
if (map_problem_conv2d_bias_relu_few_channels.count(problem_size)) {
conv2d_bias_relu_few_channels_all_func
[map_problem_conv2d_bias_relu_few_channels.at(problem_size)](params);
return;
}
//
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <mutex>
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
#include "cutlass/epilogue/thread/linear_combination_silu.h"
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
template <typename TShape, typename WShape, int Alignment = 8>
cutlass::Status Conv2dBiasSiluImpl(ConvAllParams params) {
using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm75;
using ThreadblockShape = TShape;
using WarpShape = WShape;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using SwizzleThreadBlock =
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>;
constexpr int NumStages = 2;
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm =
cutlass::conv::IteratorAlgorithm::kOptimized;
using EpilogueOp =
cutlass::epilogue::thread::LinearCombinationSilu<ElementOutput,
Alignment,
float,
ElementComputeEpilogue>;
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm,
cutlass::conv::StrideSupport::kStrided,
Alignment,
Alignment>::Kernel;
using ImplicitGemm =
cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
const half *input = params.input;
const half *weight = params.weight;
const half *bias = params.bias;
half *output = params.output;
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
int oh = params.oh;
int ow = params.ow;
int dilation_h = params.dilation_h;
int dilation_w = params.dilation_w;
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic},
{oc, kh, kw, ic},
{pad_h0, 0, pad_w0, 0},
{stride_h, stride_w},
{dilation_h, dilation_w},
{batch, oh, ow, oc},
mode,
1);
typename ImplicitGemm::Arguments arguments{
problem_size,
{(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}},
{(cutlass::half_t *)(weight), {ic, ic * kw, ic * kw * kh}},
{(cutlass::half_t *)(bias), {0, 0, 0}},
{(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}},
{1.f, 1.f}};
ImplicitGemm implicit_gemm_op;
size_t bytes = implicit_gemm_op.get_workspace_size(arguments);
auto ctx = params.ctx;
auto stream = ctx->stream();
paddle::memory::allocation::AllocationPtr tmp_gpu_ptrs_data =
paddle::memory::Alloc(
ctx->GetPlace(),
bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
void *workspace = tmp_gpu_ptrs_data->ptr();
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
status = implicit_gemm_op.initialize(arguments, workspace);
CUTLASS_CHECK(status);
status = implicit_gemm_op(stream);
CUTLASS_CHECK(status);
return status;
}
// config 0
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 1
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 2
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 3
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 4
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>(ConvAllParams);
// config 5
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>(ConvAllParams);
// config 6
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 7
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 8
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>(ConvAllParams);
std::vector<std::function<cutlass::Status(ConvAllParams)>>
conv2d_bias_silu_all_func = {
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>,
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>,
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>};
std::map<std::vector<int>, int> map_problem_conv2d_bias_silu;
std::mutex conv2d_bias_silu_mutex;
void Conv2dBiasSilu(ConvAllParams params) {
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
std::vector<int> problem_size = {
batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w};
if (map_problem_conv2d_bias_silu.count(problem_size)) {
conv2d_bias_silu_all_func[map_problem_conv2d_bias_silu.at(problem_size)](
params);
return;
}
int best_config_index = ProfileToGetBestConfig(
conv2d_bias_silu_all_func, params, CONV2D_BIAS_SILU);
std::lock_guard<std::mutex> guard(conv2d_bias_silu_mutex);
map_problem_conv2d_bias_silu[problem_size] = best_config_index;
conv2d_bias_silu_all_func[best_config_index](params);
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
......@@ -12,36 +11,51 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda_fp16.h>
#include <glog/logging.h>
#include <map>
#include <vector>
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
KernelSignature UnsqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.InputSize("AxesTensorList") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensorList";
return KernelSignature(
"unsqueeze_with_xshape", {"X"}, {"AxesTensorList"}, {"Out", "XShape"});
} else if (ctx.InputSize("AxesTensor") > 0) {
VLOG(2) << "unsqueeze2 in AxesTensor";
return KernelSignature(
"unsqueeze_with_xshape", {"X"}, {"AxesTensor"}, {"Out", "XShape"});
} else {
VLOG(2) << "unsqueeze2 in axes";
return KernelSignature(
"unsqueeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"});
}
}
typedef struct {
const half *input;
const half *weight;
const half *bias;
const half *residual;
half *output;
int batch;
int ic;
int ih;
int iw;
int kh;
int kw;
int oc;
int pad_h0;
int pad_h1;
int pad_w0;
int pad_w1;
int stride_h;
int stride_w;
int dilation_h;
int dilation_w;
int oh;
int ow;
const phi::GPUContext *ctx;
float alpha; // for leaky_relu use
} ConvAllParams;
KernelSignature UnsqueezeGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"unsqueeze_grad", {"XShape", "Out@GRAD"}, {}, {"X@GRAD"});
}
// Below functions are provided by cutlass, they are called by phi.
void Conv2dBiasAddRelu(ConvAllParams params);
void Conv2dBiasRelu(ConvAllParams params);
void Conv2dBiasLeakyRelu(ConvAllParams params);
void Conv2dBiasSilu(ConvAllParams params);
void Conv2dBias(ConvAllParams params);
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(unsqueeze2, unsqueeze);
PD_REGISTER_BASE_KERNEL_NAME(unsqueeze2_grad, unsqueeze_grad);
PD_REGISTER_ARG_MAPPING_FN(unsqueeze2, phi::UnsqueezeOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(unsqueeze2_grad,
phi::UnsqueezeGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
struct logical_coord {
int n;
int c;
int h;
int w;
};
float diff(const half *c, const float *c_baseline, int n) {
float max_diff = -1.;
for (int i = 0; i < n; i++) {
float c_value = __half2float(c[i]);
if (std::abs(c_baseline[i] - c_value) > max_diff) {
max_diff = std::abs(c_baseline[i] - c_value);
}
}
return max_diff;
}
__device__ int gpu_nhwc(struct logical_coord shape,
struct logical_coord index) {
return index.n * shape.h * shape.w * shape.c + index.h * shape.w * shape.c +
index.w * shape.c + index.c;
}
__global__ void naive_conv2d_kernel(const half *input,
const half *weight,
const half *bias,
float *output,
int batch,
int ic,
int ih,
int iw,
int kh,
int kw,
int oc,
int pad_h,
int pad_w,
int stride_h,
int stride_w,
int dilation_h,
int dilation_w,
int oh,
int ow,
const half *residual,
float alpha, // for leaky_relu
OpType op_type) {
int M = batch * oh * ow;
int N = oc;
int K = ic * kh * kw;
int m_i = threadIdx.x + blockIdx.x * blockDim.x;
int n_i = threadIdx.y + blockIdx.y * blockDim.y;
if (m_i >= M || n_i >= N) return;
int batch_i = m_i / (oh * ow);
int oh_i = (m_i % (oh * ow)) / ow;
int ow_i = (m_i % (oh * ow)) % ow;
int oc_i = n_i;
struct logical_coord weight_shape = {oc, ic, kh, kw};
struct logical_coord input_shape = {batch, ic, ih, iw};
int out_offset = m_i * N + n_i;
float *out_ptr = output + out_offset;
float sum = 0.f;
for (int k_i = 0; k_i < K; k_i++) {
int ic_i = k_i / (kh * kw);
int kh_i = (k_i % (kh * kw)) / kw;
int kw_i = (k_i % (kh * kw)) % kw;
struct logical_coord weight_index = {oc_i, ic_i, kh_i, kw_i};
int ih_i = oh_i * stride_h - pad_h + kh_i * dilation_h;
int iw_i = ow_i * stride_w - pad_w + kw_i * dilation_w;
if (ih_i < 0 || ih_i >= ih) continue;
if (iw_i < 0 || iw_i >= iw) continue;
struct logical_coord input_index = {batch_i, ic_i, ih_i, iw_i};
const half *weight_ptr = weight + gpu_nhwc(weight_shape, weight_index);
const half *in_ptr = input + gpu_nhwc(input_shape, input_index);
sum += __half2float(*in_ptr) * __half2float(*weight_ptr);
}
sum += __half2float(*(bias + oc_i));
float x = sum;
switch (op_type) {
case CONV2D_BIAS:
*out_ptr = x;
break;
case CONV2D_BIAS_RELU:
*out_ptr = x > 0 ? x : 0;
break;
case CONV2D_BIAS_SILU:
*out_ptr = x * (1.f / (1 + exp(-x)));
break;
case CONV2D_BIAS_ADD_RELU:
x += __half2float(*(residual + out_offset));
*out_ptr = x > 0 ? x : 0;
break;
case CONV2D_BIAS_LEAKY_RELU:
*out_ptr = x > 0 ? x : (x * alpha);
break;
default:
break;
}
}
float conv2d_diff_gpu(ConvAllParams params, OpType op_type) {
const half *input = params.input;
const half *weight = params.weight;
const half *bias = params.bias;
half *output = params.output;
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h = params.pad_h0;
int pad_w = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
int dilation_h = params.dilation_h;
int dilation_w = params.dilation_w;
const half *residual = params.residual;
int oh = params.oh;
int ow = params.ow;
int M = batch * oh * ow;
int N = oc;
constexpr int blockM = 16;
constexpr int blockN = 16;
uint3 grid = {(M + blockM - 1) / blockM, (N + blockN - 1) / blockN, 1};
uint3 block = {blockM, blockN, 1};
int output_size = batch * oc * oh * ow;
half *output_from_cutlass =
reinterpret_cast<half *>(malloc(sizeof(half) * output_size));
cudaMemcpy(output_from_cutlass,
output,
output_size * sizeof(half),
cudaMemcpyDeviceToHost);
float *gpu_output;
cudaMalloc(&gpu_output, output_size * sizeof(float));
naive_conv2d_kernel<<<grid, block>>>(input,
weight,
bias,
gpu_output,
batch,
ic,
ih,
iw,
kh,
kw,
oc,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
oh,
ow,
residual,
params.alpha,
op_type);
float *output_from_gpu =
reinterpret_cast<float *>(malloc(sizeof(float) * output_size));
cudaMemcpy(output_from_gpu,
gpu_output,
output_size * sizeof(float),
cudaMemcpyDeviceToHost);
float max_diff = diff(output_from_cutlass, output_from_gpu, output_size);
free(output_from_cutlass);
free(output_from_gpu);
cudaFree(gpu_output);
return max_diff;
}
std::string OpType2String(OpType op_type) {
switch (op_type) {
case CONV2D_BIAS:
return "conv2d_bias";
break;
case CONV2D_BIAS_RELU:
return "conv2d_bias_relu";
break;
case CONV2D_BIAS_SILU:
return "conv2d_bias_add_silu";
break;
case CONV2D_BIAS_ADD_RELU:
return "conv2d_bias_add_relu";
break;
case CONV2D_BIAS_LEAKY_RELU:
return "conv2d_bias_leaky_relu";
default:
break;
}
return "unnamed_op";
}
int ProfileToGetBestConfig(
const std::vector<std::function<cutlass::Status(ConvAllParams)>> &all_func,
ConvAllParams params,
OpType op_type) {
constexpr int WARMUP = 10;
constexpr int REPEAT = 100;
float min_time = 100000.f;
int min_time_index = -1;
for (int i = 0; i < all_func.size(); i++) {
cutlass::Status status;
auto func = all_func[i];
// When func has large diff, we will make it nullptr.
if (!func) continue;
for (int ii = 0; ii < WARMUP; ii++) {
status = func(params);
}
cudaEvent_t beg, end;
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&beg));
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&end));
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(beg));
for (int ii = 0; ii < REPEAT; ii++) {
status = func(params);
}
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(end));
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventSynchronize(end));
float elapsed_time;
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventElapsedTime(&elapsed_time, beg, end));
if (elapsed_time < min_time && status == cutlass::Status::kSuccess) {
min_time = elapsed_time;
min_time_index = i;
}
// debug code
VLOG(3) << OpType2String(op_type) << ": tactic " << i << " has max diff "
<< conv2d_diff_gpu(params, op_type) << " compared with baseline.";
}
if (min_time_index < 0) {
PADDLE_THROW(
phi::errors::NotFound("Can't find any cutlass config for this %s op.",
OpType2String(op_type).c_str()));
}
return min_time_index;
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda_fp16.h>
#include <vector>
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_decl.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/conv/device/implicit_gemm_convolution.h"
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/enforce.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
#define CUTLASS_CHECK(status) \
if (status != cutlass::Status::kSuccess) { \
VLOG(3) \
<< "Cutlass can not deal with this problem size, skip this kernel!"; \
return status; \
}
typedef enum {
CONV2D_BIAS,
CONV2D_BIAS_RELU,
CONV2D_BIAS_ADD_RELU,
CONV2D_BIAS_SILU,
CONV2D_BIAS_LEAKY_RELU
} OpType;
// conv2d_diff_gpu calculate diff of cutlass output and baseline output, you can
// use them to debug. return value is the max diff between cutlass and baseline.
float conv2d_diff_gpu(ConvAllParams params, OpType op_type);
int ProfileToGetBestConfig(
const std::vector<std::function<cutlass::Status(ConvAllParams)>>& all_func,
ConvAllParams params,
OpType op_type);
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_decl.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
template <typename T, typename Context>
void Conv2dFusionKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& filter,
const DenseTensor& bias,
const paddle::optional<DenseTensor>& residual,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
const std::string& activation,
float fuse_alpha,
DenseTensor* output) {
ctx.template Alloc<T>(output);
auto in_dims = x.dims();
auto filter_dims = filter.dims();
auto out_dims = output->dims();
CHECK_EQ(in_dims.size() == 4UL, true);
CHECK_EQ(filter_dims.size() == 4UL, true);
CHECK_EQ(strides.size() == 2UL, true);
CHECK_EQ(dilations.size() == 2UL, true);
CHECK_EQ(groups == 1, true);
CHECK_EQ(padding_algorithm == "EXPLICIT", true);
const int batch = in_dims[0];
const int ic = in_dims[3];
const int ih = in_dims[1];
const int iw = in_dims[2];
int pad_h0 = 0;
int pad_h1 = 0;
int pad_w0 = 0;
int pad_w1 = 0;
if (paddings.size() == 2UL) {
pad_h0 = paddings[0];
pad_h1 = paddings[0];
pad_w0 = paddings[1];
pad_w1 = paddings[1];
} else if (paddings.size() == 4UL) {
pad_h0 = paddings[0];
pad_h1 = paddings[1];
pad_w0 = paddings[2];
pad_w1 = paddings[3];
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Attr paddins in conv2d_fusion must have 2 or 4 elements, but now have "
"%u elements.",
paddings.size()));
}
const int stride_h = strides[0];
const int stride_w = strides[1];
const int dilation_h = dilations[0];
const int dilation_w = dilations[1];
const int oc = filter_dims[0];
const int kh = filter_dims[1];
const int kw = filter_dims[2];
CHECK_EQ(out_dims.size() == 4UL, true);
const int oh = out_dims[1];
const int ow = out_dims[2];
ConvAllParams params = {reinterpret_cast<const half*>(x.data<T>()),
reinterpret_cast<const half*>(filter.data<T>()),
reinterpret_cast<const half*>(bias.data<T>()),
nullptr,
reinterpret_cast<half*>(output->data<T>()),
batch,
ic,
ih,
iw,
kh,
kw,
oc,
pad_h0,
pad_h1,
pad_w0,
pad_w1,
stride_h,
stride_w,
dilation_h,
dilation_w,
oh,
ow,
&ctx};
if (residual) {
if (activation == "relu") {
params.residual = reinterpret_cast<const half*>(residual->data<T>());
Conv2dBiasAddRelu(params);
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Cutlass now only support relu activation in a residual block"));
}
} else if (activation == "relu") {
Conv2dBiasRelu(params);
} else if (activation == "swish") {
Conv2dBiasSilu(params);
} else if (activation == "identity") {
Conv2dBias(params);
} else if (activation == "leaky_relu") {
params.alpha = fuse_alpha;
Conv2dBiasLeakyRelu(params);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Cutlass does not support this activation: %s.", activation.c_str()));
}
output->set_layout(DataLayout::NHWC);
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(conv2d_fusion_cutlass,
GPU,
ALL_LAYOUT,
phi::fusion::cutlass_internal::Conv2dFusionKernel,
float,
phi::dtype::float16) {}
......@@ -17,7 +17,28 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/expand_grad_kernel_impl.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
namespace phi {
template <typename T, typename Context>
void ExpandGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const IntArray& shape,
DenseTensor* x_grad) {
ctx.template Alloc<T>(x_grad);
if (x_grad->dims() == out_grad.dims()) {
phi::Copy(ctx, out_grad, ctx.GetPlace(), false, x_grad);
} else {
std::vector<int> reduce_dims =
funcs::GetReduceDim(x_grad->dims(), out_grad.dims(), -1);
funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
ctx, out_grad, x_grad, kps::IdentityFunctor<T>(), reduce_dims);
}
}
} // namespace phi
PD_REGISTER_KERNEL(expand_grad,
GPU,
......@@ -26,5 +47,6 @@ PD_REGISTER_KERNEL(expand_grad,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t) {}
......@@ -18,7 +18,66 @@
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/expand_kernel_impl.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
namespace phi {
template <typename T, typename Context>
void ExpandKernel(const Context& ctx,
const DenseTensor& x,
const IntArray& shape,
DenseTensor* out) {
auto expand_shape = shape.GetData();
auto diff = expand_shape.size() - x.dims().size();
auto out_shape = phi::vectorize<int64_t>(x.dims());
out_shape.insert(out_shape.begin(), diff, 1);
for (size_t i = 0; i < out_shape.size(); ++i) {
PADDLE_ENFORCE_NE(
expand_shape[i],
0,
phi::errors::InvalidArgument("The expanded size cannot be zero."));
if (i < diff) {
PADDLE_ENFORCE_GT(
expand_shape[i],
0,
phi::errors::InvalidArgument(
"The expanded size (%d) for non-existing dimensions must be "
"positive for expand kernel.",
expand_shape[i]));
out_shape[i] = expand_shape[i];
} else if (expand_shape[i] > 0) {
if (out_shape[i] != 1) {
PADDLE_ENFORCE_EQ(
out_shape[i],
expand_shape[i],
phi::errors::InvalidArgument(
"The value (%d) of the non-singleton dimension does not match"
" the corresponding value (%d) in shape for expand kernel.",
out_shape[i],
expand_shape[i]));
} else {
out_shape[i] = expand_shape[i];
}
} else {
PADDLE_ENFORCE_EQ(
expand_shape[i],
-1,
phi::errors::InvalidArgument(
"When the value in shape is negative for expand_v2 op, "
"only -1 is supported, but the value received is %d.",
expand_shape[i]));
}
}
out->Resize(phi::make_ddim(out_shape));
ctx.template Alloc<T>(out);
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
phi::funcs::BroadcastKernel<ElementwiseType::kUnary, T, T>(
ctx, ins, &outs, -1, kps::IdentityFunctor<T>());
}
} // namespace phi
PD_REGISTER_KERNEL(expand,
GPU,
......@@ -27,6 +86,7 @@ PD_REGISTER_KERNEL(expand,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t,
bool) {}
......@@ -101,6 +101,9 @@ void FlipKernel(const Context& dev_ctx,
DenseTensor* out) {
const size_t total_dims = x.dims().size();
switch (total_dims) {
case 0:
LaunchFlipCudaKernel<T, Context, 0>(dev_ctx, x, axis, out);
break;
case 1:
LaunchFlipCudaKernel<T, Context, 1>(dev_ctx, x, axis, out);
break;
......
......@@ -53,9 +53,24 @@ KernelSignature Conv2dDoubleGradOpArgumentMapping(
{"DInput", "DFilter", "DDOutput"});
}
KernelSignature Conv2dFusionArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("conv2d_fusion_cutlass",
{"Input", "Filter", "Bias", "ResidualData"},
{"strides",
"paddings",
"padding_algorithm",
"groups",
"dilations",
"data_format",
"activation",
"fuse_alpha"},
{"Output"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(conv2d, phi::Conv2dOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(conv2d_fusion_cutlass,
phi::Conv2dFusionArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(conv2d_grad, phi::Conv2dGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(conv2d_grad_grad,
phi::Conv2dDoubleGradOpArgumentMapping);
......@@ -1843,6 +1843,11 @@ function precise_card_test_single {
for case in $(echo $testcases | tr "$|^" "\n" | awk '!/^$/')
do
cd ${PADDLE_ROOT}/build
find paddle/fluid -name *.gcda | xargs rm -f
find paddle/phi -name *.gcda | xargs rm -f
find paddle/utils -name *.gcda | xargs rm -f
precise_card_test "^${case}$" $num
#if test failed,continue,if test succeed ,go on
......@@ -1876,9 +1881,6 @@ function precise_card_test_single {
fi
mv python-coverage.data.* ${PADDLE_ROOT}/build/pytest/$case
fi
find paddle/fluid -name *.gcda | xargs rm -f
find paddle/phi -name *.gcda | xargs rm -f
find paddle/utils -name *.gcda | xargs rm -f
done
}
......@@ -1988,6 +1990,10 @@ set +x
fi
read testcase <<< $(echo "$line"|grep -oEi "\w+$")
if [[ "$testcase" == "simple_precision_test" ]]; then
continue
fi
if [[ "$is_multicard" == "" ]]; then
# trick: treat all test case with prefix "test_dist" as dist case, and would run on 2 GPUs
read is_multicard <<< $(echo "$testcase"|grep -oEi "test_dist_")
......@@ -2032,6 +2038,8 @@ set -x
mkdir -p ${PADDLE_ROOT}/build/ut_map
mkdir -p ${PADDLE_ROOT}/build/pytest
#run all unittest to get the coverage information of .c and .h files
precise_card_test_single "^simple_precision_test$" 1
wait;
precise_card_test_single "$single_card_tests" 1
precise_card_test_single "$single_card_tests_1" 1
precise_card_test_single "$multiple_card_tests" 2
......
......@@ -20,11 +20,11 @@ __all__ = []
import paddle
from paddle.common_ops_import import LayerHelper
from paddle.fluid.clip import GradientClipByNorm, append_gradient_clip_ops
from paddle.fluid.dygraph import base as imperative_base
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.optimizer import Momentum, Optimizer
from paddle.framework import core
from paddle.nn.clip import ClipGradByNorm, append_gradient_clip_ops
from paddle.static import create_global_var
......@@ -76,9 +76,9 @@ class DGCMomentumOptimizer(Optimizer):
self._dgc_clip_norm = None
if grad_clip is not None:
if not isinstance(grad_clip, GradientClipByNorm):
if not isinstance(grad_clip, ClipGradByNorm):
raise TypeError(
"The type of grad_clip should be 'GradientClipByNorm', because DGCMomentumOptimizer only support GradientClipByNorm"
"The type of grad_clip should be 'ClipGradByNorm', because DGCMomentumOptimizer only support ClipGradByNorm"
)
assert isinstance(num_trainers, int), (
"The type of num_trainers should be 'int', but received %s"
......
......@@ -15,9 +15,8 @@
import paddle
from paddle import framework
from paddle.autograd import no_grad
from paddle.fluid import layers
from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.framework import core
from paddle.nn import ClipGradByGlobalNorm, clip
from ...base.topology import ParallelMode
from ...utils.hybrid_parallel_util import (
......@@ -62,8 +61,8 @@ class HybridParallelClipGrad:
continue
merge_grad = g
if g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.merge_selected_rows(g)
merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
merge_grad = clip.merge_selected_rows(g)
merge_grad = clip.get_tensor_from_selected_rows(merge_grad)
square = paddle.square(merge_grad)
sum_square = paddle.sum(square)
......
......@@ -30,7 +30,7 @@ import paddle
import paddle.distributed as dist
from paddle.distributed import ParallelMode, fleet
from paddle.fluid import core
from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.nn import ClipGradByGlobalNorm
from paddle.optimizer import Optimizer
HybridParallelClipGrad = (
......
......@@ -25,8 +25,8 @@ import paddle.fluid.framework as framework
from paddle import nn
from paddle.autograd import PyLayer
from paddle.distributed import collective
from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.fluid.framework import EagerParamBase
from paddle.nn import ClipGradByGlobalNorm
from .group_sharded_storage import GradStorage
from .group_sharded_utils import GroupShardedClipGrad, Type, device_guard
......
......@@ -23,6 +23,7 @@ from paddle import _legacy_C_ops
from paddle.fluid import core, layers
from paddle.fluid.dygraph import to_variable
from paddle.fluid.framework import dygraph_only
from paddle.nn import clip
class Taskflow:
......@@ -65,8 +66,8 @@ class GroupShardedClipGrad:
merge_grad = g
if g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.get_tensor_from_selected_rows(
layers.merge_selected_rows(g)
merge_grad = clip.get_tensor_from_selected_rows(
clip.merge_selected_rows(g)
)
square = paddle.square(merge_grad)
sum_square = paddle.sum(square)
......
......@@ -159,7 +159,7 @@ def auc(stat_pos, stat_neg, scope=None, util=None):
.. code-block:: python
# in model.py
similarity_norm = fluid.layers.sigmoid(fluid.layers.clip(output, min=-15.0, max=15.0))
similarity_norm = fluid.layers.sigmoid(paddle.clip(output, min=-15.0, max=15.0))
binary_predict = fluid.layers.concat(
input=[paddle.subtract(fluid.layers.ceil(similarity_norm), similarity_norm), similarity_norm], axis=1)
self.auc, batch_auc, [batch_stat_pos, batch_stat_neg, stat_pos, stat_neg] =
......
......@@ -90,7 +90,6 @@ from .transpiler import (
DistributeTranspilerConfig,
)
from .lod_tensor import create_lod_tensor, create_random_int_lodtensor
from . import clip
from . import profiler
from . import unique_name
from . import parallel_executor
......@@ -99,7 +98,6 @@ from . import compiler
from .compiler import *
from paddle.fluid.layers.math_op_patch import monkey_patch_variable
from . import install_check
from .dygraph.nn import *
from .dygraph.layers import *
from .dygraph.base import enable_dygraph, disable_dygraph
from .io import save, load, load_program_state, set_program_state
......@@ -165,7 +163,6 @@ __all__ = (
'ParamAttr',
'WeightNormParamAttr',
'DataFeeder',
'clip',
'profiler',
'unique_name',
'Scope',
......
此差异已折叠。
......@@ -21,9 +21,6 @@ from .layers import *
from . import container
from .container import *
from . import nn
from .nn import *
from . import tracer
from .tracer import *
......@@ -45,7 +42,6 @@ __all__ = []
__all__ += layers.__all__
__all__ += base.__all__
__all__ += container.__all__
__all__ += nn.__all__
__all__ += parallel.__all__
__all__ += checkpoint.__all__
__all__ += learning_rate_scheduler.__all__
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from .. import core
from ..layers import utils
from ..layers import nn as F
from .. import dygraph_utils
from . import layers
from ..framework import (
Variable,
OpProtoHolder,
Parameter,
_dygraph_tracer,
_varbase_creator,
default_main_program,
_global_flags,
in_dygraph_mode,
)
from ..data_feeder import (
convert_dtype,
check_variable_and_dtype,
check_type,
check_dtype,
)
from ..param_attr import ParamAttr
from ..initializer import Normal, Constant, NumpyArrayInitializer
from .. import unique_name
from .layer_object_helper import LayerObjectHelper
from ..data_feeder import check_variable_and_dtype, check_type
import numpy as np
import numbers
import logging
import os
import paddle.utils.deprecated as deprecated
from paddle import _C_ops, _legacy_C_ops
__all__ = []
class BatchNorm(layers.Layer):
r"""
This interface is used to construct a callable object of the ``BatchNorm`` class.
For more details, refer to code examples.
It implements the function of the Batch Normalization Layer and can be used
as a normalizer function for conv2d and fully connected operations.
The data is normalized by the mean and variance of the channel based on the current batch data.
Refer to `Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift <https://arxiv.org/pdf/1502.03167.pdf>`_
for more details.
When use_global_stats = False, the :math:`\mu_{\beta}`
and :math:`\sigma_{\beta}^{2}` are the statistics of one mini-batch.
Calculated as follows:
.. math::
\mu_{\beta} &\gets \frac{1}{m} \sum_{i=1}^{m} x_i \qquad &
//\ mini-batch\ mean \\
\sigma_{\beta}^{2} &\gets \frac{1}{m} \sum_{i=1}^{m}(x_i - \mu_{\beta})^2 \qquad &
//\ mini-batch\ variance \\
- :math:`x` : mini-batch data
- :math:`m` : the size of the mini-batch data
When use_global_stats = True, the :math:`\\mu_{\\beta}`
and :math:`\\sigma_{\\beta}^{2}` are not the statistics of one mini-batch.
They are global or running statistics (moving_mean and moving_variance). It usually got from the
pre-trained model. Calculated as follows:
.. math::
moving\_mean = moving\_mean * momentum + \mu_{\beta} * (1. - momentum) \quad &// global mean \\
moving\_variance = moving\_variance * momentum + \sigma_{\beta}^{2} * (1. - momentum) \quad &// global variance \\
The normalization function formula is as follows:
.. math::
\hat{x_i} &\gets \frac{x_i - \mu_\beta} {\sqrt{\
\sigma_{\beta}^{2} + \epsilon}} \qquad &//\ normalize \\
y_i &\gets \gamma \hat{x_i} + \beta \qquad &//\ scale\ and\ shift
- :math:`\epsilon` : add a smaller value to the variance to prevent division by zero
- :math:`\gamma` : trainable proportional parameter
- :math:`\beta` : trainable deviation parameter
Parameters:
num_channels(int): Indicate the number of channels of the input ``Tensor``.
act(str, optional): Activation to be applied to the output of batch normalization. Default: None.
is_test (bool, optional): A flag indicating whether it is in test phrase or not.
This flag only has effect on static graph mode. For dygraph mode, please use ``eval()``.
Default: False.
momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9.
epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5.
param_attr(ParamAttr, optional): The parameter attribute for Parameter `scale`
of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr(ParamAttr, optional): The parameter attribute for the bias of batch_norm.
If it is set to None or one attribute of ParamAttr, batch_norm
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
dtype(str, optional): Indicate the data type of the input ``Tensor``,
which can be float32 or float64. Default: float32.
data_layout(str, optional): Specify the input data format, the data format can be "NCHW" or "NHWC", where `N` is batch size, `C` is the number of the feature map, `H` is the height of the feature map, `W` is the width of the feature map. Default: NCHW.
in_place(bool, optional): Make the input and output of batch norm reuse memory. Default: False.
moving_mean_name(str, optional): The name of moving_mean which store the global Mean. Default: None.
moving_variance_name(str, optional): The name of the moving_variance which store the global Variance. Default: None.
do_model_average_for_mean_and_var(bool, optional): Whether parameter mean and variance should do model
average when model average is enabled. Default: True.
use_global_stats(bool, optional): Whether to use global mean and
variance. In inference or test mode, set use_global_stats to true
or is_test to true, and the behavior is equivalent.
In train mode, when setting use_global_stats True, the global mean
and variance are also used during train period. Default: False.
trainable_statistics(bool, optional): Whether to calculate mean and var in eval mode. In eval mode, when
setting trainable_statistics True, mean and variance will be calculated by current batch statistics.
Default: False.
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
x = paddle.rand([3, 10, 3, 7], 'float32')
with fluid.dygraph.guard():
x = to_variable(x)
batch_norm = fluid.BatchNorm(10)
hidden1 = batch_norm(x)
"""
def __init__(
self,
num_channels,
act=None,
is_test=False,
momentum=0.9,
epsilon=1e-05,
param_attr=None,
bias_attr=None,
dtype='float32',
data_layout='NCHW',
in_place=False,
moving_mean_name=None,
moving_variance_name=None,
do_model_average_for_mean_and_var=True,
use_global_stats=False,
trainable_statistics=False,
):
super().__init__()
self._param_attr = param_attr
self._bias_attr = bias_attr
self._act = act
self._use_mkldnn = _global_flags()["FLAGS_use_mkldnn"]
assert (
bias_attr is not False
), "bias_attr should not be False in batch_norm."
if dtype == "float16":
self._dtype = "float32"
else:
self._dtype = dtype
param_shape = [num_channels]
# create parameter
self.weight = self.create_parameter(
attr=self._param_attr,
shape=param_shape,
dtype=self._dtype,
default_initializer=Constant(1.0),
)
self.weight.stop_gradient = (
use_global_stats and self._param_attr.learning_rate == 0.0
)
self.bias = self.create_parameter(
attr=self._bias_attr,
shape=param_shape,
dtype=self._dtype,
is_bias=True,
)
self.bias.stop_gradient = (
use_global_stats and self._param_attr.learning_rate == 0.0
)
self._mean = self.create_parameter(
attr=ParamAttr(
name=moving_mean_name,
initializer=Constant(0.0),
trainable=False,
do_model_average=do_model_average_for_mean_and_var,
),
shape=param_shape,
dtype=self._dtype,
)
self._mean.stop_gradient = True
self._variance = self.create_parameter(
attr=ParamAttr(
name=moving_variance_name,
initializer=Constant(1.0),
trainable=False,
do_model_average=do_model_average_for_mean_and_var,
),
shape=param_shape,
dtype=self._dtype,
)
self._variance.stop_gradient = True
self._in_place = in_place
self._data_layout = data_layout
self._momentum = momentum
self._epsilon = epsilon
self._is_test = is_test
self._fuse_with_relu = False
self._use_global_stats = use_global_stats
self._trainable_statistics = trainable_statistics
def forward(self, input):
# create output
# mean and mean_out share the same memory
mean_out = self._mean
# variance and variance out share the same memory
variance_out = self._variance
if in_dygraph_mode():
batch_norm_out, t1, t2, t3, t4, _ = _C_ops.batch_norm(
input,
self._mean,
self._variance,
self.weight,
self.bias,
not self.training,
self._momentum,
self._epsilon,
self._data_layout,
self._use_global_stats,
self._trainable_statistics,
)
return dygraph_utils._append_activation_in_dygraph(
batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn
)
else:
check_variable_and_dtype(
input, 'input', ['float16', 'float32', 'float64'], 'BatchNorm'
)
attrs = {
"momentum": self._momentum,
"epsilon": self._epsilon,
"is_test": self._is_test,
"data_layout": self._data_layout,
"use_mkldnn": False,
"fuse_with_relu": self._fuse_with_relu,
"use_global_stats": self._use_global_stats,
"trainable_statistics": self._trainable_statistics,
}
inputs = {
"X": [input],
"Scale": [self.weight],
"Bias": [self.bias],
"Mean": [self._mean],
"Variance": [self._variance],
}
saved_mean = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True
)
saved_variance = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True
)
reserve_space = self._helper.create_variable_for_type_inference(
dtype=self._helper.input_dtype(input), stop_gradient=True
)
batch_norm_out = (
input
if self._in_place
else self._helper.create_variable_for_type_inference(
self._dtype
)
)
outputs = {
"Y": [batch_norm_out],
"MeanOut": [mean_out],
"VarianceOut": [variance_out],
"SavedMean": [saved_mean],
"SavedVariance": [saved_variance],
}
if reserve_space is not None:
outputs["ReserveSpace"] = [reserve_space]
self._helper.append_op(
type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs
)
# Currently, we don't support inplace in dygraph mode
return self._helper.append_activation(batch_norm_out, self._act)
......@@ -185,7 +185,7 @@ class FleetUtil:
# below is part of model
emb = my_slot_net(slots, label) # emb can be fc layer of size 1
similarity_norm = fluid.layers.sigmoid(fluid.layers.clip(\
similarity_norm = fluid.layers.sigmoid(paddle.clip(\
emb, min=-15.0, max=15.0), name="similarity_norm")\
binary_predict = fluid.layers.concat(input=[\
paddle.subtract(\
......@@ -1374,7 +1374,7 @@ class FleetUtil:
label = fluid.layers.data(name="click", shape=[-1, 1],\
dtype="int64", lod_level=0, append_batch_size=False)
emb = my_slot_net(slots, label) # emb can be fc layer of size 1
similarity_norm = fluid.layers.sigmoid(fluid.layers.clip(\
similarity_norm = fluid.layers.sigmoid(paddle.clip(\
emb, min=-15.0, max=15.0), name="similarity_norm")\
binary_predict = fluid.layers.concat(input=[\
paddle.subtract(\
......@@ -1574,7 +1574,7 @@ class FleetUtil:
label = fluid.layers.data(name="click", shape=[-1, 1],\
dtype="int64", lod_level=0, append_batch_size=False)
emb = my_slot_net(slots, label) # emb can be fc layer of size 1
similarity_norm = fluid.layers.sigmoid(fluid.layers.clip(\
similarity_norm = fluid.layers.sigmoid(paddle.clip(\
emb, min=-15.0, max=15.0), name="similarity_norm")\
binary_predict = fluid.layers.concat(input=[\
paddle.subtract(\
......
......@@ -25,7 +25,7 @@ from .param_attr import ParamAttr
from .initializer import Constant
from . import layers
from . import backward
from .dygraph import Layer, nn
from .dygraph import Layer
from . import executor
from . import optimizer
from . import core
......
......@@ -63,10 +63,6 @@ __all__ = [
'fc',
'embedding',
'autoincreased_step_counter',
'clip',
'clip_by_norm',
'merge_selected_rows',
'get_tensor_from_selected_rows',
]
OP_NAMEMAPPING = {
......@@ -997,199 +993,3 @@ def _logical_op(op_name, x, y, out=None, name=None, binary_op=True):
)
return out
@templatedoc()
def clip(x, min, max, name=None):
"""
:old_api: paddle.fluid.layers.clip
${comment}
Args:
x(${x_type}): ${x_comment}
min(float): ${min_comment}
max(float): ${max_comment}
name(str, optional): The default value is None.
Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`
Returns:
${out_comment}
Return Type:
${out_type}
Examples:
.. code-block:: python
import paddle.fluid as fluid
input = fluid.data(
name='data', shape=[1], dtype='float32')
reward = fluid.layers.clip(x=input, min=-1.0, max=1.0)
"""
helper = LayerHelper("clip", **locals())
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'clip')
if name is None:
name = unique_name.generate_with_ignorable_key(
".".join([helper.name, 'tmp'])
)
out = helper.create_variable(
type=x.type, name=name, dtype=x.dtype, persistable=False
)
helper.append_op(
type="clip",
inputs={"X": x},
attrs={"min": min, "max": max},
outputs={"Out": out},
)
return out
@templatedoc()
def clip_by_norm(x, max_norm, name=None):
"""
${comment}
Args:
x(${x_type}): ${x_comment}
max_norm(${max_norm_type}): ${max_norm_comment}
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
Tensor:
out(${out_type}): ${out_comment}
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
input = paddle.to_tensor([[2.0, 2.0], [2.0, 2.0]], dtype='float32')
reward = fluid.layers.clip_by_norm(x=input, max_norm=1.0)
# [[0.5, 0.5], [0.5, 0.5]]
"""
if in_dygraph_mode():
return _C_ops.clip_by_norm(x, max_norm)
else:
helper = LayerHelper("clip_by_norm", **locals())
check_variable_and_dtype(x, 'X', ['float32', 'float16'], 'clip_by_norm')
check_type(max_norm, 'max_norm', (float), 'clip_by_norm')
if name is None:
name = unique_name.generate_with_ignorable_key(
".".join([helper.name, 'tmp'])
)
out = helper.create_variable(
type=x.type, name=name, dtype=x.dtype, persistable=False
)
helper.append_op(
type="clip_by_norm",
inputs={"X": x},
attrs={"max_norm": max_norm},
outputs={"Out": out},
)
return out
@templatedoc()
def merge_selected_rows(x, name=None):
"""
${comment}
Args:
x(${x_type}): ${x_comment}
name(basestring|None): Name of the output.
Returns:
out(${out_type}): ${out_comment}
Examples:
.. code-block:: python
import paddle.fluid as fluid
b = fluid.default_main_program().global_block()
var = b.create_var(
name="X", dtype="float32", persistable=True,
type=fluid.core.VarDesc.VarType.SELECTED_ROWS)
y = fluid.layers.merge_selected_rows(var)
"""
if in_dygraph_mode():
return _C_ops.merge_selected_rows(x)
else:
helper = LayerHelper("merge_selected_rows", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="merge_selected_rows",
inputs={"X": x},
attrs={},
outputs={"Out": out},
)
return out
@templatedoc()
def get_tensor_from_selected_rows(x, name=None):
"""
This operator gets tensor data from input with SelectedRows type, and outputs a LoDTensor.
.. code-block:: text
input x is SelectedRows:
x.rows = [0, 5, 5, 4, 19]
x.height = 20
x.value = [[1, 1] [2, 2] [2, 2] [3, 3] [6, 6]]
Output is LoDTensor:
out.shape = [5, 2]
out.data = [[1, 1],
[2, 2],
[2, 2],
[3, 3],
[6, 6]]
Args:
x(SelectedRows): Input with SelectedRows type. The data type is float32, float64, int32 or int64.
name(str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
Variable: LoDTensor transformed from SelectedRows. The data type is same with input.
Examples:
.. code-block:: python
import paddle.fluid as fluid
b = fluid.default_main_program().global_block()
input = b.create_var(name="X", dtype="float32", persistable=True, type=fluid.core.VarDesc.VarType.SELECTED_ROWS)
out = fluid.layers.get_tensor_from_selected_rows(input)
"""
check_type(x, 'x', Variable, 'get_tensor_from_selected_rows')
if x.type != core.VarDesc.VarType.SELECTED_ROWS:
raise TypeError(
"The type of 'x' in get_tensor_from_selected_rows must be SELECTED_ROWS."
)
helper = LayerHelper('get_tensor_from_selected_rows', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='get_tensor_from_selected_rows',
inputs={'X': x},
outputs={'Out': out},
attrs={},
)
return out
......@@ -38,13 +38,6 @@ from .backward import (
_append_grad_suffix_,
_get_no_grad_set_name,
)
from .clip import (
GradientClipBase,
GradientClipByNorm,
error_clip_callback,
append_gradient_clip_ops,
ClipGradByGlobalNorm,
)
from .framework import program_guard
from .initializer import Constant
from .layer_helper import LayerHelper
......@@ -160,7 +153,7 @@ class Optimizer:
)
if grad_clip is not None:
if not isinstance(grad_clip, GradientClipBase):
if not isinstance(grad_clip, paddle.nn.clip.GradientClipBase):
raise TypeError(
"'grad_clip' should be an instance of GradientClipBase's derived class"
)
......@@ -1030,7 +1023,7 @@ class Optimizer:
params_grads.append((param, grad_var))
else:
if callbacks is None:
callbacks = [error_clip_callback]
callbacks = [paddle.nn.clip.error_clip_callback]
else:
assert isinstance(callbacks, list)
program = loss.block.program
......@@ -1260,7 +1253,7 @@ class Optimizer:
# NOTE(zhiqiu): currently, only support ClipGradByGlobalNorm and without regularization.
if self._flatten_param_grads and self.regularization is None:
if self._grad_clip is None or isinstance(
self._grad_clip, ClipGradByGlobalNorm
self._grad_clip, paddle.nn.ClipGradByGlobalNorm
):
params_grads = self.flatten_param_grads(params_grads)
......@@ -1268,7 +1261,7 @@ class Optimizer:
if self._grad_clip is not None:
params_grads = self._grad_clip(params_grads)
else:
params_grads = append_gradient_clip_ops(params_grads)
params_grads = paddle.nn.clip.append_gradient_clip_ops(params_grads)
# Add regularization if any
params_grads = self.append_regularization_ops(
......
......@@ -28,4 +28,5 @@ if(WITH_CUSTOM_DEVICE AND NOT WITH_GPU)
set_tests_properties(test_custom_cpu_profiler_plugin PROPERTIES TIMEOUT 120)
set_tests_properties(test_fleet_launch_custom_device PROPERTIES TIMEOUT 120)
set_tests_properties(test_custom_cpu_to_static PROPERTIES TIMEOUT 120)
set_tests_properties(test_custom_device_relu_setup PROPERTIES TIMEOUT 120)
endif()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <vector>
#include "paddle/extension.h"
#define CHECK_CPU_INPUT(x) PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.")
#define CHECK_CUSTOM_INPUT(x) \
PD_CHECK(x.is_custom_device(), #x " must be a custom Tensor.")
template <typename data_t>
void relu_cpu_forward_kernel(const data_t* x_data,
data_t* out_data,
int64_t x_numel) {
PD_CHECK(x_data != nullptr, "x_data is nullptr.");
PD_CHECK(out_data != nullptr, "out_data is nullptr.");
for (int64_t i = 0; i < x_numel; ++i) {
out_data[i] = std::max(static_cast<data_t>(0.), x_data[i]);
}
}
template <typename data_t>
void relu_cpu_backward_kernel(const data_t* grad_out_data,
const data_t* out_data,
data_t* grad_x_data,
int64_t out_numel) {
for (int64_t i = 0; i < out_numel; ++i) {
grad_x_data[i] =
grad_out_data[i] * (out_data[i] > static_cast<data_t>(0) ? 1. : 0.);
}
}
template <typename data_t>
void relu_cpu_double_backward_kernel(const data_t* out_data,
const data_t* ddx_data,
data_t* ddout_data,
int64_t ddout_numel) {
for (int64_t i = 0; i < ddout_numel; ++i) {
ddout_data[i] =
ddx_data[i] * (out_data[i] > static_cast<data_t>(0) ? 1. : 0.);
}
}
std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
CHECK_CPU_INPUT(x);
auto out = paddle::empty_like(x);
PD_DISPATCH_FLOATING_TYPES(
x.type(), "relu_cpu_forward", ([&] {
relu_cpu_forward_kernel<data_t>(
x.data<data_t>(), out.data<data_t>(), x.numel());
}));
return {out};
}
std::vector<paddle::Tensor> relu_cpu_backward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
auto grad_x = paddle::empty_like(x);
PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] {
relu_cpu_backward_kernel<data_t>(
grad_out.data<data_t>(),
out.data<data_t>(),
grad_x.data<data_t>(),
out.size());
}));
return {grad_x};
}
std::vector<paddle::Tensor> relu_cpu_double_backward(
const paddle::Tensor& out, const paddle::Tensor& ddx) {
CHECK_CPU_INPUT(out);
CHECK_CPU_INPUT(ddx);
auto ddout = paddle::empty(out.shape(), out.dtype(), out.place());
PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_double_backward", ([&] {
relu_cpu_double_backward_kernel<data_t>(
out.data<data_t>(),
ddx.data<data_t>(),
ddout.mutable_data<data_t>(out.place()),
ddout.size());
}));
return {ddout};
}
std::vector<paddle::Tensor> relu_custom_forward(const paddle::Tensor& x) {
CHECK_CUSTOM_INPUT(x);
auto out = paddle::relu(x);
return {out};
}
std::vector<paddle::Tensor> relu_custom_backward(
const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
CHECK_CUSTOM_INPUT(x);
CHECK_CUSTOM_INPUT(out);
auto grad_x = paddle::empty_like(x, x.dtype(), x.place());
auto ones = paddle::experimental::full_like(x, 1.0, x.dtype(), x.place());
auto zeros = paddle::experimental::full_like(x, 0.0, x.dtype(), x.place());
auto condition = paddle::experimental::greater_than(x, zeros);
grad_x = paddle::multiply(grad_out, paddle::where(condition, ones, zeros));
return {grad_x};
}
std::vector<paddle::Tensor> relu_custom_double_backward(
const paddle::Tensor& out, const paddle::Tensor& ddx) {
CHECK_CUSTOM_INPUT(out);
auto ddout = paddle::empty(out.shape(), out.dtype(), out.place());
auto ones =
paddle::experimental::full_like(out, 1.0, out.dtype(), out.place());
auto zeros =
paddle::experimental::full_like(out, 0.0, out.dtype(), out.place());
auto condition = paddle::experimental::greater_than(out, zeros);
ddout = paddle::multiply(ddx, paddle::where(condition, ones, zeros));
return {ddout};
}
std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x) {
if (x.is_cpu()) {
return relu_cpu_forward(x);
} else if (x.is_custom_device()) {
return relu_custom_forward(x);
} else {
PD_THROW("Not implemented.");
}
}
std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
if (x.is_cpu()) {
return relu_cpu_backward(x, out, grad_out);
} else if (x.is_custom_device()) {
return relu_custom_backward(x, out, grad_out);
} else {
PD_THROW("Not implemented.");
}
}
std::vector<paddle::Tensor> ReluDoubleBackward(const paddle::Tensor& out,
const paddle::Tensor& ddx) {
if (out.is_cpu()) {
return relu_cpu_double_backward(out, ddx);
} else if (out.is_custom_device()) {
return relu_custom_double_backward(out, ddx);
} else {
PD_THROW("Not implemented.");
}
}
std::vector<std::vector<int64_t>> ReluDoubleBackwardInferShape(
const std::vector<int64_t>& out_shape,
const std::vector<int64_t>& ddx_shape) {
return {out_shape};
}
PD_BUILD_OP(custom_relu)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ReluForward));
PD_BUILD_GRAD_OP(custom_relu)
.Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(ReluBackward));
PD_BUILD_DOUBLE_GRAD_OP(custom_relu)
.Inputs({"Out", paddle::Grad(paddle::Grad("X"))})
.Outputs({paddle::Grad(paddle::Grad("Out"))})
.SetKernelFn(PD_KERNEL(ReluDoubleBackward))
.SetInferShapeFn(PD_INFER_SHAPE(ReluDoubleBackwardInferShape));
......@@ -38,13 +38,13 @@ with fluid.program_guard(main_program=prog):
prog_clip = prog.clone()
prog_clip.block(0).var(hidden1.name)._set_error_clip(
fluid.clip.ErrorClipByValue(max=CLIP_MAX, min=CLIP_MIN)
paddle.nn.clip.ErrorClipByValue(max=CLIP_MAX, min=CLIP_MIN)
)
avg_cost_clip = prog_clip.block(0).var(avg_cost.name)
fluid.backward.append_backward(loss=avg_cost)
fluid.backward.append_backward(
loss=avg_cost_clip, callbacks=[fluid.clip.error_clip_callback]
loss=avg_cost_clip, callbacks=[paddle.nn.clip.error_clip_callback]
)
hidden1_grad = prog.block(0).var(hidden1.name + "@GRAD")
......
......@@ -122,7 +122,7 @@ class TestDistMnist2x2(TestDistRunnerBase):
opt = paddle.optimizer.AdamW(
learning_rate=lr_val,
grad_clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0),
grad_clip=paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0),
)
acc_steps = 2 # accumulated steps for pipeline
......
......@@ -122,7 +122,7 @@ class TestDistMnist2x2(TestDistRunnerBase):
opt = fluid.optimizer.Momentum(
learning_rate=lr_val,
momentum=0.9,
grad_clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0),
grad_clip=paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0),
)
acc_steps = 2 # accumulated steps for pipeline
......
......@@ -354,7 +354,7 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
}
strategy.fuse_all_reduce_ops = True
strategy.fuse_grad_size_in_MB = 32
clip = paddle.fluid.clip.GradientClipByGlobalNorm(1.0)
clip = paddle.nn.ClipGradByGlobalNorm(1.0)
self.optimizer(
avg_cost, strategy, train_prog, startup_prog, grad_clip=clip
......@@ -552,7 +552,7 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
strategy.fuse_all_reduce_ops = True
strategy.fuse_grad_size_in_MB = 32
strategy.fuse_grad_merge = True
clip = paddle.fluid.clip.GradientClipByGlobalNorm(1.0)
clip = paddle.nn.ClipGradByGlobalNorm(1.0)
self.optimizer(
avg_cost, strategy, train_prog, startup_prog, grad_clip=clip
......@@ -940,7 +940,7 @@ class TestFleetHybridOptimizerBoundary(TestFleetMetaOptimizer):
}
strategy.fuse_all_reduce_ops = True
strategy.fuse_grad_size_in_MB = 32
clip = paddle.fluid.clip.GradientClipByGlobalNorm(1.0)
clip = paddle.nn.ClipGradByGlobalNorm(1.0)
self.optimizer(
avg_cost, strategy, train_prog, startup_prog, grad_clip=clip
......@@ -1044,7 +1044,7 @@ class TestFleetHybridOptimizerBoundary(TestFleetMetaOptimizer):
}
strategy.fuse_all_reduce_ops = True
strategy.fuse_grad_size_in_MB = 32
clip = paddle.fluid.clip.GradientClipByGlobalNorm(1.0)
clip = paddle.nn.ClipGradByGlobalNorm(1.0)
self.optimizer(
avg_cost, strategy, train_prog, startup_prog, grad_clip=clip
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册