未验证 提交 3757e068 编写于 作者: W wangchaochaohu 提交者: GitHub

Add Unittest for backward of fusion group (#22932)

* add fusion group test for backward and refine code
上级 8fdcb43f
...@@ -24,7 +24,7 @@ namespace framework { ...@@ -24,7 +24,7 @@ namespace framework {
namespace ir { namespace ir {
namespace fusion_group { namespace fusion_group {
std::string ExtractDataType(const std::vector<Node*> nodes) { std::string ExtractDataType(const std::vector<Node*>& nodes) {
std::string dtype_str = "float"; std::string dtype_str = "float";
auto data_type = nodes.back()->Var()->GetDataType(); auto data_type = nodes.back()->Var()->GetDataType();
...@@ -98,6 +98,7 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions( ...@@ -98,6 +98,7 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
std::vector<int> output_ids; std::vector<int> output_ids;
std::vector<std::string> output_names = std::vector<std::string> output_names =
OperationMap::Instance().Get(op->Type()).output_names; OperationMap::Instance().Get(op->Type()).output_names;
for (auto& name : output_names) { for (auto& name : output_names) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
op->Output(name).size(), 1U, op->Output(name).size(), 1U,
...@@ -125,9 +126,10 @@ std::string CodeGenerator::Generate( ...@@ -125,9 +126,10 @@ std::string CodeGenerator::Generate(
std::string func_name, std::string func_name,
const std::vector<OperationExpression>& expressions) { const std::vector<OperationExpression>& expressions) {
// TODO(liuyiqun): Check whether all expressions are elementwise operations. // TODO(liuyiqun): Check whether all expressions are elementwise operations.
std::set<int> input_ids = DistilInputIds(expressions); std::set<int> input_ids = std::move(DistilInputIds(expressions));
std::set<int> output_ids = DistilOutputIds(expressions); std::set<int> output_ids = std::move(DistilOutputIds(expressions));
std::unordered_map<int, std::string> dtypes = DistilDtypes(expressions); std::unordered_map<int, std::string> dtypes =
std::move(DistilDtypes(expressions));
TemplateVariable template_var; TemplateVariable template_var;
template_var.Add("func_name", func_name); template_var.Add("func_name", func_name);
template_var.Add("parameters", EmitParameters(input_ids, output_ids, dtypes)); template_var.Add("parameters", EmitParameters(input_ids, output_ids, dtypes));
...@@ -211,7 +213,7 @@ std::unordered_map<int, std::string> CodeGenerator::DistilDtypes( ...@@ -211,7 +213,7 @@ std::unordered_map<int, std::string> CodeGenerator::DistilDtypes(
// we get the parameter list code for the expression information // we get the parameter list code for the expression information
std::string CodeGenerator::EmitParameters( std::string CodeGenerator::EmitParameters(
const std::set<int>& input_ids, const std::set<int>& output_ids, const std::set<int>& input_ids, const std::set<int>& output_ids,
std::unordered_map<int, std::string> dtypes) { const std::unordered_map<int, std::string>& dtypes) const {
std::stringstream ret; std::stringstream ret;
ret << "int N, "; ret << "int N, ";
...@@ -219,13 +221,13 @@ std::string CodeGenerator::EmitParameters( ...@@ -219,13 +221,13 @@ std::string CodeGenerator::EmitParameters(
// from the input list. // from the input list.
for (auto id : input_ids) { for (auto id : input_ids) {
if (output_ids.find(id) == output_ids.end()) { if (output_ids.find(id) == output_ids.end()) {
ret << dtypes[id] << "* " << ArgName(id) << ", "; ret << dtypes.at(id) << "* " << ArgName(id) << ", ";
} }
} }
size_t index = 0; size_t index = 0;
for (auto id : output_ids) { for (auto id : output_ids) {
ret << dtypes[id] << "* " << ArgName(id); ret << dtypes.at(id) << "* " << ArgName(id);
if (index != output_ids.size() - 1) { if (index != output_ids.size() - 1) {
ret << ", "; ret << ", ";
} }
...@@ -238,7 +240,7 @@ std::string CodeGenerator::EmitParameters( ...@@ -238,7 +240,7 @@ std::string CodeGenerator::EmitParameters(
std::string CodeGenerator::EmitComputeBody( std::string CodeGenerator::EmitComputeBody(
const std::vector<OperationExpression>& expressions, const std::vector<OperationExpression>& expressions,
const std::set<int>& input_ids, const std::set<int>& output_ids, const std::set<int>& input_ids, const std::set<int>& output_ids,
std::unordered_map<int, std::string> dtypes) { const std::unordered_map<int, std::string>& dtypes) const {
std::ostringstream compute; std::ostringstream compute;
std::unordered_set<int> used; std::unordered_set<int> used;
for (size_t i = 0; i < expressions.size(); i++) { for (size_t i = 0; i < expressions.size(); i++) {
...@@ -251,7 +253,8 @@ std::string CodeGenerator::EmitComputeBody( ...@@ -251,7 +253,8 @@ std::string CodeGenerator::EmitComputeBody(
for (auto id : input_ids) { for (auto id : input_ids) {
if (output_ids.find(id) == output_ids.end() && if (output_ids.find(id) == output_ids.end() &&
used.find(id) != used.end()) { used.find(id) != used.end()) {
load << dtypes[id] << " " << TmpName(id) << " = " << VarName(id) << ";"; load << dtypes.at(id) << " " << TmpName(id) << " = " << VarName(id)
<< ";";
} }
} }
// Store temporal variables to memory. // Store temporal variables to memory.
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <set> #include <set>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h" #include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h"
#include "paddle/fluid/framework/ir/fusion_group/subgraph.h" #include "paddle/fluid/framework/ir/fusion_group/subgraph.h"
...@@ -46,14 +47,14 @@ class CodeGenerator { ...@@ -46,14 +47,14 @@ class CodeGenerator {
const std::vector<OperationExpression>& expressions); const std::vector<OperationExpression>& expressions);
// we get the parameter list code for the expression information // we get the parameter list code for the expression information
std::string EmitParameters(const std::set<int>& input_ids, std::string EmitParameters(
const std::set<int>& output_ids, const std::set<int>& input_ids, const std::set<int>& output_ids,
std::unordered_map<int, std::string> dtypes); const std::unordered_map<int, std::string>& dtypes) const;
std::string EmitComputeBody( std::string EmitComputeBody(
const std::vector<OperationExpression>& expressions, const std::vector<OperationExpression>& expressions,
const std::set<int>& input_ids, const std::set<int>& output_ids, const std::set<int>& input_ids, const std::set<int>& output_ids,
std::unordered_map<int, std::string> dtypes); const std::unordered_map<int, std::string>& dtypes) const;
// Encode all var nodes in the subgraph with an unique number. // Encode all var nodes in the subgraph with an unique number.
std::unordered_map<std::string, int> EncodeVarNodes(SubGraph* subgraph); std::unordered_map<std::string, int> EncodeVarNodes(SubGraph* subgraph);
......
...@@ -25,7 +25,7 @@ namespace operators { ...@@ -25,7 +25,7 @@ namespace operators {
static void MutableMultiTypeData( static void MutableMultiTypeData(
std::vector<paddle::framework::LoDTensor*>* var, std::vector<paddle::framework::LoDTensor*>* var,
const std::vector<std::string>& data_type, const platform::Place& place) { const std::vector<std::string>& data_type, const platform::Place& place) {
for (size_t i = 0; i < (*var).size(); i++) { for (size_t i = 0; i < var->size(); i++) {
if (data_type[i] == "float") { if (data_type[i] == "float") {
(*var)[i]->mutable_data<float>(place); (*var)[i]->mutable_data<float>(place);
} else if (data_type[i] == "double") { } else if (data_type[i] == "double") {
......
...@@ -38,6 +38,7 @@ class PassTest(unittest.TestCase): ...@@ -38,6 +38,7 @@ class PassTest(unittest.TestCase):
self.pass_attrs = {} self.pass_attrs = {}
self.fused_op_type = None self.fused_op_type = None
self.num_fused_ops = -1 self.num_fused_ops = -1
self.backward = True
np.random.seed(123) np.random.seed(123)
random.seed(124) random.seed(124)
...@@ -48,6 +49,11 @@ class PassTest(unittest.TestCase): ...@@ -48,6 +49,11 @@ class PassTest(unittest.TestCase):
places.append(fluid.CUDAPlace(0)) places.append(fluid.CUDAPlace(0))
return places return places
def append_gradinets(self, outs):
with fluid.program_guard(self.main_program, self.startup_program):
loss = fluid.layers.mean(outs)
fluid.backward.append_backward(loss)
def check_output(self, startup_on_cpu=False, atol=1e-5): def check_output(self, startup_on_cpu=False, atol=1e-5):
''' '''
Check whether the fetched outputs of the origin program and the Check whether the fetched outputs of the origin program and the
...@@ -143,7 +149,7 @@ class PassTest(unittest.TestCase): ...@@ -143,7 +149,7 @@ class PassTest(unittest.TestCase):
np.allclose( np.allclose(
outs_opt[i], outs[i], atol=atol), outs_opt[i], outs[i], atol=atol),
"Output < {} > has diff at {}, expected {} but got {}".format( "Output < {} > has diff at {}, expected {} but got {}".format(
self.fetch_list[i].name, str(place), outs_opt[i], outs[i])) self.fetch_list[i], str(place), outs_opt[i], outs[i]))
def _check_fused_ops(self, program): def _check_fused_ops(self, program):
''' '''
......
...@@ -35,10 +35,15 @@ class FusionGroupPassTest(PassTest): ...@@ -35,10 +35,15 @@ class FusionGroupPassTest(PassTest):
# subgraph with 2 op nodes # subgraph with 2 op nodes
tmp_2 = layers.relu(tmp_0 + tmp_1) tmp_2 = layers.relu(tmp_0 + tmp_1)
self.fetch_list = [tmp_2]
self.num_fused_ops = 1 self.num_fused_ops = 1
self.fetch_list = [tmp_2.name, tmp_1.name + "@GRAD"]
if self.backward:
self.append_gradinets(tmp_2)
self.num_fused_ops = 2
def setUp(self): def setUp(self):
self.backward = True
self.build_program("float32") self.build_program("float32")
self.feeds = self._feed_random_data(self.feed_vars) self.feeds = self._feed_random_data(self.feed_vars)
self.pass_names = "fusion_group_pass" self.pass_names = "fusion_group_pass"
...@@ -86,8 +91,13 @@ class FusionGroupPassTest1(FusionGroupPassTest): ...@@ -86,8 +91,13 @@ class FusionGroupPassTest1(FusionGroupPassTest):
self.feed_vars[2]) * layers.tanh(self.feed_vars[3]) self.feed_vars[2]) * layers.tanh(self.feed_vars[3])
tmp_2 = layers.tanh(tmp_1) + layers.sigmoid(self.feed_vars[4]) tmp_2 = layers.tanh(tmp_1) + layers.sigmoid(self.feed_vars[4])
self.fetch_list = [tmp_1, tmp_2] if self.backward:
self.num_fused_ops = 1 self.append_gradinets(tmp_2)
self.num_fused_ops = 2
else:
self.num_fused_ops = 1
self.fetch_list = [tmp_2.name, tmp_0.name + "@GRAD"]
class FusionGroupPassTest2(FusionGroupPassTest): class FusionGroupPassTest2(FusionGroupPassTest):
...@@ -98,15 +108,27 @@ class FusionGroupPassTest2(FusionGroupPassTest): ...@@ -98,15 +108,27 @@ class FusionGroupPassTest2(FusionGroupPassTest):
fluid.data( fluid.data(
name="data3", shape=[128, 32], dtype=dtype)) name="data3", shape=[128, 32], dtype=dtype))
# subgraph with 3 op nodes # subgraph with 3 op node
tmp_1 = layers.relu( tmp_0 = self.feed_vars[0] + self.feed_vars[1]
(self.feed_vars[0] - self.feed_vars[1]) * self.feed_vars[2]) tmp_1 = layers.relu(self.feed_vars[2] * tmp_0)
# subgraph with 2 op nodes # subgraph with 2 op nodes
tmp_2 = layers.relu(layers.sigmoid(self.feed_vars[3])) tmp_2 = layers.relu(layers.sigmoid(self.feed_vars[3]))
tmp_3 = layers.mul(tmp_1, tmp_2) tmp_3 = layers.mul(tmp_1, tmp_2)
self.fetch_list = [tmp_1, tmp_2, tmp_3]
self.num_fused_ops = 2 self.num_fused_ops = 2
self.fetch_list = [tmp_3.name]
#TODO(wangchaochaohu): we need to deal with the condition of stop gradient
if self.backward:
self.append_gradinets(tmp_3)
self.num_fused_ops = 3
def setUp(self):
self.backward = False
self.build_program("float32")
self.feeds = self._feed_random_data(self.feed_vars)
self.pass_names = "fusion_group_pass"
self.fused_op_type = "fusion_group"
class FusionGroupPassTestFP64(FusionGroupPassTest): class FusionGroupPassTestFP64(FusionGroupPassTest):
...@@ -132,8 +154,12 @@ class FusionGroupPassTestFP16(FusionGroupPassTest): ...@@ -132,8 +154,12 @@ class FusionGroupPassTestFP16(FusionGroupPassTest):
tmp_4 = layers.relu(tmp_2 + tmp_3) tmp_4 = layers.relu(tmp_2 + tmp_3)
tmp_5 = layers.cast(tmp_4, dtype=dtype) tmp_5 = layers.cast(tmp_4, dtype=dtype)
self.fetch_list = [tmp_0, tmp_1, tmp_2, tmp_3, tmp_4, tmp_5] self.num_fused_ops = 1
self.num_fused_ops = 2 self.fetch_list = [tmp_5.name]
if self.backward:
self.num_fused_ops = 4
self.append_gradinets(tmp_5)
class FusionGroupPassSumTest(FusionGroupPassTest): class FusionGroupPassSumTest(FusionGroupPassTest):
...@@ -158,9 +184,13 @@ class FusionGroupPassCastTest(FusionGroupPassTest): ...@@ -158,9 +184,13 @@ class FusionGroupPassCastTest(FusionGroupPassTest):
tmp_1 = layers.cast(tmp_0, dtype="double") tmp_1 = layers.cast(tmp_0, dtype="double")
tmp_2 = layers.cast(tmp_1, dtype="float32") tmp_2 = layers.cast(tmp_1, dtype="float32")
self.fetch_list = [tmp_0, tmp_1, tmp_2] self.fetch_list = [tmp_2.name, tmp_1.name + "@GRAD"]
self.num_fused_ops = 1 self.num_fused_ops = 1
if self.backward:
self.num_fused_ops = 2
self.append_gradinets(tmp_2)
def setUp(self): def setUp(self):
self.build_program("float64") self.build_program("float64")
self.feeds = self._feed_random_data(self.feed_vars) self.feeds = self._feed_random_data(self.feed_vars)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册