未验证 提交 3af47711 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add the detection and code-generation of sqrt and square in fusion_group (#23095)

上级 d066d6f9
......@@ -25,16 +25,22 @@ namespace ir {
namespace fusion_group {
std::string ExtractDataType(const std::vector<Node*>& nodes) {
std::string dtype_str = "float";
auto data_type = nodes.back()->Var()->GetDataType();
if (data_type == proto::VarType::FP32) {
std::string dtype_str = "";
for (const auto* n : nodes) {
if (n && n->IsVar() && n->Var()) {
// The data type of all inputs/outputs must be the same, which are
// checked when detecting the subgraph.
auto dtype = n->Var()->GetDataType();
if (dtype == proto::VarType::FP32) {
dtype_str = "float";
} else if (data_type == proto::VarType::FP64) {
} else if (dtype == proto::VarType::FP64) {
dtype_str = "double";
} else if (data_type == proto::VarType::FP16) {
} else if (dtype == proto::VarType::FP16) {
dtype_str = "float16";
}
break;
}
}
return dtype_str;
}
......@@ -80,7 +86,6 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
for (auto& name : input_names) {
// Some input vars are not used in grad ops, such as
// "elementwise_add_grad", where "X", "Y" and "Out" are not used.
if ((HasInput(node, name) && op->Input(name).size() >= 1U)) {
for (size_t i = 0; i < op->Input(name).size(); i++) {
PADDLE_ENFORCE_NE(
......
......@@ -38,13 +38,13 @@ static std::string ExpandMultivariateTemplate(const std::string rhs,
int start_pos = rhs.find("[", 0);
int end_pos = rhs.find("]", 0);
std::string sum_rhs = rhs.substr(0, start_pos);
std::string sum_rhs_component =
std::string repeated_component =
rhs.substr(start_pos + 1, (end_pos - start_pos - 1));
int replace_pos = sum_rhs_component.find("?", 0);
int replace_pos = repeated_component.find("?", 0);
for (size_t i = 1; i < input_size; i++) {
std::string append_str =
sum_rhs_component.replace(replace_pos, 1, std::to_string(i));
std::string append_str = repeated_component;
append_str.replace(replace_pos, 1, std::to_string(i));
sum_rhs = sum_rhs + append_str;
}
return sum_rhs;
......
......@@ -20,20 +20,26 @@ namespace ir {
namespace fusion_group {
static constexpr char predefined_cuda_functions_fp32[] = R"(
__device__ inline float real_exp(float x) { return ::expf(x); }
__device__ inline float real_log(float x) { return ::logf(x); }
__device__ inline float Max(float x, float y) { return fmaxf(x, y); }
__device__ inline float Exp(float x) { return expf(x); }
__device__ inline float Log(float x) { return logf(x); }
__device__ inline float Sqrt(float x) { return sqrtf(x); }
)";
static constexpr char predefined_cuda_functions_fp64[] = R"(
__device__ inline double real_exp(double x) { return ::exp(x); }
__device__ inline double real_log(double x) { return ::log(x); }
__device__ inline double Max(double x, double y) { return fmax(x, y); }
__device__ inline double Exp(double x) { return exp(x); }
__device__ inline double Log(double x) { return log(x); }
__device__ inline double Sqrt(double x) { return sqrt(x); }
)";
static constexpr char predefined_cuda_functions_fp16[] = R"(
__device__ inline float real_exp(float x) { return ::expf(x); }
__device__ inline float real_log(float x) { return ::logf(x); }
__device__ inline float Max(float x, float y) { return fmaxf(x, y); }
__device__ inline float Exp(float x) { return expf(x); }
__device__ inline float Log(float x) { return logf(x); }
__device__ inline float Sqrt(float x) { return sqrtf(x); }
#define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var)))
#define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var)))
......
......@@ -60,52 +60,41 @@ static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l,
return l.size() != 0U && r.size() != 0U && l == r;
}
bool GroupDetector::IsFusionGroupOp(const Node* n) {
if (!(n && n->IsOp() && n->Op())) return false;
bool GroupDetector::CheckPrecondition(const Node* n) {
auto check_data_type = [&](const std::vector<Node*>& nodes) -> bool {
bool is_first = true;
proto::VarType::Type i_data_type = proto::VarType::FP32;
proto::VarType::Type o_data_type = proto::VarType::FP32;
for (auto* i_node : n->inputs) {
if (!i_node->Var()) return false;
if (i_node->Var()->GetType() != proto::VarType::LOD_TENSOR) {
proto::VarType::Type data_type_0;
for (auto* n : nodes) {
if (n && n->IsVar() && n->Var()) {
if (n->Var()->GetType() != proto::VarType::LOD_TENSOR) {
return false;
}
proto::VarType::Type data_type_i = n->Var()->GetDataType();
if (data_type_i == proto::VarType::FP32 ||
data_type_i == proto::VarType::FP64 ||
data_type_i == proto::VarType::FP16) {
if (is_first) {
i_data_type = i_node->Var()->GetDataType();
data_type_0 = data_type_i;
is_first = false;
} else {
if (i_data_type != i_node->Var()->GetDataType()) return false;
}
}
is_first = true;
for (auto* o_node : n->outputs) {
if (!o_node->Var()) return false;
if (o_node->Var()->GetType() != proto::VarType::LOD_TENSOR) {
} else if (data_type_0 != data_type_i) {
return false;
}
if (is_first) {
o_data_type = o_node->Var()->GetDataType();
is_first = false;
} else {
if (o_data_type != o_node->Var()->GetDataType()) return false;
return false;
}
}
}
if (!(i_data_type == proto::VarType::FP32 ||
i_data_type == proto::VarType::FP64 ||
i_data_type == proto::VarType::FP16) ||
!(o_data_type == proto::VarType::FP32 ||
o_data_type == proto::VarType::FP64 ||
o_data_type == proto::VarType::FP16))
return false;
return true;
};
return n && n->IsOp() && n->Op() && check_data_type(n->inputs) &&
check_data_type(n->outputs);
}
bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
if (IsSpecifiedOp(GetElementwiseOpTypes(), n)) {
// Check whether all inputs have the same shape.
std::vector<int64_t> shape_0;
for (size_t i = 0; i < n->inputs.size(); ++i) {
auto* in_i = n->inputs[i];
......@@ -130,7 +119,7 @@ bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
std::vector<std::vector<Node*>> ElementwiseGroupDetector::operator()(
Graph* graph) {
auto teller = [&](const Node* n) -> bool {
return IsFusionGroupOp(n) && IsElementwiseOp(n);
return CheckPrecondition(n) && IsElementwiseOp(n);
};
return SubgraphDetector(graph, teller)();
......
......@@ -25,7 +25,7 @@ namespace fusion_group {
class GroupDetector {
protected:
bool IsFusionGroupOp(const Node* n);
bool CheckPrecondition(const Node* n);
};
class ElementwiseGroupDetector : GroupDetector {
......
......@@ -33,6 +33,8 @@ void FusionGroupPass::ApplyImpl(ir::Graph* graph) const {
fusion_group::OperationMap::Init();
int num_elementwise_groups = DetectFusionGroup(graph, 0);
AddStatis(num_elementwise_groups);
LOG(INFO) << "Detect " << num_elementwise_groups
<< " elementwise fusion groups.";
}
}
......@@ -54,7 +56,7 @@ int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const {
VLOG(3) << "subgraph: {\n" << DebugString(subgraph.SortedNodes()) << "}\n";
if (subgraph.IsValid(min_subgraph_size)) {
subgraph.SetFuncName("fused_elementwise_" + std::to_string(index++));
subgraph.SetFuncName("FusedElementwise" + std::to_string(index++));
if (GenerateCode(&subgraph)) {
InsertFusionGroupOp(graph, &subgraph);
num_subgraphs++;
......
......@@ -95,20 +95,29 @@ void OperationMap::InsertUnaryElementwiseOperations() {
// sigmoid:
// out = f(x) = 1.0 / (1.0 + exp(-x))
// dx = dout * out * (1 - out)
insert_handler("sigmoid", "1.0 / (1.0 + real_exp(- ${0}))",
insert_handler("sigmoid", "1.0 / (1.0 + Exp(- ${0}))",
{"${2} * ${1} * (1.0 - ${1})"});
// tanh:
// out = f(x) = 2.0 / (1.0 + exp(-2.0 * x)) - 1.0;
// dx = dout * (1 - out * out)
insert_handler("tanh", "2.0 / (1.0 + real_exp(-2.0 * ${0})) - 1.0",
insert_handler("tanh", "2.0 / (1.0 + Exp(-2.0 * ${0})) - 1.0",
{"${2} * (1.0 - ${1} * ${1})"});
// cast
// out = static_cast<T>(d)
// dx = static_cast<T>(d_out)
// cast:
// out = static_cast<T>(x)
// TODO(wangchaochaohu): This is not the compelete definition of
// cast Op, We need refine it later.
insert_handler("cast", "${0}", {"${0}"});
insert_handler("cast", "${0}", {});
// sqrt:
// out = x^(1/2)
// dx = dout * 0.5 / out
insert_handler("sqrt", "Sqrt(${0})", {"${2} * 0.5 / ${1}"});
// square:
// out = x^2
// dx = dout * 2.0 * x
insert_handler("square", "${0} * ${0}", {"${2} * 2.0 * ${0}"});
}
void OperationMap::InsertBinaryElementwiseOperations() {
......@@ -168,9 +177,13 @@ void OperationMap::InsertMultivariateElementwiseOperations() {
Insert(type, num_oprands, op_type, expr, grad_exprs, {"X"}, {"Out"});
};
// here [] represent the number of input is positive(>=0).
// if input list size of Sum Op is 3, It will expand as
// ${0} + ${1} + ${2}
// sum:
// out = x_0 + x_1 + ... + x_N-1
//
// For sum with N inputs, the expression inside "[]" will be expanded
// N - 1 times. The ${?} represents the number of inputs starting with is 1.
// For example, sum with 4 inputs, the expanded expression is:
// ${0} + ${1} + ${2} + ${3}
insert_handler("sum", "${0}[ + ${?}]", {});
}
......
......@@ -38,7 +38,6 @@ class PassTest(unittest.TestCase):
self.pass_attrs = {}
self.fused_op_type = None
self.num_fused_ops = -1
self.backward = True
np.random.seed(123)
random.seed(124)
......@@ -49,7 +48,11 @@ class PassTest(unittest.TestCase):
places.append(fluid.CUDAPlace(0))
return places
def append_gradinets(self, outs):
def grad(self, var):
grad_name = var.name + "@GRAD"
return self.main_program.global_block().var(grad_name)
def append_gradients(self, outs):
with fluid.program_guard(self.main_program, self.startup_program):
loss = fluid.layers.mean(outs)
fluid.backward.append_backward(loss)
......
......@@ -35,15 +35,12 @@ class FusionGroupPassTest(PassTest):
# subgraph with 2 op nodes
tmp_2 = layers.relu(tmp_0 + tmp_1)
self.num_fused_ops = 1
self.fetch_list = [tmp_2.name, tmp_1.name + "@GRAD"]
self.append_gradients(tmp_2)
if self.backward:
self.append_gradinets(tmp_2)
self.num_fused_ops = 2
self.fetch_list = [tmp_2, self.grad(tmp_1)]
def setUp(self):
self.backward = True
self.build_program("float32")
self.feeds = self._feed_random_data(self.feed_vars)
self.pass_names = "fusion_group_pass"
......@@ -91,13 +88,10 @@ class FusionGroupPassTest1(FusionGroupPassTest):
self.feed_vars[2]) * layers.tanh(self.feed_vars[3])
tmp_2 = layers.tanh(tmp_1) + layers.sigmoid(self.feed_vars[4])
if self.backward:
self.append_gradinets(tmp_2)
self.num_fused_ops = 2
else:
self.num_fused_ops = 1
self.append_gradients(tmp_2)
self.fetch_list = [tmp_2.name, tmp_0.name + "@GRAD"]
self.num_fused_ops = 2
self.fetch_list = [tmp_2, self.grad(tmp_0)]
class FusionGroupPassTest2(FusionGroupPassTest):
......@@ -115,20 +109,11 @@ class FusionGroupPassTest2(FusionGroupPassTest):
tmp_2 = layers.relu(layers.sigmoid(self.feed_vars[3]))
tmp_3 = layers.mul(tmp_1, tmp_2)
self.num_fused_ops = 2
self.fetch_list = [tmp_3.name]
#TODO(wangchaochaohu): we need to deal with the condition of stop gradient
if self.backward:
self.append_gradinets(tmp_3)
self.num_fused_ops = 3
# TODO(wangchaochaohu): support the case when some vars are set
# stop_gradient = True.
def setUp(self):
self.backward = False
self.build_program("float32")
self.feeds = self._feed_random_data(self.feed_vars)
self.pass_names = "fusion_group_pass"
self.fused_op_type = "fusion_group"
self.num_fused_ops = 2
self.fetch_list = [tmp_3]
class FusionGroupPassTestFP64(FusionGroupPassTest):
......@@ -147,32 +132,41 @@ class FusionGroupPassTestFP16(FusionGroupPassTest):
fluid.data(
name="data2", shape=[128, 128], dtype=dtype))
# subgraph with 2 op nodes
tmp_0 = self.feed_vars[0] * self.feed_vars[1]
tmp_1 = layers.mul(tmp_0, self.feed_vars[2])
tmp_3 = layers.cast(tmp_1, dtype="float16")
tmp_2 = layers.cast(tmp_0, dtype="float16")
tmp_4 = layers.relu(tmp_2 + tmp_3)
tmp_1 = layers.cast(tmp_0, dtype="float16")
tmp_2 = layers.mul(tmp_0, self.feed_vars[2])
# subgraph with 4 op nodes
tmp_3 = layers.cast(tmp_2, dtype="float16")
tmp_4 = layers.relu(tmp_1 + tmp_3)
tmp_5 = layers.cast(tmp_4, dtype=dtype)
self.num_fused_ops = 1
self.fetch_list = [tmp_5.name]
self.append_gradients(tmp_5)
if self.backward:
self.num_fused_ops = 4
self.append_gradinets(tmp_5)
self.num_fused_ops = 3
self.fetch_list = [tmp_5, self.grad(tmp_0)]
class FusionGroupPassSumTest(FusionGroupPassTest):
def build_program(self, dtype):
with fluid.program_guard(self.main_program, self.startup_program):
self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 5)
self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 3)
self.feed_vars.append(
fluid.data(
name="data3", shape=[128, 128], dtype=dtype))
tmp_0 = layers.elementwise_add(self.feed_vars[0], self.feed_vars[1])
tmp_1 = layers.sum([tmp_0, self.feed_vars[2], self.feed_vars[3]])
tmp_2 = layers.sum([tmp_1, self.feed_vars[4]])
# subgraph with 2 op nodes
tmp_0 = layers.sum(
[self.feed_vars[0], self.feed_vars[1], self.feed_vars[2]])
tmp_1 = layers.sqrt(tmp_0)
tmp_2 = layers.mul(tmp_0, self.feed_vars[3])
# subgraph with 2 op nodes
tmp_3 = layers.square(layers.sum([tmp_1, tmp_2]))
self.append_gradients(tmp_3)
self.fetch_list = [tmp_0, tmp_1, tmp_2]
self.num_fused_ops = 1
self.num_fused_ops = 3
self.fetch_list = [tmp_3, self.grad(tmp_0)]
class FusionGroupPassCastTest(FusionGroupPassTest):
......@@ -184,12 +178,10 @@ class FusionGroupPassCastTest(FusionGroupPassTest):
tmp_1 = layers.cast(tmp_0, dtype="double")
tmp_2 = layers.cast(tmp_1, dtype="float32")
self.fetch_list = [tmp_2.name, tmp_1.name + "@GRAD"]
self.num_fused_ops = 1
self.append_gradients(tmp_2)
if self.backward:
self.num_fused_ops = 2
self.append_gradinets(tmp_2)
self.fetch_list = [tmp_2, self.grad(tmp_0)]
def setUp(self):
self.build_program("float64")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册