未验证 提交 011f97bc 编写于 作者: H Huihuang Zheng 提交者: GitHub

【CINN】Remove Remaining Old Schedule, Now We Completely Remove it. (#55566)

Remove the remaining old schedules.
上级 669bcf54
......@@ -32,7 +32,6 @@
#include "paddle/cinn/runtime/flags.h"
DECLARE_bool(auto_schedule_use_cost_model);
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace auto_schedule {
......@@ -70,8 +69,6 @@ class TestAutoTuner : public ::testing::Test {
void SetUp() override {
srand(0);
// AutoTuner is combined with new IR Schedule
FLAGS_cinn_ir_schedule = true;
std::unordered_set<std::string> fetch_ids;
auto program = CreateAddReluProgram();
auto graph = cinn::frontend::Optimize(&program, fetch_ids, target);
......
......@@ -27,8 +27,6 @@
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/runtime/flags.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace auto_schedule {
......@@ -55,7 +53,6 @@ class TestMeasurer : public ::testing::Test {
std::vector<MeasureInput> inputs;
void SetUp() override {
FLAGS_cinn_ir_schedule = true;
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
......
......@@ -40,8 +40,6 @@
#include "paddle/cinn/utils/string.h"
#include "test/cpp/cinn/concrete_program_builder.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace auto_schedule {
......@@ -155,7 +153,6 @@ TEST(AutoInline, AddReluInline) {
frontend::Program program = builder.Build();
FLAGS_cinn_ir_schedule = true;
auto graph = std::make_shared<Graph>(program, target);
hlir::framework::ApplyPass(graph.get(), "OpFusionPass");
......
......@@ -29,7 +29,6 @@
#include "paddle/cinn/utils/type_defs.h"
DECLARE_bool(auto_schedule_use_cost_model);
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace auto_schedule {
......@@ -70,7 +69,6 @@ std::shared_ptr<hlir::framework::Graph> CreateAddProgram(
TEST(TestTaskRegistry, basic) {
FLAGS_auto_schedule_use_cost_model = true;
FLAGS_cinn_ir_schedule = true;
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
......
......@@ -35,8 +35,6 @@
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/utils/string.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace auto_schedule {
......@@ -59,8 +57,6 @@ Program CreateAddProgram() {
}
TEST(TuneTask, GraphToUnoptLoweredFunc_NoPass) {
// Auto tuner is combined with IR schedule
FLAGS_cinn_ir_schedule = true;
Context::Global().ResetNameId();
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
......@@ -170,8 +166,6 @@ TEST(TuneTask, GraphToUnoptLoweredFunc_NoPass) {
}
TEST(TuneTask, GraphToUnoptLoweredFunc_ApplyPass) {
// Auto tuner is combined with IR schedule
FLAGS_cinn_ir_schedule = true;
Context::Global().ResetNameId();
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
......
......@@ -30,8 +30,6 @@
#include "paddle/cinn/poly/stage.h"
#include "paddle/cinn/utils/profiler.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace hlir {
namespace framework {
......
......@@ -19,7 +19,6 @@
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h"
DECLARE_bool(cinn_ir_schedule);
DECLARE_bool(cinn_use_cuda_vectorize);
namespace cinn {
......@@ -52,32 +51,28 @@ std::vector<ir::LoweredFunc> OpLowerer::Lower(const GroupPtr& group,
<< " , Op Pattern : " << group->op_pattern_kind;
group->input_names.clear();
group->output_names.clear();
if (FLAGS_cinn_ir_schedule) {
switch (group->op_pattern_kind) {
case framework::kElementWise:
case framework::kBroadcast:
case framework::kInjective:
return LowerGroup(group,
apply_op_schedule,
apply_group_schedule,
&OpLowerer::ElementwiseScheduleDetermineFunction);
case framework::kReduction:
return LowerGroup(group,
apply_op_schedule,
apply_group_schedule,
&OpLowerer::ReduceScheduleDetermineFunction);
case framework::kOutFusible:
LOG(FATAL) << "Group Pattern Kind kOutFusible Is Not Implemented!";
case framework::kNonFusible:
return LowerGroup(group,
apply_op_schedule,
apply_group_schedule,
&OpLowerer::NonFusibleScheduleDetermineFunction);
default:
LOG(FATAL) << "Group Pattern Kind Is Unknown!";
}
} else {
LOG(FATAL) << "Previous IR Schedule Is Not Implemented!";
switch (group->op_pattern_kind) {
case framework::kElementWise:
case framework::kBroadcast:
case framework::kInjective:
return LowerGroup(group,
apply_op_schedule,
apply_group_schedule,
&OpLowerer::ElementwiseScheduleDetermineFunction);
case framework::kReduction:
return LowerGroup(group,
apply_op_schedule,
apply_group_schedule,
&OpLowerer::ReduceScheduleDetermineFunction);
case framework::kOutFusible:
LOG(FATAL) << "Group Pattern Kind kOutFusible Is Not Implemented!";
case framework::kNonFusible:
return LowerGroup(group,
apply_op_schedule,
apply_group_schedule,
&OpLowerer::NonFusibleScheduleDetermineFunction);
default:
LOG(FATAL) << "Group Pattern Kind Is Unknown!";
}
}
......
......@@ -27,8 +27,6 @@
#include "paddle/cinn/hlir/pe/broadcast.h"
#include "paddle/cinn/runtime/flags.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace hlir {
namespace framework {
......@@ -57,35 +55,19 @@ TEST(Operator, GetAttrs) {
std::string func_name = "add1";
if (FLAGS_cinn_ir_schedule) {
std::string out_name = "C";
common::CINNValuePack cinn_input =
common::CINNValuePack{{common::CINNValue(A),
common::CINNValue(B),
common::CINNValue(out_name)}};
std::vector<std::string> input_output_names{"A", "B", out_name};
std::string out_name = "C";
common::CINNValuePack cinn_input =
common::CINNValuePack{{common::CINNValue(A),
common::CINNValue(B),
common::CINNValue(out_name)}};
std::vector<std::string> input_output_names{"A", "B", out_name};
auto funcs = framework::GetFuncFromImpl(
impl, cinn_input, inputs, input_output_names, func_name, target);
auto funcs = framework::GetFuncFromImpl(
impl, cinn_input, inputs, input_output_names, func_name, target);
for (auto func : funcs) {
LOG(INFO) << "Test Operator_ElementWise_Add_Test0's Strategy, func is :\n"
<< func;
}
} else {
common::CINNValuePack cinn_input =
common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B)}};
common::CINNValuePack rets = impl->fcompute(cinn_input);
ASSERT_EQ(rets.size(), 2UL);
rets = impl->fschedule(rets);
ASSERT_EQ(rets.size(), 2UL);
// the last element is a StageMap
for (int i = 0; i < rets->size() - 1; i++) {
ir::Expr temp = rets[i];
inputs.push_back(temp.as_tensor_ref());
}
auto func = Lower(func_name, rets.back(), inputs);
LOG(INFO) << "Test Strategy Codegen:\n" << func;
for (auto func : funcs) {
LOG(INFO) << "Test Operator_ElementWise_Add_Test0's Strategy, func is :\n"
<< func;
}
}
......
......@@ -26,8 +26,6 @@
#include "paddle/cinn/ir/layout.h"
#include "paddle/cinn/ir/op/ir_operators.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace hlir {
namespace op {
......
......@@ -35,8 +35,6 @@
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace hlir {
namespace op {
......
......@@ -35,8 +35,6 @@
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace hlir {
namespace op {
......
......@@ -38,7 +38,6 @@
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace hlir {
......
......@@ -38,8 +38,6 @@
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace hlir {
namespace op {
......
......@@ -38,7 +38,6 @@
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace hlir {
......
......@@ -39,8 +39,6 @@
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace hlir {
namespace op {
......
......@@ -38,8 +38,6 @@
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace hlir {
namespace op {
......
......@@ -38,8 +38,6 @@
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace hlir {
namespace op {
......@@ -163,12 +161,9 @@ std::shared_ptr<framework::OpStrategy> StrategyForRepeat(
auto tensor_A = A.as_tensor_ref();
VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ")
<< ", output_shapes: " << utils::Join(output_shapes[0], ", ");
std::string tensor_name = common::UniqName("T_Repeat_out");
if (FLAGS_cinn_ir_schedule) {
CHECK_EQ(pack_args.size(), 2U);
tensor_name = pack_args[1].operator std::string();
}
CHECK_EQ(pack_args.size(), 2U);
std::string tensor_name = pack_args[1].operator std::string();
std::vector<ir::Tensor> out = Repeat(tensor_A, repeats, axis, tensor_name);
CHECK(out.size() == 1U) << "The size of Repeat's output should be 1";
......@@ -186,44 +181,34 @@ std::shared_ptr<framework::OpStrategy> StrategyForRepeat(
framework::CINNSchedule repeat_schedule([=](lang::Args args,
lang::RetValue *ret) {
if (FLAGS_cinn_ir_schedule) {
CHECK(!args.empty())
<< "The input argument of repeat schedule is empty! Please check.\n";
common::CINNValuePack arg_pack = args[0];
std::vector<Expr> vec_ast;
for (int i = 0; i < arg_pack.size(); i++) {
if (arg_pack[i].is_expr()) {
Expr temp = arg_pack[i];
vec_ast.emplace_back(temp);
}
CHECK(!args.empty())
<< "The input argument of repeat schedule is empty! Please check.\n";
common::CINNValuePack arg_pack = args[0];
std::vector<Expr> vec_ast;
for (int i = 0; i < arg_pack.size(); i++) {
if (arg_pack[i].is_expr()) {
Expr temp = arg_pack[i];
vec_ast.emplace_back(temp);
}
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
int64_t prod_size = std::accumulate(output_shapes[0].begin(),
output_shapes[0].end(),
1,
std::multiplies<int>());
if (prod_size > 1) {
if (target.arch == Target::Arch::NVGPU) {
pe::IRCudaScheduleInjective(ir_sch, output_shapes.front(), target);
} else if (target.arch == Target::Arch::X86) {
pe::IRScheduleInjectiveCPU(
ir_sch, output_shapes.front(), target, true);
}
}
CHECK(!vec_ast.empty());
ir::ModuleExpr mod_expr(vec_ast);
ir::IRSchedule ir_sch(mod_expr);
ir_sch.MergeExprs();
int64_t prod_size = std::accumulate(output_shapes[0].begin(),
output_shapes[0].end(),
1,
std::multiplies<int>());
if (prod_size > 1) {
if (target.arch == Target::Arch::NVGPU) {
pe::IRCudaScheduleInjective(ir_sch, output_shapes.front(), target);
} else if (target.arch == Target::Arch::X86) {
pe::IRScheduleInjectiveCPU(ir_sch, output_shapes.front(), target, true);
}
std::vector<common::CINNValue> res{
common::CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = common::CINNValuePack{res};
} else {
CHECK(!args.empty())
<< "The input argument of repeat schedule is empty! Please check.\n";
CINNValuePack arg_pack = args[0];
Expr out = arg_pack[0];
CHECK(out.as_tensor());
*ret = arg_pack;
}
std::vector<common::CINNValue> res{
common::CINNValue(ir_sch.GetModule().GetExprs().at(0))};
*ret = common::CINNValuePack{res};
});
auto strategy = std::make_shared<framework::OpStrategy>();
......
......@@ -37,8 +37,6 @@
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace hlir {
namespace op {
......
......@@ -38,8 +38,6 @@
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace hlir {
namespace op {
......
......@@ -27,8 +27,6 @@
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/utils/functional.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace hlir {
namespace op {
......
......@@ -29,8 +29,6 @@
#include "paddle/cinn/hlir/pe/broadcast.h"
#include "paddle/cinn/runtime/flags.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace hlir {
namespace framework {
......
......@@ -30,8 +30,6 @@
#include "paddle/cinn/hlir/pe/nn.h"
#include "paddle/cinn/runtime/flags.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace hlir {
namespace framework {
......@@ -49,33 +47,17 @@ Module LowerToModule(const std::string test_name,
const Target &target) {
Module::Builder builder("module", target);
if (FLAGS_cinn_ir_schedule) {
cinn_inputs.emplace_back(output_name);
common::CINNValuePack cinn_input = common::CINNValuePack{cinn_inputs};
input_names.push_back(output_name);
cinn_inputs.emplace_back(output_name);
common::CINNValuePack cinn_input = common::CINNValuePack{cinn_inputs};
input_names.push_back(output_name);
auto funcs = framework::GetFuncFromImpl(
impl, cinn_input, inputs, input_names, func_name, target);
for (auto func : funcs) {
LOG(INFO) << "Test" << test_name << "'s Strategy, func is :\n" << func;
builder.AddFunction(func);
}
} else {
common::CINNValuePack cinn_input = common::CINNValuePack{cinn_inputs};
common::CINNValuePack rets = impl->fcompute(cinn_input);
rets = impl->fschedule(rets);
// the last element is a StageMap
for (int i = 0; i < rets->size() - 1; i++) {
Expr temp = rets[i];
inputs.push_back(temp.as_tensor_ref());
}
auto func = Lower("fn_" + func_name, rets.back(), inputs);
LOG(INFO) << "Test Strategy Codegen:\n" << func;
auto funcs = framework::GetFuncFromImpl(
impl, cinn_input, inputs, input_names, func_name, target);
for (auto func : funcs) {
LOG(INFO) << "Test" << test_name << "'s Strategy, func is :\n" << func;
builder.AddFunction(func);
}
return builder.Build();
}
......
......@@ -39,7 +39,6 @@
#include "paddle/cinn/hlir/pe/nn.h"
#include "paddle/cinn/runtime/cinn_runtime.h"
#include "paddle/cinn/runtime/cuda/cuda_module.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace hlir {
namespace framework {
......@@ -94,34 +93,14 @@ std::pair<ir::Module, std::string> GenReduceCode(
strategy(attrs, inputs, out_type, {output_shape}, target));
std::vector<ir::LoweredFunc> func;
if (!FLAGS_cinn_ir_schedule) {
common::CINNValuePack cinn_input =
common::CINNValuePack{{common::CINNValue(X)}};
common::CINNValuePack rets = impl->fcompute(cinn_input);
rets = impl->fschedule(rets);
poly::StageMap stages = rets.back();
// the last element is a StageMap
for (int i = 0; i < rets->size() - 1; i++) {
Expr temp = rets[i];
if (!temp.as_tensor_ref()->buffer.defined() &&
!stages[temp.as_tensor_ref()]->inlined()) {
inputs.push_back(temp.as_tensor_ref());
}
}
func =
lang::LowerVec(func_name, rets.back(), inputs, {}, {}, nullptr, target);
} else {
std::vector<std::string> input_output_nodes{"X", op_name};
func = GetFuncFromImpl(impl,
common::CINNValuePack{{common::CINNValue(X),
common::CINNValue(op_name)}},
inputs,
input_output_nodes,
func_name,
target);
}
std::vector<std::string> input_output_nodes{"X", op_name};
func = GetFuncFromImpl(
impl,
common::CINNValuePack{{common::CINNValue(X), common::CINNValue(op_name)}},
inputs,
input_output_nodes,
func_name,
target);
Module::Builder builder(func_name + "_builder", target);
for (auto& f : func) {
......@@ -139,11 +118,7 @@ std::pair<ir::Module, std::string> GenReduceCode(
backends::CodeGenCUDA_Dev codegen(target);
std::string source_code;
if (!FLAGS_cinn_ir_schedule) {
source_code = codegen.Compile(builder.Build());
} else {
source_code = codegen.Compile(device_module);
}
source_code = codegen.Compile(device_module);
// LOG(INFO) << "compiled code:\n" << device_module;
return std::pair<ir::Module, std::string>(host_module, source_code);
......@@ -385,18 +360,12 @@ void TestCaseForReduce(const float init_val,
dev_x, buffer_x->memory, buffer_x->memory_size, cudaMemcpyHostToDevice));
dim3 grid;
dim3 block;
if (!FLAGS_cinn_ir_schedule) {
grid = {n * c, 1, 1};
block = {h * w, 1, 1};
} else {
grid = {c, 1, 1};
int block_dim_x = n * w * h > 1024 ? 1024 : n * w * h;
block = {block_dim_x, 1, 1};
}
grid = {c, 1, 1};
int block_dim_x = n * w * h > 1024 ? 1024 : n * w * h;
block = {block_dim_x, 1, 1};
void* args[] = {&dev_x, &dev_z};
std::string new_test_name = test_name;
if (FLAGS_cinn_ir_schedule) new_test_name = "fn_" + new_test_name + "_kernel";
std::string new_test_name = "fn_" + test_name + "_kernel";
cuda_module.LaunchKernel(0, new_test_name, grid, block, args);
CUDA_CALL(cudaMemcpy(
buffer_z->memory, dev_z, buffer_z->memory_size, cudaMemcpyDeviceToHost));
......@@ -458,8 +427,7 @@ TEST(Operator, Operator_Reduction_Case_7) {
CUDA_CALL(cudaSetDevice(0));
runtime::cuda::CUDAModule cuda_module(ptx,
runtime::cuda::CUDAModule::Kind::PTX);
std::string new_func_name = func_name;
if (FLAGS_cinn_ir_schedule) new_func_name = "fn_" + new_func_name;
std::string new_func_name = "fn_" + func_name;
void* reduce_sum_kernel =
cuda_module.GetFunction(0, new_func_name + "_kernel");
CHECK(reduce_sum_kernel);
......
......@@ -28,8 +28,6 @@
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/utils/string.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace hlir {
namespace op {
......
......@@ -40,8 +40,6 @@
#include "paddle/cinn/runtime/cuda/cuda_module.h"
#include "paddle/cinn/runtime/flags.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace hlir {
namespace framework {
......
......@@ -30,7 +30,6 @@
#include "paddle/cinn/optim/tensor_write_tell.h"
#include "paddle/cinn/runtime/intrinsic.h"
#include "paddle/cinn/utils/string.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace ir {
......@@ -58,8 +57,7 @@ LoweredFunc _LoweredFunc_::Make(const std::string& name,
n->PrepareCreateTempBufferExprs();
n->PrepareAllocTempBufferExprs();
n->AllocTempBuffer();
bool with_expr_gen_tensor = true;
if (FLAGS_cinn_ir_schedule) with_expr_gen_tensor = false;
bool with_expr_gen_tensor = false;
n->PrepareBufferCastExprs(with_expr_gen_tensor);
n->PrepareArgumentExprs();
n->PrepareDeallocTempBufferExprs();
......
......@@ -37,8 +37,6 @@
#include "paddle/cinn/optim/unroll_loops.h"
#include "paddle/cinn/optim/vectorize_loops.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace optim {
......@@ -60,7 +58,7 @@ Expr Optimize(Expr e,
VectorizeLoops(&copied, target);
VLOG(4) << "After Optimize VectorizeLoops:" << copied;
#ifdef CINN_WITH_CUDA
if (FLAGS_cinn_ir_schedule && copied.as_lowered_func()) {
if (copied.as_lowered_func()) {
ir::SetCudaAxisInfo(&copied);
}
if (remove_gpu_for_loops) {
......@@ -93,10 +91,8 @@ Expr Optimize(Expr e,
ir::Module Optimize(const ir::Module& module, const Target& target) {
auto copied = IRCopy(Expr(module));
if (FLAGS_cinn_ir_schedule) {
UnrollLoop(&copied);
VectorizeLoops(&copied, Target());
}
UnrollLoop(&copied);
VectorizeLoops(&copied, Target());
VLOG(10) << "After VectorizeLoops:" << copied.as_module_ref();
RemoveScheduleBlock(&copied);
VLOG(10) << "After RemoveScheduleBlock:" << copied.as_module_ref();
......
......@@ -28,8 +28,6 @@
#include "paddle/cinn/pybind/bind.h"
#include "paddle/cinn/runtime/flags.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn::pybind {
namespace py = pybind11;
......@@ -64,41 +62,23 @@ void BindFramework(pybind11::module *m) {
}
ir::LoweredFunc func;
if (FLAGS_cinn_ir_schedule) {
std::string output_name = "out";
temp_inputs.emplace_back(output_name);
std::vector<std::string> input_output_names;
for (const auto &input : inputs) {
input_output_names.push_back(input->name);
}
input_output_names.push_back(output_name);
std::vector<ir::LoweredFunc> funcs =
hlir::framework::GetFuncFromImpl(
impl,
common::CINNValuePack{temp_inputs},
res,
input_output_names,
key,
target);
CHECK_EQ(funcs.size(), 1U);
func = funcs[0];
} else {
common::CINNValuePack C =
impl->fcompute(common::CINNValuePack{temp_inputs});
poly::StageMap stages = C.back();
// make sure all the tensors in the stages before schedule
// launch.
for (int i = 0; i < C->size() - 1; i++) {
ir::Expr temp = C[i];
stages->InsertLazily(temp.as_tensor_ref());
}
C = impl->fschedule(C);
for (int i = 0; i < C->size() - 1; i++) {
ir::Expr temp = C[i];
res.push_back(temp.as_tensor_ref());
}
func = Lower(key, stages, res);
std::string output_name = "out";
temp_inputs.emplace_back(output_name);
std::vector<std::string> input_output_names;
for (const auto &input : inputs) {
input_output_names.push_back(input->name);
}
input_output_names.push_back(output_name);
std::vector<ir::LoweredFunc> funcs =
hlir::framework::GetFuncFromImpl(
impl,
common::CINNValuePack{temp_inputs},
res,
input_output_names,
key,
target);
CHECK_EQ(funcs.size(), 1U);
func = funcs[0];
return func;
});
......
......@@ -89,10 +89,6 @@ DEFINE_bool(cinn_use_cuda_vectorize,
BoolFromEnv("FLAGS_cinn_use_cuda_vectorize", false),
"Whether use cuda vectroize on schedule config");
DEFINE_bool(cinn_ir_schedule,
BoolFromEnv("FLAGS_cinn_ir_schedule", true),
"Whether use reconstructed schedule primitives.");
DEFINE_bool(use_reduce_split_pass,
BoolFromEnv("FLAGS_use_reduce_split_pass", false),
"Whether use reduce split pass.");
......
......@@ -24,8 +24,6 @@
#include "paddle/cinn/runtime/flags.h"
#include "paddle/cinn/utils/timer.h"
DECLARE_bool(cinn_ir_schedule);
namespace cinn {
namespace tests {
using ir::Tensor;
......@@ -87,117 +85,88 @@ Module OpBenchmarkTester::CreateCinnModule(
auto impl = hlir::framework::OpStrategy::SelectImpl(
strategy[op](attrs, input_tensors, out_types, input_shapes_, target_));
if (FLAGS_cinn_ir_schedule) {
std::string output_name = "out";
std::vector<common::CINNValue> temp_inputs;
std::vector<ir::Tensor> all_arg_tensors;
std::vector<std::string> input_output_names;
for (const auto& tensor : input_tensors) {
temp_inputs.emplace_back(tensor);
all_arg_tensors.push_back(tensor);
input_output_names.push_back(tensor->name);
}
temp_inputs.emplace_back(output_name);
common::CINNValuePack cinn_inputs = common::CINNValuePack{temp_inputs};
input_output_names.push_back(output_name);
// 1.Call Op's Compute function, using the default stages and LowerVec to
// get IR tree.
common::CINNValuePack C = impl->fcompute(cinn_inputs);
// 2. Collect tensors and arguments
// Add output tensors to all_arg_tensors
for (int i = 0; i < C->size() - 1; i++) {
ir::Expr temp = C[i];
// checkout whether the tensor is with buffer.
if (!temp.as_tensor_ref()->buffer.defined() ||
target_ != common::DefaultNVGPUTarget()) {
all_arg_tensors.push_back(temp.as_tensor_ref());
}
std::string output_name = "out";
std::vector<common::CINNValue> temp_inputs;
std::vector<ir::Tensor> all_arg_tensors;
std::vector<std::string> input_output_names;
for (const auto& tensor : input_tensors) {
temp_inputs.emplace_back(tensor);
all_arg_tensors.push_back(tensor);
input_output_names.push_back(tensor->name);
}
temp_inputs.emplace_back(output_name);
common::CINNValuePack cinn_inputs = common::CINNValuePack{temp_inputs};
input_output_names.push_back(output_name);
// 1.Call Op's Compute function, using the default stages and LowerVec to
// get IR tree.
common::CINNValuePack C = impl->fcompute(cinn_inputs);
// 2. Collect tensors and arguments
// Add output tensors to all_arg_tensors
for (int i = 0; i < C->size() - 1; i++) {
ir::Expr temp = C[i];
// checkout whether the tensor is with buffer.
if (!temp.as_tensor_ref()->buffer.defined() ||
target_ != common::DefaultNVGPUTarget()) {
all_arg_tensors.push_back(temp.as_tensor_ref());
}
}
stages = C.back();
auto funcs = lang::LowerVec(
op_name_, stages, all_arg_tensors, {}, {}, nullptr, target_, true);
stages = C.back();
auto funcs = lang::LowerVec(
op_name_, stages, all_arg_tensors, {}, {}, nullptr, target_, true);
std::vector<common::CINNValue> schedule_inputs;
for (int i = 0; i < C.size() - 1; ++i) {
CHECK(C[i].is_tensor());
schedule_inputs.push_back(common::CINNValue(C[i]));
}
for (auto& f : funcs) {
schedule_inputs.push_back(common::CINNValue(f->body));
}
std::vector<common::CINNValue> schedule_inputs;
for (int i = 0; i < C.size() - 1; ++i) {
CHECK(C[i].is_tensor());
schedule_inputs.push_back(common::CINNValue(C[i]));
}
for (auto& f : funcs) {
schedule_inputs.push_back(common::CINNValue(f->body));
}
// 3. Call Op's Schedule function, optimizing the IR tree by new IR
// schedule
common::CINNValuePack expr_pack =
impl->fschedule(common::CINNValuePack{schedule_inputs});
// 3. Call Op's Schedule function, optimizing the IR tree by new IR
// schedule
common::CINNValuePack expr_pack =
impl->fschedule(common::CINNValuePack{schedule_inputs});
// 4. Optimize the LoweredFunc
std::vector<ir::LoweredFunc> res;
for (int i = 0; i < expr_pack.size(); i++) {
// 4. Optimize the LoweredFunc
std::vector<ir::LoweredFunc> res;
for (int i = 0; i < expr_pack.size(); i++) {
#ifdef CINN_WITH_CUDA
optim::OptimizeExprGPU(&(funcs[i]->body));
optim::OptimizeExprGPU(&(funcs[i]->body));
#endif
if (funcs.size() > expr_pack.size()) {
auto new_args = lang::GetArgs(funcs[i]->body, input_output_names);
funcs[i]->args = new_args;
}
auto temp_buffers =
lang::GetTempBuffers(all_arg_tensors, stages, funcs[i]->body);
funcs[i]->temp_bufs = temp_buffers;
funcs[i]->PrepareBufferCastExprs();
res.push_back(funcs[i]);
}
for (int i = 0; i < res.size(); i++) {
res[i] = optim::Optimize(Expr(funcs[i]), target_, false)
.as_lowered_func_ref();
if (funcs.size() > expr_pack.size()) {
auto new_args = lang::GetArgs(funcs[i]->body, input_output_names);
funcs[i]->args = new_args;
}
auto temp_buffers =
lang::GetTempBuffers(all_arg_tensors, stages, funcs[i]->body);
funcs[i]->temp_bufs = temp_buffers;
funcs[i]->PrepareBufferCastExprs();
res.push_back(funcs[i]);
}
for (int i = 0; i < res.size(); i++) {
res[i] =
optim::Optimize(Expr(funcs[i]), target_, false).as_lowered_func_ref();
}
for (auto func : res) {
builder.AddFunction(func);
for (auto func : res) {
builder.AddFunction(func);
for (const auto& arg : func->args) {
std::vector<int> output_shape;
if (arg.io == ir::Argument::IO::kOutput) {
for (auto& shape_dim : arg.buffer_arg()->shape) {
LOG(INFO) << shape_dim << ",";
CHECK(shape_dim.is_constant());
output_shape.push_back(
static_cast<int>(shape_dim.get_constant()));
}
output_shapes_.push_back(output_shape);
break;
}
}
}
} else {
std::vector<common::CINNValue> temp_inputs;
for (auto& tensor : input_tensors) {
temp_inputs.push_back(common::CINNValue(tensor));
}
common::CINNValuePack C =
impl->fcompute(common::CINNValuePack(temp_inputs));
stages = C.back();
C = impl->fschedule(C);
for (int i = 0; i < C->size() - 1; i++) {
ir::Expr temp = C[i];
stages->InsertLazily(temp.as_tensor_ref());
std::vector<Expr> output_shape_expr =
temp.as_tensor_ref()->domain_without_reduce_axis();
for (const auto& arg : func->args) {
std::vector<int> output_shape;
for (auto& shape : output_shape_expr) {
LOG(INFO) << shape;
output_shape.push_back(common::AutoSimplify(shape).as_int32());
if (arg.io == ir::Argument::IO::kOutput) {
for (auto& shape_dim : arg.buffer_arg()->shape) {
LOG(INFO) << shape_dim << ",";
CHECK(shape_dim.is_constant());
output_shape.push_back(static_cast<int>(shape_dim.get_constant()));
}
output_shapes_.push_back(output_shape);
break;
}
output_shapes_.push_back(output_shape);
rets.push_back(temp.as_tensor_ref());
}
auto func = Lower(op_name_, stages, rets);
LOG(INFO) << "After Lower, func is: \n" << func;
builder.AddFunction(func);
}
} else {
stages = CreateStages(input_tensors);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册