未验证 提交 6b464f96 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add an operator node in unittest to make the fusing result unique. (#24617)

上级 837dd47a
...@@ -82,7 +82,7 @@ int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const { ...@@ -82,7 +82,7 @@ int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const {
bool FusionGroupPass::GenerateCode(fusion_group::SubGraph* subgraph) const { bool FusionGroupPass::GenerateCode(fusion_group::SubGraph* subgraph) const {
fusion_group::CodeGenerator code_generator; fusion_group::CodeGenerator code_generator;
std::string code_str = code_generator.Generate(subgraph); std::string code_str = code_generator.Generate(subgraph);
VLOG(3) << code_str; VLOG(4) << code_str;
// TODO(liuyiqun): supported different places // TODO(liuyiqun): supported different places
platform::CUDAPlace place = platform::CUDAPlace(0); platform::CUDAPlace place = platform::CUDAPlace(0);
......
...@@ -122,7 +122,7 @@ class FusionGroupPassTestFP64(FusionGroupPassTest): ...@@ -122,7 +122,7 @@ class FusionGroupPassTestFP64(FusionGroupPassTest):
self.fused_op_type = "fusion_group" self.fused_op_type = "fusion_group"
class FusionGroupPassTestFP16(FusionGroupPassTest): class FusionGroupPassTestCastAndFP16(FusionGroupPassTest):
def build_program(self, dtype): def build_program(self, dtype):
with fluid.program_guard(self.main_program, self.startup_program): with fluid.program_guard(self.main_program, self.startup_program):
self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 2) self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 2)
...@@ -132,7 +132,7 @@ class FusionGroupPassTestFP16(FusionGroupPassTest): ...@@ -132,7 +132,7 @@ class FusionGroupPassTestFP16(FusionGroupPassTest):
# subgraph with 2 op nodes # subgraph with 2 op nodes
tmp_0 = self.feed_vars[0] * self.feed_vars[1] tmp_0 = self.feed_vars[0] * self.feed_vars[1]
tmp_1 = layers.cast(tmp_0, dtype="float16") tmp_1 = layers.softmax(layers.cast(tmp_0, dtype="float16"))
tmp_2 = layers.mul(tmp_0, self.feed_vars[2]) tmp_2 = layers.mul(tmp_0, self.feed_vars[2])
# subgraph with 4 op nodes # subgraph with 4 op nodes
tmp_3 = layers.cast(tmp_2, dtype="float16") tmp_3 = layers.cast(tmp_2, dtype="float16")
...@@ -141,7 +141,7 @@ class FusionGroupPassTestFP16(FusionGroupPassTest): ...@@ -141,7 +141,7 @@ class FusionGroupPassTestFP16(FusionGroupPassTest):
self.append_gradients(tmp_5) self.append_gradients(tmp_5)
self.num_fused_ops = 3 self.num_fused_ops = 4
self.fetch_list = [tmp_5, self.grad(tmp_0)] self.fetch_list = [tmp_5, self.grad(tmp_0)]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册