未验证 提交 3757e068 编写于 作者: W wangchaochaohu 提交者: GitHub

Add Unittest for backward of fusion group (#22932)

* add fusion group test for backward and refine code
上级 8fdcb43f
......@@ -24,7 +24,7 @@ namespace framework {
namespace ir {
namespace fusion_group {
std::string ExtractDataType(const std::vector<Node*> nodes) {
std::string ExtractDataType(const std::vector<Node*>& nodes) {
std::string dtype_str = "float";
auto data_type = nodes.back()->Var()->GetDataType();
......@@ -98,6 +98,7 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
std::vector<int> output_ids;
std::vector<std::string> output_names =
OperationMap::Instance().Get(op->Type()).output_names;
for (auto& name : output_names) {
PADDLE_ENFORCE_EQ(
op->Output(name).size(), 1U,
......@@ -125,9 +126,10 @@ std::string CodeGenerator::Generate(
std::string func_name,
const std::vector<OperationExpression>& expressions) {
// TODO(liuyiqun): Check whether all expressions are elementwise operations.
std::set<int> input_ids = DistilInputIds(expressions);
std::set<int> output_ids = DistilOutputIds(expressions);
std::unordered_map<int, std::string> dtypes = DistilDtypes(expressions);
std::set<int> input_ids = std::move(DistilInputIds(expressions));
std::set<int> output_ids = std::move(DistilOutputIds(expressions));
std::unordered_map<int, std::string> dtypes =
std::move(DistilDtypes(expressions));
TemplateVariable template_var;
template_var.Add("func_name", func_name);
template_var.Add("parameters", EmitParameters(input_ids, output_ids, dtypes));
......@@ -211,7 +213,7 @@ std::unordered_map<int, std::string> CodeGenerator::DistilDtypes(
// we get the parameter list code for the expression information
std::string CodeGenerator::EmitParameters(
const std::set<int>& input_ids, const std::set<int>& output_ids,
std::unordered_map<int, std::string> dtypes) {
const std::unordered_map<int, std::string>& dtypes) const {
std::stringstream ret;
ret << "int N, ";
......@@ -219,13 +221,13 @@ std::string CodeGenerator::EmitParameters(
// from the input list.
for (auto id : input_ids) {
if (output_ids.find(id) == output_ids.end()) {
ret << dtypes[id] << "* " << ArgName(id) << ", ";
ret << dtypes.at(id) << "* " << ArgName(id) << ", ";
}
}
size_t index = 0;
for (auto id : output_ids) {
ret << dtypes[id] << "* " << ArgName(id);
ret << dtypes.at(id) << "* " << ArgName(id);
if (index != output_ids.size() - 1) {
ret << ", ";
}
......@@ -238,7 +240,7 @@ std::string CodeGenerator::EmitParameters(
std::string CodeGenerator::EmitComputeBody(
const std::vector<OperationExpression>& expressions,
const std::set<int>& input_ids, const std::set<int>& output_ids,
std::unordered_map<int, std::string> dtypes) {
const std::unordered_map<int, std::string>& dtypes) const {
std::ostringstream compute;
std::unordered_set<int> used;
for (size_t i = 0; i < expressions.size(); i++) {
......@@ -251,7 +253,8 @@ std::string CodeGenerator::EmitComputeBody(
for (auto id : input_ids) {
if (output_ids.find(id) == output_ids.end() &&
used.find(id) != used.end()) {
load << dtypes[id] << " " << TmpName(id) << " = " << VarName(id) << ";";
load << dtypes.at(id) << " " << TmpName(id) << " = " << VarName(id)
<< ";";
}
}
// Store temporal variables to memory.
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <set>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h"
#include "paddle/fluid/framework/ir/fusion_group/subgraph.h"
......@@ -46,14 +47,14 @@ class CodeGenerator {
const std::vector<OperationExpression>& expressions);
// we get the parameter list code for the expression information
std::string EmitParameters(const std::set<int>& input_ids,
const std::set<int>& output_ids,
std::unordered_map<int, std::string> dtypes);
std::string EmitParameters(
const std::set<int>& input_ids, const std::set<int>& output_ids,
const std::unordered_map<int, std::string>& dtypes) const;
std::string EmitComputeBody(
const std::vector<OperationExpression>& expressions,
const std::set<int>& input_ids, const std::set<int>& output_ids,
std::unordered_map<int, std::string> dtypes);
const std::unordered_map<int, std::string>& dtypes) const;
// Encode all var nodes in the subgraph with an unique number.
std::unordered_map<std::string, int> EncodeVarNodes(SubGraph* subgraph);
......
......@@ -25,7 +25,7 @@ namespace operators {
static void MutableMultiTypeData(
std::vector<paddle::framework::LoDTensor*>* var,
const std::vector<std::string>& data_type, const platform::Place& place) {
for (size_t i = 0; i < (*var).size(); i++) {
for (size_t i = 0; i < var->size(); i++) {
if (data_type[i] == "float") {
(*var)[i]->mutable_data<float>(place);
} else if (data_type[i] == "double") {
......
......@@ -38,6 +38,7 @@ class PassTest(unittest.TestCase):
self.pass_attrs = {}
self.fused_op_type = None
self.num_fused_ops = -1
self.backward = True
np.random.seed(123)
random.seed(124)
......@@ -48,6 +49,11 @@ class PassTest(unittest.TestCase):
places.append(fluid.CUDAPlace(0))
return places
def append_gradinets(self, outs):
with fluid.program_guard(self.main_program, self.startup_program):
loss = fluid.layers.mean(outs)
fluid.backward.append_backward(loss)
def check_output(self, startup_on_cpu=False, atol=1e-5):
'''
Check whether the fetched outputs of the origin program and the
......@@ -143,7 +149,7 @@ class PassTest(unittest.TestCase):
np.allclose(
outs_opt[i], outs[i], atol=atol),
"Output < {} > has diff at {}, expected {} but got {}".format(
self.fetch_list[i].name, str(place), outs_opt[i], outs[i]))
self.fetch_list[i], str(place), outs_opt[i], outs[i]))
def _check_fused_ops(self, program):
'''
......
......@@ -35,10 +35,15 @@ class FusionGroupPassTest(PassTest):
# subgraph with 2 op nodes
tmp_2 = layers.relu(tmp_0 + tmp_1)
self.fetch_list = [tmp_2]
self.num_fused_ops = 1
self.fetch_list = [tmp_2.name, tmp_1.name + "@GRAD"]
if self.backward:
self.append_gradinets(tmp_2)
self.num_fused_ops = 2
def setUp(self):
self.backward = True
self.build_program("float32")
self.feeds = self._feed_random_data(self.feed_vars)
self.pass_names = "fusion_group_pass"
......@@ -86,8 +91,13 @@ class FusionGroupPassTest1(FusionGroupPassTest):
self.feed_vars[2]) * layers.tanh(self.feed_vars[3])
tmp_2 = layers.tanh(tmp_1) + layers.sigmoid(self.feed_vars[4])
self.fetch_list = [tmp_1, tmp_2]
self.num_fused_ops = 1
if self.backward:
self.append_gradinets(tmp_2)
self.num_fused_ops = 2
else:
self.num_fused_ops = 1
self.fetch_list = [tmp_2.name, tmp_0.name + "@GRAD"]
class FusionGroupPassTest2(FusionGroupPassTest):
......@@ -98,15 +108,27 @@ class FusionGroupPassTest2(FusionGroupPassTest):
fluid.data(
name="data3", shape=[128, 32], dtype=dtype))
# subgraph with 3 op nodes
tmp_1 = layers.relu(
(self.feed_vars[0] - self.feed_vars[1]) * self.feed_vars[2])
# subgraph with 3 op node
tmp_0 = self.feed_vars[0] + self.feed_vars[1]
tmp_1 = layers.relu(self.feed_vars[2] * tmp_0)
# subgraph with 2 op nodes
tmp_2 = layers.relu(layers.sigmoid(self.feed_vars[3]))
tmp_3 = layers.mul(tmp_1, tmp_2)
self.fetch_list = [tmp_1, tmp_2, tmp_3]
self.num_fused_ops = 2
self.fetch_list = [tmp_3.name]
#TODO(wangchaochaohu): we need to deal with the condition of stop gradient
if self.backward:
self.append_gradinets(tmp_3)
self.num_fused_ops = 3
def setUp(self):
self.backward = False
self.build_program("float32")
self.feeds = self._feed_random_data(self.feed_vars)
self.pass_names = "fusion_group_pass"
self.fused_op_type = "fusion_group"
class FusionGroupPassTestFP64(FusionGroupPassTest):
......@@ -132,8 +154,12 @@ class FusionGroupPassTestFP16(FusionGroupPassTest):
tmp_4 = layers.relu(tmp_2 + tmp_3)
tmp_5 = layers.cast(tmp_4, dtype=dtype)
self.fetch_list = [tmp_0, tmp_1, tmp_2, tmp_3, tmp_4, tmp_5]
self.num_fused_ops = 2
self.num_fused_ops = 1
self.fetch_list = [tmp_5.name]
if self.backward:
self.num_fused_ops = 4
self.append_gradinets(tmp_5)
class FusionGroupPassSumTest(FusionGroupPassTest):
......@@ -158,9 +184,13 @@ class FusionGroupPassCastTest(FusionGroupPassTest):
tmp_1 = layers.cast(tmp_0, dtype="double")
tmp_2 = layers.cast(tmp_1, dtype="float32")
self.fetch_list = [tmp_0, tmp_1, tmp_2]
self.fetch_list = [tmp_2.name, tmp_1.name + "@GRAD"]
self.num_fused_ops = 1
if self.backward:
self.num_fused_ops = 2
self.append_gradinets(tmp_2)
def setUp(self):
self.build_program("float64")
self.feeds = self._feed_random_data(self.feed_vars)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册