未验证 提交 3af47711 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add the detection and code-generation of sqrt and square in fusion_group (#23095)

上级 d066d6f9
...@@ -25,15 +25,21 @@ namespace ir { ...@@ -25,15 +25,21 @@ namespace ir {
namespace fusion_group { namespace fusion_group {
std::string ExtractDataType(const std::vector<Node*>& nodes) { std::string ExtractDataType(const std::vector<Node*>& nodes) {
std::string dtype_str = "float"; std::string dtype_str = "";
auto data_type = nodes.back()->Var()->GetDataType(); for (const auto* n : nodes) {
if (n && n->IsVar() && n->Var()) {
if (data_type == proto::VarType::FP32) { // The data type of all inputs/outputs must be the same, which are
dtype_str = "float"; // checked when detecting the subgraph.
} else if (data_type == proto::VarType::FP64) { auto dtype = n->Var()->GetDataType();
dtype_str = "double"; if (dtype == proto::VarType::FP32) {
} else if (data_type == proto::VarType::FP16) { dtype_str = "float";
dtype_str = "float16"; } else if (dtype == proto::VarType::FP64) {
dtype_str = "double";
} else if (dtype == proto::VarType::FP16) {
dtype_str = "float16";
}
break;
}
} }
return dtype_str; return dtype_str;
...@@ -80,7 +86,6 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions( ...@@ -80,7 +86,6 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
for (auto& name : input_names) { for (auto& name : input_names) {
// Some input vars are not used in grad ops, such as // Some input vars are not used in grad ops, such as
// "elementwise_add_grad", where "X", "Y" and "Out" are not used. // "elementwise_add_grad", where "X", "Y" and "Out" are not used.
if ((HasInput(node, name) && op->Input(name).size() >= 1U)) { if ((HasInput(node, name) && op->Input(name).size() >= 1U)) {
for (size_t i = 0; i < op->Input(name).size(); i++) { for (size_t i = 0; i < op->Input(name).size(); i++) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
......
...@@ -38,13 +38,13 @@ static std::string ExpandMultivariateTemplate(const std::string rhs, ...@@ -38,13 +38,13 @@ static std::string ExpandMultivariateTemplate(const std::string rhs,
int start_pos = rhs.find("[", 0); int start_pos = rhs.find("[", 0);
int end_pos = rhs.find("]", 0); int end_pos = rhs.find("]", 0);
std::string sum_rhs = rhs.substr(0, start_pos); std::string sum_rhs = rhs.substr(0, start_pos);
std::string sum_rhs_component = std::string repeated_component =
rhs.substr(start_pos + 1, (end_pos - start_pos - 1)); rhs.substr(start_pos + 1, (end_pos - start_pos - 1));
int replace_pos = sum_rhs_component.find("?", 0); int replace_pos = repeated_component.find("?", 0);
for (size_t i = 1; i < input_size; i++) { for (size_t i = 1; i < input_size; i++) {
std::string append_str = std::string append_str = repeated_component;
sum_rhs_component.replace(replace_pos, 1, std::to_string(i)); append_str.replace(replace_pos, 1, std::to_string(i));
sum_rhs = sum_rhs + append_str; sum_rhs = sum_rhs + append_str;
} }
return sum_rhs; return sum_rhs;
......
...@@ -20,20 +20,26 @@ namespace ir { ...@@ -20,20 +20,26 @@ namespace ir {
namespace fusion_group { namespace fusion_group {
static constexpr char predefined_cuda_functions_fp32[] = R"( static constexpr char predefined_cuda_functions_fp32[] = R"(
__device__ inline float real_exp(float x) { return ::expf(x); } __device__ inline float Max(float x, float y) { return fmaxf(x, y); }
__device__ inline float real_log(float x) { return ::logf(x); } __device__ inline float Exp(float x) { return expf(x); }
__device__ inline float Log(float x) { return logf(x); }
__device__ inline float Sqrt(float x) { return sqrtf(x); }
)"; )";
static constexpr char predefined_cuda_functions_fp64[] = R"( static constexpr char predefined_cuda_functions_fp64[] = R"(
__device__ inline double real_exp(double x) { return ::exp(x); } __device__ inline double Max(double x, double y) { return fmax(x, y); }
__device__ inline double real_log(double x) { return ::log(x); } __device__ inline double Exp(double x) { return exp(x); }
__device__ inline double Log(double x) { return log(x); }
__device__ inline double Sqrt(double x) { return sqrt(x); }
)"; )";
static constexpr char predefined_cuda_functions_fp16[] = R"( static constexpr char predefined_cuda_functions_fp16[] = R"(
__device__ inline float real_exp(float x) { return ::expf(x); } __device__ inline float Max(float x, float y) { return fmaxf(x, y); }
__device__ inline float real_log(float x) { return ::logf(x); } __device__ inline float Exp(float x) { return expf(x); }
__device__ inline float Log(float x) { return logf(x); }
__device__ inline float Sqrt(float x) { return sqrtf(x); }
#define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var))) #define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var)))
#define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var))) #define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var)))
......
...@@ -60,52 +60,41 @@ static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l, ...@@ -60,52 +60,41 @@ static bool IsEqualAndNotEmpty(const std::vector<int64_t>& l,
return l.size() != 0U && r.size() != 0U && l == r; return l.size() != 0U && r.size() != 0U && l == r;
} }
bool GroupDetector::IsFusionGroupOp(const Node* n) { bool GroupDetector::CheckPrecondition(const Node* n) {
if (!(n && n->IsOp() && n->Op())) return false; auto check_data_type = [&](const std::vector<Node*>& nodes) -> bool {
bool is_first = true; bool is_first = true;
proto::VarType::Type i_data_type = proto::VarType::FP32; proto::VarType::Type data_type_0;
proto::VarType::Type o_data_type = proto::VarType::FP32; for (auto* n : nodes) {
if (n && n->IsVar() && n->Var()) {
for (auto* i_node : n->inputs) { if (n->Var()->GetType() != proto::VarType::LOD_TENSOR) {
if (!i_node->Var()) return false; return false;
if (i_node->Var()->GetType() != proto::VarType::LOD_TENSOR) { }
return false;
}
if (is_first) {
i_data_type = i_node->Var()->GetDataType();
is_first = false;
} else {
if (i_data_type != i_node->Var()->GetDataType()) return false;
}
}
is_first = true; proto::VarType::Type data_type_i = n->Var()->GetDataType();
for (auto* o_node : n->outputs) { if (data_type_i == proto::VarType::FP32 ||
if (!o_node->Var()) return false; data_type_i == proto::VarType::FP64 ||
if (o_node->Var()->GetType() != proto::VarType::LOD_TENSOR) { data_type_i == proto::VarType::FP16) {
return false; if (is_first) {
} data_type_0 = data_type_i;
if (is_first) { is_first = false;
o_data_type = o_node->Var()->GetDataType(); } else if (data_type_0 != data_type_i) {
is_first = false; return false;
} else { }
if (o_data_type != o_node->Var()->GetDataType()) return false; } else {
return false;
}
}
} }
} return true;
};
if (!(i_data_type == proto::VarType::FP32 ||
i_data_type == proto::VarType::FP64 ||
i_data_type == proto::VarType::FP16) ||
!(o_data_type == proto::VarType::FP32 ||
o_data_type == proto::VarType::FP64 ||
o_data_type == proto::VarType::FP16))
return false;
return true; return n && n->IsOp() && n->Op() && check_data_type(n->inputs) &&
check_data_type(n->outputs);
} }
bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) { bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
if (IsSpecifiedOp(GetElementwiseOpTypes(), n)) { if (IsSpecifiedOp(GetElementwiseOpTypes(), n)) {
// Check whether all inputs have the same shape.
std::vector<int64_t> shape_0; std::vector<int64_t> shape_0;
for (size_t i = 0; i < n->inputs.size(); ++i) { for (size_t i = 0; i < n->inputs.size(); ++i) {
auto* in_i = n->inputs[i]; auto* in_i = n->inputs[i];
...@@ -130,7 +119,7 @@ bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) { ...@@ -130,7 +119,7 @@ bool ElementwiseGroupDetector::IsElementwiseOp(const Node* n) {
std::vector<std::vector<Node*>> ElementwiseGroupDetector::operator()( std::vector<std::vector<Node*>> ElementwiseGroupDetector::operator()(
Graph* graph) { Graph* graph) {
auto teller = [&](const Node* n) -> bool { auto teller = [&](const Node* n) -> bool {
return IsFusionGroupOp(n) && IsElementwiseOp(n); return CheckPrecondition(n) && IsElementwiseOp(n);
}; };
return SubgraphDetector(graph, teller)(); return SubgraphDetector(graph, teller)();
......
...@@ -25,7 +25,7 @@ namespace fusion_group { ...@@ -25,7 +25,7 @@ namespace fusion_group {
class GroupDetector { class GroupDetector {
protected: protected:
bool IsFusionGroupOp(const Node* n); bool CheckPrecondition(const Node* n);
}; };
class ElementwiseGroupDetector : GroupDetector { class ElementwiseGroupDetector : GroupDetector {
......
...@@ -33,6 +33,8 @@ void FusionGroupPass::ApplyImpl(ir::Graph* graph) const { ...@@ -33,6 +33,8 @@ void FusionGroupPass::ApplyImpl(ir::Graph* graph) const {
fusion_group::OperationMap::Init(); fusion_group::OperationMap::Init();
int num_elementwise_groups = DetectFusionGroup(graph, 0); int num_elementwise_groups = DetectFusionGroup(graph, 0);
AddStatis(num_elementwise_groups); AddStatis(num_elementwise_groups);
LOG(INFO) << "Detect " << num_elementwise_groups
<< " elementwise fusion groups.";
} }
} }
...@@ -54,7 +56,7 @@ int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const { ...@@ -54,7 +56,7 @@ int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const {
VLOG(3) << "subgraph: {\n" << DebugString(subgraph.SortedNodes()) << "}\n"; VLOG(3) << "subgraph: {\n" << DebugString(subgraph.SortedNodes()) << "}\n";
if (subgraph.IsValid(min_subgraph_size)) { if (subgraph.IsValid(min_subgraph_size)) {
subgraph.SetFuncName("fused_elementwise_" + std::to_string(index++)); subgraph.SetFuncName("FusedElementwise" + std::to_string(index++));
if (GenerateCode(&subgraph)) { if (GenerateCode(&subgraph)) {
InsertFusionGroupOp(graph, &subgraph); InsertFusionGroupOp(graph, &subgraph);
num_subgraphs++; num_subgraphs++;
......
...@@ -95,20 +95,29 @@ void OperationMap::InsertUnaryElementwiseOperations() { ...@@ -95,20 +95,29 @@ void OperationMap::InsertUnaryElementwiseOperations() {
// sigmoid: // sigmoid:
// out = f(x) = 1.0 / (1.0 + exp(-x)) // out = f(x) = 1.0 / (1.0 + exp(-x))
// dx = dout * out * (1 - out) // dx = dout * out * (1 - out)
insert_handler("sigmoid", "1.0 / (1.0 + real_exp(- ${0}))", insert_handler("sigmoid", "1.0 / (1.0 + Exp(- ${0}))",
{"${2} * ${1} * (1.0 - ${1})"}); {"${2} * ${1} * (1.0 - ${1})"});
// tanh: // tanh:
// out = f(x) = 2.0 / (1.0 + exp(-2.0 * x)) - 1.0; // out = f(x) = 2.0 / (1.0 + exp(-2.0 * x)) - 1.0;
// dx = dout * (1 - out * out) // dx = dout * (1 - out * out)
insert_handler("tanh", "2.0 / (1.0 + real_exp(-2.0 * ${0})) - 1.0", insert_handler("tanh", "2.0 / (1.0 + Exp(-2.0 * ${0})) - 1.0",
{"${2} * (1.0 - ${1} * ${1})"}); {"${2} * (1.0 - ${1} * ${1})"});
// cast // cast:
// out = static_cast<T>(d) // out = static_cast<T>(x)
// dx = static_cast<T>(d_out)
// TODO(wangchaochaohu): This is not the compelete definition of // TODO(wangchaochaohu): This is not the compelete definition of
// cast Op, We need refine it later. // cast Op, We need refine it later.
insert_handler("cast", "${0}", {"${0}"}); insert_handler("cast", "${0}", {});
// sqrt:
// out = x^(1/2)
// dx = dout * 0.5 / out
insert_handler("sqrt", "Sqrt(${0})", {"${2} * 0.5 / ${1}"});
// square:
// out = x^2
// dx = dout * 2.0 * x
insert_handler("square", "${0} * ${0}", {"${2} * 2.0 * ${0}"});
} }
void OperationMap::InsertBinaryElementwiseOperations() { void OperationMap::InsertBinaryElementwiseOperations() {
...@@ -168,9 +177,13 @@ void OperationMap::InsertMultivariateElementwiseOperations() { ...@@ -168,9 +177,13 @@ void OperationMap::InsertMultivariateElementwiseOperations() {
Insert(type, num_oprands, op_type, expr, grad_exprs, {"X"}, {"Out"}); Insert(type, num_oprands, op_type, expr, grad_exprs, {"X"}, {"Out"});
}; };
// here [] represent the number of input is positive(>=0). // sum:
// if input list size of Sum Op is 3, It will expand as // out = x_0 + x_1 + ... + x_N-1
// ${0} + ${1} + ${2} //
// For sum with N inputs, the expression inside "[]" will be expanded
// N - 1 times. The ${?} represents the number of inputs starting with is 1.
// For example, sum with 4 inputs, the expanded expression is:
// ${0} + ${1} + ${2} + ${3}
insert_handler("sum", "${0}[ + ${?}]", {}); insert_handler("sum", "${0}[ + ${?}]", {});
} }
......
...@@ -38,7 +38,6 @@ class PassTest(unittest.TestCase): ...@@ -38,7 +38,6 @@ class PassTest(unittest.TestCase):
self.pass_attrs = {} self.pass_attrs = {}
self.fused_op_type = None self.fused_op_type = None
self.num_fused_ops = -1 self.num_fused_ops = -1
self.backward = True
np.random.seed(123) np.random.seed(123)
random.seed(124) random.seed(124)
...@@ -49,7 +48,11 @@ class PassTest(unittest.TestCase): ...@@ -49,7 +48,11 @@ class PassTest(unittest.TestCase):
places.append(fluid.CUDAPlace(0)) places.append(fluid.CUDAPlace(0))
return places return places
def append_gradinets(self, outs): def grad(self, var):
grad_name = var.name + "@GRAD"
return self.main_program.global_block().var(grad_name)
def append_gradients(self, outs):
with fluid.program_guard(self.main_program, self.startup_program): with fluid.program_guard(self.main_program, self.startup_program):
loss = fluid.layers.mean(outs) loss = fluid.layers.mean(outs)
fluid.backward.append_backward(loss) fluid.backward.append_backward(loss)
......
...@@ -35,15 +35,12 @@ class FusionGroupPassTest(PassTest): ...@@ -35,15 +35,12 @@ class FusionGroupPassTest(PassTest):
# subgraph with 2 op nodes # subgraph with 2 op nodes
tmp_2 = layers.relu(tmp_0 + tmp_1) tmp_2 = layers.relu(tmp_0 + tmp_1)
self.num_fused_ops = 1 self.append_gradients(tmp_2)
self.fetch_list = [tmp_2.name, tmp_1.name + "@GRAD"]
if self.backward: self.num_fused_ops = 2
self.append_gradinets(tmp_2) self.fetch_list = [tmp_2, self.grad(tmp_1)]
self.num_fused_ops = 2
def setUp(self): def setUp(self):
self.backward = True
self.build_program("float32") self.build_program("float32")
self.feeds = self._feed_random_data(self.feed_vars) self.feeds = self._feed_random_data(self.feed_vars)
self.pass_names = "fusion_group_pass" self.pass_names = "fusion_group_pass"
...@@ -91,13 +88,10 @@ class FusionGroupPassTest1(FusionGroupPassTest): ...@@ -91,13 +88,10 @@ class FusionGroupPassTest1(FusionGroupPassTest):
self.feed_vars[2]) * layers.tanh(self.feed_vars[3]) self.feed_vars[2]) * layers.tanh(self.feed_vars[3])
tmp_2 = layers.tanh(tmp_1) + layers.sigmoid(self.feed_vars[4]) tmp_2 = layers.tanh(tmp_1) + layers.sigmoid(self.feed_vars[4])
if self.backward: self.append_gradients(tmp_2)
self.append_gradinets(tmp_2)
self.num_fused_ops = 2
else:
self.num_fused_ops = 1
self.fetch_list = [tmp_2.name, tmp_0.name + "@GRAD"] self.num_fused_ops = 2
self.fetch_list = [tmp_2, self.grad(tmp_0)]
class FusionGroupPassTest2(FusionGroupPassTest): class FusionGroupPassTest2(FusionGroupPassTest):
...@@ -115,20 +109,11 @@ class FusionGroupPassTest2(FusionGroupPassTest): ...@@ -115,20 +109,11 @@ class FusionGroupPassTest2(FusionGroupPassTest):
tmp_2 = layers.relu(layers.sigmoid(self.feed_vars[3])) tmp_2 = layers.relu(layers.sigmoid(self.feed_vars[3]))
tmp_3 = layers.mul(tmp_1, tmp_2) tmp_3 = layers.mul(tmp_1, tmp_2)
self.num_fused_ops = 2 # TODO(wangchaochaohu): support the case when some vars are set
self.fetch_list = [tmp_3.name] # stop_gradient = True.
#TODO(wangchaochaohu): we need to deal with the condition of stop gradient
if self.backward:
self.append_gradinets(tmp_3)
self.num_fused_ops = 3
def setUp(self): self.num_fused_ops = 2
self.backward = False self.fetch_list = [tmp_3]
self.build_program("float32")
self.feeds = self._feed_random_data(self.feed_vars)
self.pass_names = "fusion_group_pass"
self.fused_op_type = "fusion_group"
class FusionGroupPassTestFP64(FusionGroupPassTest): class FusionGroupPassTestFP64(FusionGroupPassTest):
...@@ -147,32 +132,41 @@ class FusionGroupPassTestFP16(FusionGroupPassTest): ...@@ -147,32 +132,41 @@ class FusionGroupPassTestFP16(FusionGroupPassTest):
fluid.data( fluid.data(
name="data2", shape=[128, 128], dtype=dtype)) name="data2", shape=[128, 128], dtype=dtype))
# subgraph with 2 op nodes
tmp_0 = self.feed_vars[0] * self.feed_vars[1] tmp_0 = self.feed_vars[0] * self.feed_vars[1]
tmp_1 = layers.mul(tmp_0, self.feed_vars[2]) tmp_1 = layers.cast(tmp_0, dtype="float16")
tmp_3 = layers.cast(tmp_1, dtype="float16") tmp_2 = layers.mul(tmp_0, self.feed_vars[2])
tmp_2 = layers.cast(tmp_0, dtype="float16") # subgraph with 4 op nodes
tmp_4 = layers.relu(tmp_2 + tmp_3) tmp_3 = layers.cast(tmp_2, dtype="float16")
tmp_4 = layers.relu(tmp_1 + tmp_3)
tmp_5 = layers.cast(tmp_4, dtype=dtype) tmp_5 = layers.cast(tmp_4, dtype=dtype)
self.num_fused_ops = 1 self.append_gradients(tmp_5)
self.fetch_list = [tmp_5.name]
if self.backward: self.num_fused_ops = 3
self.num_fused_ops = 4 self.fetch_list = [tmp_5, self.grad(tmp_0)]
self.append_gradinets(tmp_5)
class FusionGroupPassSumTest(FusionGroupPassTest): class FusionGroupPassSumTest(FusionGroupPassTest):
def build_program(self, dtype): def build_program(self, dtype):
with fluid.program_guard(self.main_program, self.startup_program): with fluid.program_guard(self.main_program, self.startup_program):
self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 5) self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 3)
self.feed_vars.append(
fluid.data(
name="data3", shape=[128, 128], dtype=dtype))
tmp_0 = layers.elementwise_add(self.feed_vars[0], self.feed_vars[1]) # subgraph with 2 op nodes
tmp_1 = layers.sum([tmp_0, self.feed_vars[2], self.feed_vars[3]]) tmp_0 = layers.sum(
tmp_2 = layers.sum([tmp_1, self.feed_vars[4]]) [self.feed_vars[0], self.feed_vars[1], self.feed_vars[2]])
tmp_1 = layers.sqrt(tmp_0)
tmp_2 = layers.mul(tmp_0, self.feed_vars[3])
# subgraph with 2 op nodes
tmp_3 = layers.square(layers.sum([tmp_1, tmp_2]))
self.fetch_list = [tmp_0, tmp_1, tmp_2] self.append_gradients(tmp_3)
self.num_fused_ops = 1
self.num_fused_ops = 3
self.fetch_list = [tmp_3, self.grad(tmp_0)]
class FusionGroupPassCastTest(FusionGroupPassTest): class FusionGroupPassCastTest(FusionGroupPassTest):
...@@ -184,12 +178,10 @@ class FusionGroupPassCastTest(FusionGroupPassTest): ...@@ -184,12 +178,10 @@ class FusionGroupPassCastTest(FusionGroupPassTest):
tmp_1 = layers.cast(tmp_0, dtype="double") tmp_1 = layers.cast(tmp_0, dtype="double")
tmp_2 = layers.cast(tmp_1, dtype="float32") tmp_2 = layers.cast(tmp_1, dtype="float32")
self.fetch_list = [tmp_2.name, tmp_1.name + "@GRAD"] self.append_gradients(tmp_2)
self.num_fused_ops = 1
if self.backward: self.num_fused_ops = 2
self.num_fused_ops = 2 self.fetch_list = [tmp_2, self.grad(tmp_0)]
self.append_gradinets(tmp_2)
def setUp(self): def setUp(self):
self.build_program("float64") self.build_program("float64")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册