未验证 提交 ad0e7a26 编写于 作者: 傅剑寒 提交者: GitHub

【CINN】Integate cast_simplify into ir_simplify (#56958)

* integate cast_simplify into ir_simplify

* fix cast simplify testcase
上级 fc71459f
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
#include "paddle/cinn/ir/utils/ir_nodes_collector.h" #include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h" #include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h" #include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/optim/cast_simplify.h"
#include "paddle/cinn/utils/string.h" #include "paddle/cinn/utils/string.h"
namespace cinn { namespace cinn {
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
#include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_mutator.h" #include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h" #include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/cast_simplify.h"
namespace cinn { namespace cinn {
namespace common { namespace common {
...@@ -147,7 +146,7 @@ Expr IndiceToAbsOffset(const std::vector<Expr> &shape, ...@@ -147,7 +146,7 @@ Expr IndiceToAbsOffset(const std::vector<Expr> &shape,
for (int i = 0; i < shape.size(); i++) { for (int i = 0; i < shape.size(); i++) {
CHECK_EQ(shape[i].type(), Int(32)); CHECK_EQ(shape[i].type(), Int(32));
Expr indice_prod = indices[i]; Expr indice_prod = indices[i];
optim::CastSimplify(&indice_prod); optim::SimplifyCast(&indice_prod);
for (int j = i + 1; j < shape.size(); j++) { for (int j = i + 1; j < shape.size(); j++) {
indice_prod = RampRelatedMul(indice_prod, shape[j]); indice_prod = RampRelatedMul(indice_prod, shape[j]);
} }
......
...@@ -23,7 +23,6 @@ gather_srcs( ...@@ -23,7 +23,6 @@ gather_srcs(
compute_inline_expand.cc compute_inline_expand.cc
buffer_assign.cc buffer_assign.cc
replace_const_param_to_integer.cc replace_const_param_to_integer.cc
cast_simplify.cc
lower_intrin.cc lower_intrin.cc
cast_bool_to_int8.cc cast_bool_to_int8.cc
collect_undefined_vars.cc collect_undefined_vars.cc
......
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/optim/cast_simplify.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
namespace cinn::optim {
using cinn::common::bfloat16;
using cinn::common::float16;
namespace {
template <typename CastType, typename T>
CastType NormCastValue(T value) {
if (type_of<CastType>().is_uint() || type_of<T>().is_uint()) {
// not support uint
return static_cast<CastType>(value);
}
if (std::isinf(value)) {
return std::numeric_limits<CastType>::infinity();
} else if (std::isnan(value)) {
return std::numeric_limits<CastType>::signaling_NaN();
} else if (value >= static_cast<T>(std::numeric_limits<CastType>::max())) {
return std::numeric_limits<CastType>::max();
} else if (value <= static_cast<T>(std::numeric_limits<CastType>::lowest())) {
return std::numeric_limits<CastType>::lowest();
}
return static_cast<CastType>(value);
}
struct Mutator : ir::IRMutator<> {
using ir::IRMutator<>::Visit;
void Visit(const ir::Cast* op, Expr* expr) {
auto* node = expr->As<ir::Cast>();
Visit(&node->v(), &node->v());
if (op->type() == op->v().type()) {
*expr = op->v();
return;
}
#define __CAST_TO_TYPE(type__) \
if (auto* i = op->v().As<ir::IntImm>()) { \
*expr = Expr(static_cast<type__>(i->value)); \
} else if (auto* f = op->v().As<ir::FloatImm>()) { \
*expr = Expr(static_cast<type__>(NormCastValue<type__>(f->value))); \
} else if (auto* u = op->v().As<ir::UIntImm>()) { \
*expr = Expr(static_cast<type__>(u->value)); \
} else { \
CINN_NOT_IMPLEMENTED \
}
if (op->v().is_constant()) {
if (op->type() == type_of<int8_t>()) {
__CAST_TO_TYPE(int8_t)
} else if (op->type() == type_of<int16_t>()) {
__CAST_TO_TYPE(int16_t)
} else if (op->type() == type_of<int32_t>()) {
__CAST_TO_TYPE(int32_t)
} else if (op->type() == type_of<int64_t>()) {
__CAST_TO_TYPE(int64_t)
} else if (op->type() == type_of<uint8_t>()) {
__CAST_TO_TYPE(uint8_t)
} else if (op->type() == type_of<uint16_t>()) {
__CAST_TO_TYPE(uint16_t)
} else if (op->type() == type_of<uint32_t>()) {
__CAST_TO_TYPE(uint32_t)
} else if (op->type() == type_of<uint64_t>()) {
__CAST_TO_TYPE(uint64_t)
} else if (op->type() == type_of<float>()) {
__CAST_TO_TYPE(float)
} else if (op->type() == type_of<double>()) {
__CAST_TO_TYPE(double)
} else if (op->type() == type_of<bool>()) {
__CAST_TO_TYPE(bool)
} else if (op->type() == type_of<uint32_t>()) {
__CAST_TO_TYPE(uint32_t)
} else if (op->type() == type_of<uint64_t>()) {
__CAST_TO_TYPE(uint64_t)
} else if (op->type() == type_of<bfloat16>()) {
// Cannot simplify!!! pass
__CAST_TO_TYPE(bfloat16)
} else if (op->type() == type_of<float16>()) {
// Cannot simplify!!! pass
__CAST_TO_TYPE(float16)
} else {
CINN_NOT_IMPLEMENTED
}
}
#undef __CAST_TO_TYPE
}
};
} // namespace
void CastSimplify(Expr* e) {
Mutator mutator;
mutator.Visit(e, e);
}
} // namespace cinn::optim
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/ir/ir.h"
namespace cinn::optim {
/**
* Simplify the Cast nodes.
*
* There are several patterns:
* 1. the source and target type are the same, drop the Cast node
* 2. for intermediate numbers, just replace the Cast node with a Node of the
* target type
*/
void CastSimplify(Expr* e);
} // namespace cinn::optim
...@@ -12,13 +12,11 @@ ...@@ -12,13 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/cinn/optim/cast_simplify.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h" #include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/ir_simplify.h"
namespace cinn::optim { namespace cinn::optim {
TEST(CastSimplify, same_type) { TEST(CastSimplify, same_type) {
...@@ -26,7 +24,7 @@ TEST(CastSimplify, same_type) { ...@@ -26,7 +24,7 @@ TEST(CastSimplify, same_type) {
Expr a = ir::Cast::Make(Int(32), n); Expr a = ir::Cast::Make(Int(32), n);
LOG(INFO) << n->type(); LOG(INFO) << n->type();
LOG(INFO) << a; LOG(INFO) << a;
CastSimplify(&a); SimplifyCast(&a);
ASSERT_EQ(utils::GetStreamCnt(a), "n"); ASSERT_EQ(utils::GetStreamCnt(a), "n");
} }
...@@ -34,7 +32,7 @@ TEST(CastSimplify, Imm_int) { ...@@ -34,7 +32,7 @@ TEST(CastSimplify, Imm_int) {
Expr a = ir::Cast::Make(Int(64), Expr(1)); Expr a = ir::Cast::Make(Int(64), Expr(1));
Expr c = ir::Cast::Make(Int(32), a); Expr c = ir::Cast::Make(Int(32), a);
LOG(INFO) << c; LOG(INFO) << c;
CastSimplify(&c); SimplifyCast(&c);
LOG(INFO) << c; LOG(INFO) << c;
ASSERT_EQ(utils::GetStreamCnt(c), "1"); ASSERT_EQ(utils::GetStreamCnt(c), "1");
ASSERT_EQ(c.type(), Int(32)); ASSERT_EQ(c.type(), Int(32));
...@@ -44,7 +42,7 @@ TEST(CastSimplify, Imm_double) { ...@@ -44,7 +42,7 @@ TEST(CastSimplify, Imm_double) {
Expr a = ir::Cast::Make(Float(64), Expr(2.33)); Expr a = ir::Cast::Make(Float(64), Expr(2.33));
Expr c = ir::Cast::Make(Int(32), a); Expr c = ir::Cast::Make(Int(32), a);
LOG(INFO) << c; LOG(INFO) << c;
CastSimplify(&c); SimplifyCast(&c);
LOG(INFO) << c; LOG(INFO) << c;
ASSERT_EQ(utils::GetStreamCnt(c), "2"); ASSERT_EQ(utils::GetStreamCnt(c), "2");
ASSERT_EQ(c.type(), Int(32)); ASSERT_EQ(c.type(), Int(32));
...@@ -54,7 +52,7 @@ TEST(CastSimplify, Imm_uint) { ...@@ -54,7 +52,7 @@ TEST(CastSimplify, Imm_uint) {
Expr a = ir::Cast::Make(UInt(64), Expr(1)); Expr a = ir::Cast::Make(UInt(64), Expr(1));
Expr c = ir::Cast::Make(UInt(32), a); Expr c = ir::Cast::Make(UInt(32), a);
LOG(INFO) << c; LOG(INFO) << c;
CastSimplify(&c); SimplifyCast(&c);
LOG(INFO) << c; LOG(INFO) << c;
ASSERT_EQ(utils::GetStreamCnt(c), "1"); ASSERT_EQ(utils::GetStreamCnt(c), "1");
ASSERT_EQ(c.type(), UInt(32)); ASSERT_EQ(c.type(), UInt(32));
......
...@@ -29,13 +29,14 @@ ...@@ -29,13 +29,14 @@
#include "paddle/cinn/ir/utils/ir_mutator.h" #include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h" #include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h" #include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/optim/cast_simplify.h"
#include "paddle/cinn/utils/string.h" #include "paddle/cinn/utils/string.h"
namespace cinn { namespace cinn {
namespace optim { namespace optim {
using namespace ir; // NOLINT using namespace ir; // NOLINT
using common::bfloat16;
using common::ExprToGinacConverter; using common::ExprToGinacConverter;
using common::float16;
using utils::GetStreamCnt; using utils::GetStreamCnt;
using utils::Replace; using utils::Replace;
...@@ -346,11 +347,95 @@ struct SimplifyForLoopsMutator : public ir::IRMutator<> { ...@@ -346,11 +347,95 @@ struct SimplifyForLoopsMutator : public ir::IRMutator<> {
} }
}; };
template <typename CastType, typename T>
CastType NormCastValue(T value) {
if (type_of<CastType>().is_uint() || type_of<T>().is_uint()) {
// not support uint
return static_cast<CastType>(value);
}
if (std::isinf(value)) {
return std::numeric_limits<CastType>::infinity();
} else if (std::isnan(value)) {
return std::numeric_limits<CastType>::signaling_NaN();
} else if (value >= static_cast<T>(std::numeric_limits<CastType>::max())) {
return std::numeric_limits<CastType>::max();
} else if (value <= static_cast<T>(std::numeric_limits<CastType>::lowest())) {
return std::numeric_limits<CastType>::lowest();
}
return static_cast<CastType>(value);
}
struct SimplifyCastMutator : public ir::IRMutator<> {
void operator()(Expr* expr) { ir::IRMutator<ir::Expr*>::Visit(expr, expr); }
void Visit(const ir::Cast* op, Expr* expr) {
auto* node = expr->As<ir::Cast>();
ir::IRMutator<ir::Expr*>::Visit(&node->v(), &node->v());
if (op->type() == op->v().type()) {
*expr = op->v();
return;
}
#define __CAST_TO_TYPE(type__) \
if (auto* i = op->v().As<ir::IntImm>()) { \
*expr = Expr(static_cast<type__>(i->value)); \
} else if (auto* f = op->v().As<ir::FloatImm>()) { \
*expr = Expr(static_cast<type__>(NormCastValue<type__>(f->value))); \
} else if (auto* u = op->v().As<ir::UIntImm>()) { \
*expr = Expr(static_cast<type__>(u->value)); \
} else { \
CINN_NOT_IMPLEMENTED \
}
if (op->v().is_constant()) {
if (op->type() == type_of<int8_t>()) {
__CAST_TO_TYPE(int8_t)
} else if (op->type() == type_of<int16_t>()) {
__CAST_TO_TYPE(int16_t)
} else if (op->type() == type_of<int32_t>()) {
__CAST_TO_TYPE(int32_t)
} else if (op->type() == type_of<int64_t>()) {
__CAST_TO_TYPE(int64_t)
} else if (op->type() == type_of<uint8_t>()) {
__CAST_TO_TYPE(uint8_t)
} else if (op->type() == type_of<uint16_t>()) {
__CAST_TO_TYPE(uint16_t)
} else if (op->type() == type_of<uint32_t>()) {
__CAST_TO_TYPE(uint32_t)
} else if (op->type() == type_of<uint64_t>()) {
__CAST_TO_TYPE(uint64_t)
} else if (op->type() == type_of<float>()) {
__CAST_TO_TYPE(float)
} else if (op->type() == type_of<double>()) {
__CAST_TO_TYPE(double)
} else if (op->type() == type_of<bool>()) {
__CAST_TO_TYPE(bool)
} else if (op->type() == type_of<uint32_t>()) {
__CAST_TO_TYPE(uint32_t)
} else if (op->type() == type_of<uint64_t>()) {
__CAST_TO_TYPE(uint64_t)
} else if (op->type() == type_of<bfloat16>()) {
// Cannot simplify!!! pass
__CAST_TO_TYPE(bfloat16)
} else if (op->type() == type_of<float16>()) {
// Cannot simplify!!! pass
__CAST_TO_TYPE(float16)
} else {
CINN_NOT_IMPLEMENTED
}
}
#undef __CAST_TO_TYPE
}
};
} // namespace } // namespace
void Simplify(Expr* expr) { void Simplify(Expr* expr) {
VLOG(3) << "Begin Simplify " << *expr; VLOG(3) << "Begin Simplify " << *expr;
optim::CastSimplify(expr); SimplifyCastMutator()(expr);
SimplifyRampMutator()(expr); SimplifyRampMutator()(expr);
SimplifyLoadMutator()(expr); SimplifyLoadMutator()(expr);
SimplifyStoreMutator()(expr); SimplifyStoreMutator()(expr);
...@@ -363,6 +448,7 @@ void Simplify(Expr* expr) { ...@@ -363,6 +448,7 @@ void Simplify(Expr* expr) {
ReplaceFracWithDivMutator()(expr); ReplaceFracWithDivMutator()(expr);
} }
void SimplifyCast(Expr* expr) { SimplifyCastMutator()(expr); }
void SimplifyForLoops(Expr* expr) { SimplifyForLoopsMutator()(expr); } void SimplifyForLoops(Expr* expr) { SimplifyForLoopsMutator()(expr); }
void SimplifyBlocks(Expr* expr) { SimplifyBlocksMutator()(expr); } void SimplifyBlocks(Expr* expr) { SimplifyBlocksMutator()(expr); }
......
...@@ -30,6 +30,8 @@ namespace optim { ...@@ -30,6 +30,8 @@ namespace optim {
*/ */
void Simplify(Expr *expr); void Simplify(Expr *expr);
void SimplifyCast(Expr *expr);
void SimplifyForLoops(Expr *expr); void SimplifyForLoops(Expr *expr);
void SimplifyBlocks(Expr *expr); void SimplifyBlocks(Expr *expr);
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
#include "paddle/cinn/ir/utils/ir_printer.h" #include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/call_arg_list_to_pod_value.h" #include "paddle/cinn/optim/call_arg_list_to_pod_value.h"
#include "paddle/cinn/optim/cast_bool_to_int8.h" #include "paddle/cinn/optim/cast_bool_to_int8.h"
#include "paddle/cinn/optim/cast_simplify.h"
#include "paddle/cinn/optim/eliminate_broadcast_in_forloop.h" #include "paddle/cinn/optim/eliminate_broadcast_in_forloop.h"
#include "paddle/cinn/optim/extern_call_process.h" #include "paddle/cinn/optim/extern_call_process.h"
#include "paddle/cinn/optim/fold_cinn_call_arguments.h" #include "paddle/cinn/optim/fold_cinn_call_arguments.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册